package config
import (
"os"
audit "github.com/vigiloauth/vigilo/v2/cmd/config/audit"
login "github.com/vigiloauth/vigilo/v2/cmd/config/login"
password "github.com/vigiloauth/vigilo/v2/cmd/config/password"
server "github.com/vigiloauth/vigilo/v2/cmd/config/server"
smtp "github.com/vigiloauth/vigilo/v2/cmd/config/smtp"
token "github.com/vigiloauth/vigilo/v2/cmd/config/token"
lib "github.com/vigiloauth/vigilo/v2/idp/config"
"gopkg.in/yaml.v3"
)
type ApplicationConfig struct {
ServerConfig *server.ServerConfigYAML `yaml:"server_config"`
TokenConfig token.TokenConfigYAML `yaml:"token_config,omitempty"`
PasswordConfig password.PasswordConfigYAML `yaml:"password_config,omitempty"`
LoginConfig login.LoginConfigYAML `yaml:"login_config,omitempty"`
SMTPConfig smtp.SMTPConfigYAML `yaml:"smtp_config,omitempty"`
AuditLogConfig audit.AuditLogConfigYAML `yaml:"audit_config,omitempty"`
LogLevel *string `yaml:"log_level,omitempty"`
Port *string `yaml:"port"`
Logger *lib.Logger
Module string
}
func LoadConfigurations() *ApplicationConfig {
configFile := os.Getenv("VIGILO_CONFIG_PATH")
ac := &ApplicationConfig{
Logger: lib.GetLogger(),
Module: "Identity Server",
}
appConfig := ac.loadFromYAML(configFile)
if appConfig == nil {
ac.Logger.Warn(ac.Module, "", "No YAML file present. Using default configurations")
return ac
}
loginOptions := appConfig.LoginConfig.ToOptions()
loginConfig := lib.NewLoginConfig(loginOptions...)
passwordOptions := appConfig.PasswordConfig.ToOptions()
passwordConfig := lib.NewPasswordConfig(passwordOptions...)
tokenOptions := appConfig.TokenConfig.ToOptions()
tokenConfig := lib.NewTokenConfig(tokenOptions...)
smtpOptions := appConfig.SMTPConfig.ToOptions()
smtpConfig := lib.NewSMTPConfig(smtpOptions...)
auditLogOptions := appConfig.AuditLogConfig.ToOptions()
auditLogConfig := lib.NewAuditLogConfig(auditLogOptions...)
serverOptions := appConfig.ServerConfig.ToOptions()
serverConfig := lib.NewServerConfig(serverOptions...)
serverConfig.SetLoginConfig(loginConfig)
serverConfig.SetPasswordConfig(passwordConfig)
serverConfig.SetTokenConfig(tokenConfig)
serverConfig.SetSMTPConfig(smtpConfig)
serverConfig.SetAuditLogConfig(auditLogConfig)
if appConfig.LogLevel != nil {
lib.SetLevel(*appConfig.LogLevel)
}
return appConfig
}
func (ac *ApplicationConfig) loadFromYAML(path string) *ApplicationConfig {
data, err := os.ReadFile(path)
if err != nil {
ac.Logger.Error(ac.Module, "", "Failed to load yaml configuration: %v. Using default settings ", err)
return nil
}
var appConfig ApplicationConfig
if err := yaml.Unmarshal(data, &appConfig); err != nil {
ac.Logger.Fatal(ac.Module, "", "Failed to unmarshal YAML: %v", err)
}
return &appConfig
}
package config
import (
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
const oneDay time.Duration = 24 * time.Hour
type AuditLogConfigYAML struct {
RetentionPeriod *int `yaml:"retention_period,omitempty"`
}
func (alc *AuditLogConfigYAML) ToOptions() []config.AuditLogConfigOptions {
options := []config.AuditLogConfigOptions{}
if alc.RetentionPeriod != nil {
retention := time.Duration(*alc.RetentionPeriod) * oneDay
options = append(options, config.WithRetentionPeriod(retention))
}
return options
}
package config
import (
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
type LoginConfigYAML struct {
MaxFailedAttempts *int `yaml:"max_failed_attempts,omitempty"`
Delay *int64 `yaml:"delay,omitempty"`
LoginURL *string `yaml:"login_url,omitempty"`
}
func (lc *LoginConfigYAML) ToOptions() []config.LoginConfigOptions {
options := []config.LoginConfigOptions{}
if lc.MaxFailedAttempts != nil {
options = append(options, config.WithMaxFailedAttempts(*lc.MaxFailedAttempts))
}
if lc.Delay != nil {
delay := time.Duration(*lc.Delay) * time.Millisecond
options = append(options, config.WithDelay(delay))
}
if lc.LoginURL != nil {
options = append(options, config.WithLoginURL(*lc.LoginURL))
}
return options
}
package config
import "github.com/vigiloauth/vigilo/v2/idp/config"
type PasswordConfigYAML struct {
RequireUppercase *bool `yaml:"require_uppercase,omitempty"`
RequireNumber *bool `yaml:"require_number,omitempty"`
RequireSymbol *bool `yaml:"require_symbol,omitempty"`
MinimumLength *int `yaml:"minimum_length,omitempty"`
}
func (pc *PasswordConfigYAML) ToOptions() []config.PasswordConfigOptions {
options := []config.PasswordConfigOptions{}
if pc.RequireUppercase != nil {
options = append(options, config.WithUppercase())
}
if pc.RequireNumber != nil {
options = append(options, config.WithNumber())
}
if pc.RequireSymbol != nil {
options = append(options, config.WithSymbol())
}
if pc.MinimumLength != nil {
options = append(options, config.WithMinLength(*pc.MinimumLength))
}
return options
}
package config
import (
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
type ServerConfigYAML struct {
Port *string `yaml:"port,omitempty"`
CertFilePath *string `yaml:"cert_file_path,omitempty"`
KeyFilePath *string `yaml:"key_file_path,omitempty"`
SessionCookieName *string `yaml:"session_cookie_name,omitempty"`
ForceHTTPS *bool `yaml:"force_https,omitempty"`
Domain *string `yaml:"domain,omitempty"`
EnableRequestLogging *bool `yaml:"enable_request_logging,omitempty"`
ReadTimeout *int64 `yaml:"read_timeout,omitempty"`
WriteTimeout *int64 `yaml:"write_timeout,omitempty"`
AuthzCodeDuration *int64 `yaml:"authorization_code_duration,omitempty"`
}
func (sc *ServerConfigYAML) ToOptions() []config.ServerConfigOptions {
options := []config.ServerConfigOptions{}
if sc.Port != nil {
options = append(options, config.WithPort(*sc.Port))
}
if sc.CertFilePath != nil {
options = append(options, config.WithCertFilePath(*sc.CertFilePath))
}
if sc.KeyFilePath != nil {
options = append(options, config.WithKeyFilePath(*sc.KeyFilePath))
}
if sc.SessionCookieName != nil {
options = append(options, config.WithSessionCookieName(*sc.SessionCookieName))
}
if sc.ForceHTTPS != nil {
options = append(options, config.WithForceHTTPS())
}
if sc.EnableRequestLogging != nil {
options = append(options, config.WithRequestLogging(*sc.EnableRequestLogging))
}
if sc.ReadTimeout != nil {
timeoutDuration := time.Duration(*sc.ReadTimeout) * time.Second
options = append(options, config.WithReadTimeout(timeoutDuration))
}
if sc.WriteTimeout != nil {
timeoutDuration := time.Duration(*sc.WriteTimeout) * time.Second
options = append(options, config.WithWriteTimeout(timeoutDuration))
}
if sc.AuthzCodeDuration != nil {
duration := time.Duration(*sc.AuthzCodeDuration) * time.Minute
options = append(options, config.WithAuthorizationCodeDuration(duration))
}
if sc.Domain != nil {
options = append(options, config.WithDomain(*sc.Domain))
}
return options
}
package config
import "github.com/vigiloauth/vigilo/v2/idp/config"
const TLSPort int = 587
const SSLPort int = 465
type SMTPConfigYAML struct {
Host *string `yaml:"host,omitempty"`
Port *int `yaml:"port,omitempty"`
Username *string `yaml:"username,omitempty"`
Password *string `yaml:"password,omitempty"`
FromAddress *string `yaml:"from_address,omitempty"`
Encryption *string `yaml:"encryption,omitempty"`
}
func (s *SMTPConfigYAML) ToOptions() []config.SMTPConfigOptions {
options := []config.SMTPConfigOptions{}
if s.Host != nil {
options = append(options, config.WithSMTPHost(*s.Host))
}
if s.Port != nil {
switch *s.Port {
case TLSPort:
options = append(options, config.WithTLS())
case SSLPort:
options = append(options, config.WithSSL())
}
}
if s.Username != nil && s.Password != nil {
options = append(options, config.WithCredentials(*s.Username, *s.Password))
}
if s.FromAddress != nil {
options = append(options, config.WithFromAddress(*s.FromAddress))
}
if s.Encryption != nil {
options = append(options, config.WithEncryption(*s.Encryption))
}
return options
}
package config
import (
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
const oneDay time.Duration = 24 * time.Hour
type TokenConfigYAML struct {
SecretKey *string `yaml:"secret_key,omitempty"`
ExpirationTime *int64 `yaml:"expiration_time,omitempty"`
AccessTokenDuration *int64 `yaml:"access_token_duration,omitempty"`
RefreshTokenDuration *int64 `yaml:"refresh_token_duration,omitempty"`
}
func (tc *TokenConfigYAML) ToOptions() []config.TokenConfigOptions {
options := []config.TokenConfigOptions{}
if tc.ExpirationTime != nil {
duration := time.Duration(*tc.ExpirationTime) * time.Minute
options = append(options, config.WithExpirationTime(duration))
}
if tc.AccessTokenDuration != nil {
duration := time.Duration(*tc.AccessTokenDuration) * time.Minute
options = append(options, config.WithAccessTokenDuration(duration))
}
if tc.RefreshTokenDuration != nil {
duration := time.Duration(*tc.RefreshTokenDuration) * (oneDay)
options = append(options, config.WithRefreshTokenDuration(duration))
}
return options
}
package main
import (
"net/http"
"os"
"path/filepath"
"strings"
"github.com/go-chi/chi/v5"
config "github.com/vigiloauth/vigilo/v2/cmd/config/application"
"github.com/vigiloauth/vigilo/v2/idp/server"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
func main() {
isDockerENV := os.Getenv(constants.VigiloServerModeENV) == "docker"
if !isDockerENV {
return
}
config.LoadConfigurations()
vs := server.NewVigiloIdentityServer()
r := chi.NewRouter()
setupSpaRouting(r)
vs.StartServer(r)
vs.Shutdown()
}
func setupSpaRouting(r *chi.Mux) {
buildPath := os.Getenv(constants.ReactBuildPathENV)
fs := http.FileServer(http.Dir(buildPath))
r.HandleFunc("/static/*", func(w http.ResponseWriter, r *http.Request) {
filePath := strings.TrimPrefix(r.URL.Path, "/static/")
fullPath := filepath.Join(buildPath, "static", filePath)
_, err := os.Stat(fullPath)
if os.IsNotExist(err) {
web.WriteError(w, errors.New(errors.ErrCodeInternalServerError, "file not found"))
return
}
setContentTypeHeader(w, fullPath)
http.ServeFile(w, r, fullPath)
})
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, filepath.Join(buildPath, "index.html"))
})
r.Get("/authenticate", serveIndexHTML(buildPath))
r.Get("/consent", serveIndexHTML(buildPath))
r.Get("/error", serveIndexHTML(buildPath))
r.Get("/*", func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
for _, prefix := range []string{"/authenticate/", "/consent/", "/error/"} {
if strings.HasPrefix(path, prefix+"static/") {
staticPath := strings.TrimPrefix(path, prefix)
r.URL.Path = "/" + staticPath
fs.ServeHTTP(w, r)
return
}
}
http.ServeFile(w, r, filepath.Join(buildPath, "index.html"))
})
}
func serveIndexHTML(buildPath string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, filepath.Join(buildPath, "index.html"))
}
}
func setContentTypeHeader(w http.ResponseWriter, fullPath string) {
ext := filepath.Ext(fullPath)
switch ext {
case ".js":
w.Header().Set("Content-Type", "application/javascript")
case ".css":
w.Header().Set("Content-Type", "text/css")
case ".json":
w.Header().Set("Content-Type", "application/json")
case ".png":
w.Header().Set("Content-Type", "image/png")
case ".jpg", ".jpeg":
w.Header().Set("Content-Type", "image/jpeg")
case ".svg":
w.Header().Set("Content-Type", "image/svg+xml")
}
}
package config
import "time"
// AuditLogConfig holds configuration settings for audit logging,
// such as how long to retain audit logs.
type AuditLogConfig struct {
retentionPeriod time.Duration
}
// AuditLogConfigOptions defines a function signature for modifying AuditLogConfig.
type AuditLogConfigOptions func(*AuditLogConfig)
// Default retention period for audit logs: 90 days.
const defaultRetentionPeriod time.Duration = 90 * 24 * time.Hour
// NewAuditLogConfig creates a new AuditLogConfig instance,
// applying any provided options to override the default values.
func NewAuditLogConfig(opts ...AuditLogConfigOptions) *AuditLogConfig {
cfg := &AuditLogConfig{retentionPeriod: defaultRetentionPeriod}
for _, opt := range opts {
opt(cfg)
}
return cfg
}
// WithRetentionPeriod returns an option that sets a custom retention period
// for audit logs. Use this with NewAuditLogConfig to override the default.
func WithRetentionPeriod(retentionPeriod time.Duration) AuditLogConfigOptions {
return func(alc *AuditLogConfig) {
alc.retentionPeriod = retentionPeriod
}
}
// RetentionPeriod returns the configured audit log retention period.
func (alc *AuditLogConfig) RetentionPeriod() time.Duration {
return alc.retentionPeriod
}
package config
import (
"fmt"
"os"
"strings"
"sync"
"time"
)
// LogLevel represents the severity of a log message
type LogLevel int
const (
DEBUG LogLevel = iota
INFO
WARN
ERROR
FATAL
)
var levelNames = map[LogLevel]string{
DEBUG: "DEBUG",
INFO: "INFO",
WARN: "WARN",
ERROR: "ERROR",
FATAL: "FATAL",
}
var levelByName = map[string]LogLevel{
"DEBUG": DEBUG,
"INFO": INFO,
"WARN": WARN,
"ERROR": ERROR,
"FATAL": FATAL,
}
// colors for terminal output
var colors = map[LogLevel]string{
DEBUG: "\033[36m", // Cyan
INFO: "\033[32m", // Green
WARN: "\033[33m", // Yellow
ERROR: "\033[31m", // Red
FATAL: "\033[31m", // Red
}
var colorReset = "\033[0m"
// Logger is base application logger.
type Logger struct {
level LogLevel
colorized bool
mu sync.RWMutex
}
var (
instance *Logger
once sync.Once
)
// GetLogger returns the singleton logger instance
func GetLogger() *Logger {
once.Do(func() {
instance = &Logger{
level: INFO,
colorized: true,
}
})
return instance
}
// SetLevel sets the log level
func (l *Logger) SetLevel(level string) {
l.mu.Lock()
defer l.mu.Unlock()
levelUpper := strings.ToUpper(level)
if lvl, exists := levelByName[levelUpper]; exists {
l.level = lvl
timestamp := time.Now().Format("2006-01-02 15:04:05.000")
logLine := fmt.Sprintf("%s[%s] %s%s%s [LOGGER] Log level set to %s",
colors[INFO], timestamp, colors[INFO], "INFO", colorReset, levelUpper,
)
_, _ = fmt.Fprintln(os.Stdout, logLine)
if levelUpper == "DEBUG" {
logLine := fmt.Sprintf("%s[%s] %s%s%s [LOGGER] It is highly recommended to disable DEBUG logs in production environments",
colors[WARN], timestamp, colors[WARN], "WARN", colorReset,
)
_, _ = fmt.Fprintln(os.Stdout, logLine)
}
} else {
timestamp := time.Now().Format("2006-01-02 15:04:05.000")
logLine := fmt.Sprintf("%s[%s] %s%s%s [LOGGER] Invalid log level: %s. Using INFO",
colors[WARN], timestamp, colors[WARN], "WARN", colorReset, level,
)
_, _ = fmt.Fprintln(os.Stdout, logLine)
}
}
// SetColorized enables or disables colorized output
func (l *Logger) SetColorized(enabled bool) {
l.mu.Lock()
defer l.mu.Unlock()
l.colorized = enabled
}
// GetLevel returns the current log level
func (l *Logger) GetLevel() string {
l.mu.RLock()
defer l.mu.RUnlock()
return levelNames[l.level]
}
// log logs a message with the given level and module name
// log logs a message with the given level, module name, and optional requestID
func (l *Logger) log(level LogLevel, module string, requestID string, format string, args ...any) {
l.mu.RLock()
loggerLevel := l.level
colorized := l.colorized
l.mu.RUnlock()
if level < loggerLevel {
return
}
now := time.Now()
timestamp := now.Format("2006-01-02 15:04:05.000")
levelName := levelNames[level]
message := fmt.Sprintf(format, args...)
var logLine string
if colorized { //nolint
colorCode := colors[level]
// Include requestID if provided
if requestID != "" {
logLine = fmt.Sprintf("%s[%s] %s [RequestID=%s] %s%s [%s] %s",
colorCode, timestamp, levelName, requestID, colorCode, colorReset, module, message)
} else {
logLine = fmt.Sprintf("%s[%s] %s%s%s [%s] %s",
colorCode, timestamp, colorCode, levelName, colorReset, module, message)
}
} else {
// Include requestID if provided
if requestID != "" {
logLine = fmt.Sprintf("[%s] [RequestID=%s] [%s] [%s] %s", timestamp, requestID, levelName, module, message)
} else {
logLine = fmt.Sprintf("[%s] [%s] [%s] %s", timestamp, levelName, module, message)
}
}
_, _ = fmt.Fprintln(os.Stdout, logLine)
}
// Debug logs a debug message
func (l *Logger) Debug(module string, requestID string, format string, args ...any) {
l.log(DEBUG, module, requestID, format, args...)
}
// Info logs an info message
func (l *Logger) Info(module string, requestID string, format string, args ...any) {
l.log(INFO, module, requestID, format, args...)
}
// Warn logs a warning message
func (l *Logger) Warn(module string, requestID string, format string, args ...any) {
l.log(WARN, module, requestID, format, args...)
}
// Error logs an error message
func (l *Logger) Error(module string, requestID string, format string, args ...any) {
l.log(ERROR, module, requestID, format, args...)
}
func (l *Logger) Fatal(module string, requestID string, format string, args ...any) {
l.log(FATAL, module, requestID, format, args...)
os.Exit(1)
}
// Package-level convenience functions
// Debug logs a debug message
func Debug(module string, requestID string, format string, args ...any) {
GetLogger().Debug(module, requestID, format, args...)
}
// Info logs an info message
func Info(module string, requestID string, format string, args ...any) {
GetLogger().Info(module, requestID, format, args...)
}
// Warn logs a warning message
func Warn(module string, requestID string, format string, args ...any) {
GetLogger().Warn(module, requestID, format, args...)
}
// Error logs an error message
func Error(module string, requestID string, format string, args ...any) {
GetLogger().Error(module, requestID, format, args...)
}
func Fatal(module string, requestID string, format string, args ...any) {
GetLogger().Fatal(module, requestID, format, args...)
}
// SetLevel sets the log level at package level
func SetLevel(level string) {
GetLogger().SetLevel(level)
}
// SetColorized enables or disables colorized output at package level
func SetColorized(enabled bool) {
GetLogger().SetColorized(enabled)
}
package config
import (
"fmt"
"time"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// LoginConfig holds the configuration for login attempt throttling.
type LoginConfig struct {
maxFailedAttempts int // Maximum number of failed login attempts allowed.
delay time.Duration // Delay duration after exceeding max failed attempts in milliseconds
loginURL string
logger *Logger
module string
}
// LoginConfigOptions is a function type used to configure LoginConfig options.
type LoginConfigOptions func(*LoginConfig)
const (
defaultMaxFailedAttempts int = 5 // Default maximum number of failed login attempts.
defaultDelay time.Duration = 500 * time.Millisecond // Default delay duration (500 milliseconds).
)
// NewLoginConfig creates a new LoginConfig with default values and applies provided options.
//
// Parameters:
//
// opts ...LoginConfigOptions: A variadic list of LoginConfigOptions functions to configure the LoginConfig.
//
// Returns:
//
// *LoginConfig: A new LoginConfig instance.
func NewLoginConfig(opts ...LoginConfigOptions) *LoginConfig {
cfg := defaultLoginConfig()
cfg.loadOptions(opts...)
cfg.logger.Debug(cfg.module, "\n\nLogin config parameters: %s", cfg.String())
return cfg
}
// WithMaxFailedAttempts configures the maximum number of failed login attempts for the LoginConfig.
//
// Parameters:
//
// maxAttempts int: The maximum number of failed login attempts.
//
// Returns:
//
// LoginConfigOptions: A function that configures the maximum failed attempts.
func WithMaxFailedAttempts(maxAttempts int) LoginConfigOptions {
return func(lc *LoginConfig) {
if maxAttempts > defaultMaxFailedAttempts {
lc.logger.Info(lc.module, "", "Configuring LoginConfig to use [%d] max failed login attempts", maxAttempts)
lc.maxFailedAttempts = maxAttempts
}
}
}
// WithDelay configures the delay duration, in milliseconds for the LoginConfig.
// Default is 500 milliseconds
//
// Parameters:
//
// delay time.Duration: The delay duration.
//
// Returns:
//
// LoginConfigOptions: A function that configures the delay duration.
func WithDelay(delay time.Duration) LoginConfigOptions {
return func(lc *LoginConfig) {
if !isInMilliseconds(delay) {
lc.logger.Warn(lc.module, "", "Delay duration is not in milliseconds, using default value of 500ms")
lc.delay = defaultDelay
return
}
lc.logger.Info(lc.module, "", "Configuring LoginConfig to use delay=[%s]", delay)
lc.delay = delay
}
}
// MaxFailedAttempts returns the maximum number of failed login attempts from the LoginConfig.
//
// Returns:
//
// int: The maximum number of failed login attempts.
func (lc *LoginConfig) MaxFailedAttempts() int {
return lc.maxFailedAttempts
}
// WithLoginURL allows the user to define their own login URL.
//
// Parameters:
//
// url string: The login url
//
// Returns:
//
// LoginConfigOptions: A function that configures the login url.
func WithLoginURL(url string) LoginConfigOptions {
return func(lc *LoginConfig) {
lc.logger.Info(lc.module, "", "Configuring LoginConfig to use URL=[%s]", url)
lc.loginURL = url
}
}
// LoginURL returns the predefined login URL.
//
// Returns:
//
// string: The predefined login URL.
func (lc *LoginConfig) LoginURL() string {
return lc.loginURL
}
// Delay returns the delay duration from the LoginConfig.
//
// Returns:
//
// time.Duration: The delay duration.
func (lc *LoginConfig) Delay() time.Duration {
return lc.delay
}
func (lc *LoginConfig) String() string {
return fmt.Sprintf(
"\n\tMaxFailedAttempts: %d\n"+
"\tDelay: %s\n"+
"\tLoginURL: %s\n",
lc.maxFailedAttempts,
lc.delay,
lc.loginURL,
)
}
func defaultLoginConfig() *LoginConfig {
return &LoginConfig{
maxFailedAttempts: defaultMaxFailedAttempts,
delay: defaultDelay,
loginURL: web.UserEndpoints.Login,
logger: GetLogger(),
module: "Login Config",
}
}
func (cfg *LoginConfig) loadOptions(opts ...LoginConfigOptions) {
if len(opts) > 0 {
cfg.logger.Info(cfg.module, "", "Creating login config with %d options", len(opts))
for _, opt := range opts {
opt(cfg)
}
} else {
cfg.logger.Info(cfg.module, "", "Using default login config")
}
}
package config
import (
"fmt"
)
const defaultRequiredPasswordLength int = 5
// PasswordConfig holds the configuration for password complexity requirements.
type PasswordConfig struct {
requireUpper bool // Indicates whether uppercase letters are required.
requireNumber bool // Indicates whether numbers are required.
requireSymbol bool // Indicates whether symbols are required.
minLength int // Minimum required password length.
logger *Logger
module string
}
// PasswordConfigOptions is a function type used to configure PasswordConfig options.
type PasswordConfigOptions func(*PasswordConfig)
// NewPasswordConfig creates a new PasswordConfig with default values and applies provided options.
//
// Parameters:
//
// opts ...PasswordConfigOption: A variadic list of PasswordConfigOption functions to configure the PasswordConfig.
//
// Returns:
//
// *PasswordConfig: A new PasswordConfig instance.
func NewPasswordConfig(opts ...PasswordConfigOptions) *PasswordConfig {
cfg := defaultPasswordConfig()
cfg.loadOptions(opts...)
cfg.logger.Debug(cfg.module, "\n\nPassword config parameters: %s", cfg.String())
return cfg
}
// WithUppercase configures the PasswordConfig to require uppercase letters.
//
// Returns:
//
// PasswordConfigOption: A function that configures the uppercase requirement.
func WithUppercase() PasswordConfigOptions {
return func(pc *PasswordConfig) {
pc.logger.Info(pc.module, "", "Configuring PasswordConfig to require an uppercase")
pc.requireUpper = true
}
}
// WithNumber configures the PasswordConfig to require numbers.
//
// Returns:
//
// PasswordConfigOption: A function that configures the number requirement.
func WithNumber() PasswordConfigOptions {
return func(pc *PasswordConfig) {
pc.logger.Info(pc.module, "", "Configuring PasswordConfig to require a number")
pc.requireNumber = true
}
}
// WithSymbol configures the PasswordConfig to require symbols.
//
// Returns:
//
// PasswordConfigOption: A function that configures the symbol requirement.
func WithSymbol() PasswordConfigOptions {
return func(pc *PasswordConfig) {
pc.logger.Info(pc.module, "", "Configuring PasswordConfig to require a symbol")
pc.requireSymbol = true
}
}
// WithMinLength configures the minimum required password length for the PasswordConfig.
//
// Parameters:
//
// length int: The minimum password length.
//
// Returns:
//
// PasswordConfigOption: A function that configures the minimum length.
func WithMinLength(length int) PasswordConfigOptions {
return func(pc *PasswordConfig) {
if length > defaultRequiredPasswordLength {
pc.logger.Info(pc.module, "", "Configuring PasswordConfig minimum length=[%d]", length)
pc.minLength = length
}
}
}
// RequireUppercase returns whether uppercase letters are required from the PasswordConfig.
//
// Returns:
//
// bool: True if uppercase letters are required, false otherwise.
func (pc *PasswordConfig) RequireUppercase() bool {
return pc.requireUpper
}
// RequireNumber returns whether numbers are required from the PasswordConfig.
//
// Returns:
//
// bool: True if numbers are required, false otherwise.
func (pc *PasswordConfig) RequireNumber() bool {
return pc.requireNumber
}
// RequireSymbol returns whether symbols are required from the PasswordConfig.
//
// Returns:
//
// bool: True if symbols are required, false otherwise.
func (pc *PasswordConfig) RequireSymbol() bool {
return pc.requireSymbol
}
// MinLength returns the minimum required password length from the PasswordConfig.
//
// Returns:
//
// int: The minimum required password length.
func (pc *PasswordConfig) MinLength() int {
return pc.minLength
}
func (pc *PasswordConfig) String() string {
return fmt.Sprintf(
"\n\tRequireUppercase: %t\n"+
"\tRequireNumber: %t\n"+
"\tRequireSymbol: %t\n"+
"\tMinLength: %d\n",
pc.requireUpper,
pc.requireNumber,
pc.requireSymbol,
pc.minLength,
)
}
func defaultPasswordConfig() *PasswordConfig {
return &PasswordConfig{
requireUpper: false,
requireNumber: false,
requireSymbol: false,
minLength: defaultRequiredPasswordLength,
logger: GetLogger(),
module: "Password Config",
}
}
func (cfg *PasswordConfig) loadOptions(opts ...PasswordConfigOptions) {
if len(opts) > 0 {
cfg.logger.Info(cfg.module, "", "Creating password config with %d options", len(opts))
for _, opt := range opts {
opt(cfg)
}
} else {
cfg.logger.Info(cfg.module, "", "Using default password config")
}
}
package config
import (
_ "embed"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/joho/godotenv"
)
// ServerConfig holds the configuration for the server.
type ServerConfig struct {
certFilePath string // Path to the SSL certificate file.
keyFilePath string // Path to the SSL key file.
baseURL string // Base URL of the server.
sessionCookieName string // Name of the session cookie.
domain string
forceHTTPS bool // Whether to force HTTPS connections.
port string // Port number the server listens on.
requestsPerMinute int // Maximum requests allowed per minute.
requestLogging bool // Whether to enable request logging or not.
readTimeout time.Duration // Read timeout for HTTP requests in seconds. Default is 15
writeTimeout time.Duration // Write timeout for HTTP responses in seconds. Default is 13.
authorizationCodeDuration time.Duration // Authz code duration in minutes. Default is 10
tokenConfig *TokenConfig // JWT configuration.
loginConfig *LoginConfig // Login configuration.
passwordConfig *PasswordConfig // Password configuration.
smtpConfig *SMTPConfig // SMTP configuration.
auditLogConfig *AuditLogConfig // Audit Log configuration.
logger *Logger // Logging Configuration
module string
}
// ServerConfigOptions is a function type used to configure ServerConfig options.
type ServerConfigOptions func(*ServerConfig)
var (
serverConfigInstance *ServerConfig // Singleton instance of ServerConfig.
serverConfigOnce sync.Once // Ensures singleton initialization.
)
const (
defaultPort string = "8080" // Default port number.
defaultHTTPSRequirement bool = false // Default HTTPS requirement.
defaultDomain string = "localhost"
defaultReadTimeout time.Duration = 15 * time.Second // Default read timeout.
defaultWriteTimeout time.Duration = 15 * time.Second // Default write timeout.
defaultAuthorizationCodeDuration time.Duration = 10 * time.Minute // Default Authorization Code Duration
defaultRequestsPerMinute int = 100 // Default maximum requests per minute.
defaultSessionCookieName string = "vigilo-auth-session-cookie" // Default session cookie name.
defaultRequestLogging bool = true // Default request logging
)
// GetServerConfig returns the global server configuration instance (singleton).
//
// Returns:
//
// *ServerConfig: The server configuration instance.
func GetServerConfig() *ServerConfig {
if serverConfigInstance == nil {
serverConfigOnce.Do(func() {
serverConfigInstance = NewServerConfig()
})
}
return serverConfigInstance
}
// NewServerConfig creates a new ServerConfig with default values and applies provided options.
//
// Parameters:
//
// opts ...ServerConfigOptions: A variadic list of ServerConfigOptions functions to configure the ServerConfig.
//
// Returns:
//
// *ServerConfig: A new ServerConfig instance.
func NewServerConfig(opts ...ServerConfigOptions) *ServerConfig {
cfg := defaultServerConfig()
cfg.loadOptions(opts...)
cfg.logger.Info(cfg.module, "", "Initializing server config")
serverConfigInstance = cfg
return cfg
}
// WithPort configures the server port.
//
// Parameters:
//
// port int: The port number.
//
// Returns:
//
// ServerConfigOptions: A function that configures the port.
func WithPort(port string) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.logger.Info(sc.module, "", "Configuring server to run on port [%s]", port)
sc.port = port
}
}
// WithCertFilePath configures the SSL certificate file path.
//
// Parameters:
//
// filePath string: The certificate file path.
//
// Returns:
//
// ServerConfigOptions: A function that configures the certificate file path.
func WithCertFilePath(filePath string) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.certFilePath = filePath
}
}
// WithKeyFilePath configures the SSL key file path.
//
// Parameters:
//
// filePath string: The key file path.
//
// Returns:
//
// ServerConfigOptions: A function that configures the key file path.
func WithKeyFilePath(filePath string) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.keyFilePath = filePath
}
}
// WithSessionCookieName configures the session cookie name.
//
// Parameters:
//
// cookieName string: The session cookie name.
//
// Returns:
//
// ServerConfigOptions: A function that configures the session cookie name.
func WithSessionCookieName(cookieName string) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.sessionCookieName = cookieName
}
}
// WithBaseURL configures the server base URL.
//
// Parameters:
//
// baseURL string: The base URL.
//
// Returns:
//
// ServerConfigOptions: A function that configures the base URL.
func WithBaseURL(baseURL string) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.baseURL = baseURL
}
}
// WithForceHTTPS configures whether to force HTTPS connections.
//
// Returns:
//
// ServerConfigOptions: A function that configures HTTPS forcing.
func WithForceHTTPS() ServerConfigOptions {
return func(sc *ServerConfig) {
if sc.certFilePath == "" || sc.keyFilePath == "" {
sc.logger.Warn(sc.module, "", "SSL certificate or key file path is not set. Defaulting to HTTP.")
return
}
sc.forceHTTPS = true
}
}
// WithReadTimeout configures the read timeout in seconds.
// Default is 15 seconds.
//
// Parameters:
//
// timeout time.Duration: The read timeout duration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the read timeout.
func WithReadTimeout(timeout time.Duration) ServerConfigOptions {
return func(sc *ServerConfig) {
if !isInSeconds(timeout) {
sc.logger.Warn(sc.module, "", "Read timeout was not set to seconds. Defaulting to 15 seconds.")
timeout = defaultReadTimeout
return
}
sc.readTimeout = timeout
}
}
// WithWriteTimeout configures the write timeout in seconds
// Default is 15 seconds.
//
// Parameters:
//
// timeout time.Duration: The write timeout duration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the write timeout.
func WithWriteTimeout(timeout time.Duration) ServerConfigOptions {
return func(sc *ServerConfig) {
if !isInSeconds(timeout) {
sc.logger.Warn(sc.module, "", "Write timeout was not set to seconds. Defaulting to 15 seconds.")
timeout = defaultWriteTimeout
return
}
sc.writeTimeout = timeout
}
}
// WithTokenConfig configures the JWT configuration.
//
// Parameters:
//
// jwtConfig *JWTConfig: The JWT configuration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the JWT configuration.
func WithTokenConfig(jwtConfig *TokenConfig) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.tokenConfig = jwtConfig
}
}
// WithLoginConfig configures the login configuration.
//
// Parameters:
//
// loginConfig *LoginConfig: The login configuration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the login configuration.
func WithLoginConfig(loginConfig *LoginConfig) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.loginConfig = loginConfig
}
}
// WithPasswordConfig configures the password configuration.
//
// Parameters:
//
// passwordConfig *PasswordConfig: The password configuration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the password configuration.
func WithPasswordConfig(passwordConfig *PasswordConfig) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.passwordConfig = passwordConfig
}
}
// WithMaxRequestsPerMinute configures the max requests the server can take per minute.
//
// Parameters:
//
// requests int: The amount of requests.
//
// Returns:
//
// ServerConfigOptions: A function that configures the server configuration.
func WithMaxRequestsPerMinute(requests int) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.requestsPerMinute = requests
}
}
// WithAuthorizationCodeDuration configures the duration of the authorization code.
//
// Parameters:
//
// duration time.Duration: The duration of the authorization code.
//
// Returns:
//
// ServerConfigOptions: A function that configures the server configuration.
func WithAuthorizationCodeDuration(duration time.Duration) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.authorizationCodeDuration = duration
}
}
// WithSMTPConfig configures the servers SMTP configuration.
//
// Parameters:
//
// smtpConfig *SMTPConfig: The SMTP configuration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the server configuration.
func WithSMTPConfig(smtpConfig *SMTPConfig) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.smtpConfig = smtpConfig
}
}
// WithAuditLogConfig configures the servers Audit Log configuration.
//
// Parameters:
//
// auditLogConfig *AuditLogConfig: The audit log configuration.
//
// Returns:
//
// ServerConfigOptions: A function that configures the server configuration.
func WithAuditLogConfig(auditLogConfig *AuditLogConfig) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.auditLogConfig = auditLogConfig
}
}
// WithRequestLogging configures if the server uses request logging.
//
// Parameters:
//
// enable bool: Whether or not to enable request logging.
//
// Returns:
//
// ServerConfigOptions: A function that configures the server configuration.
func WithRequestLogging(enable bool) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.requestLogging = enable
}
}
// WithDomain configures the servers domain.
//
// Parameters:
//
// domain string: The domain.
//
// Returns:
//
// ServerConfigOptions: A function that configures the server configuration.
func WithDomain(domain string) ServerConfigOptions {
return func(sc *ServerConfig) {
sc.domain = domain
}
}
// Port returns the servers port from.
//
// Returns:
//
// int: The port number.
func (sc *ServerConfig) Port() string {
return sc.port
}
// BaseURL returns the servers base URL
//
// Returns:
//
// string: The base URL
func (sc *ServerConfig) BaseURL() string {
return sc.baseURL
}
// CertFilePath returns the servers cert file path.
//
// Returns:
//
// string: The cert file path
func (sc *ServerConfig) CertFilePath() string {
return sc.certFilePath
}
// KeyFilePath returns the path to the SSL key file.
//
// Returns:
//
// string: The SSL key file path.
func (sc *ServerConfig) KeyFilePath() string {
return sc.keyFilePath
}
// ForceHTTPS returns whether HTTPS connections are enforced.
//
// Returns:
//
// bool: True if HTTPS is enforced, false otherwise.
func (sc *ServerConfig) ForceHTTPS() bool {
return sc.forceHTTPS
}
// ReadTimeout returns the read timeout for HTTP requests.
//
// Returns:
//
// time.Duration: The read timeout duration.
func (sc *ServerConfig) ReadTimeout() time.Duration {
return sc.readTimeout
}
// WriteTimeout returns the write timeout for HTTP responses.
//
// Returns:
//
// time.Duration: The write timeout duration.
func (sc *ServerConfig) WriteTimeout() time.Duration {
return sc.writeTimeout
}
// TokenConfig returns the Token configuration.
//
// Returns:
//
// *TokenConfig: The Token configuration.
func (sc *ServerConfig) TokenConfig() *TokenConfig {
return sc.tokenConfig
}
// LoginConfig returns the login configuration.
//
// Returns:
//
// *LoginConfig: The login configuration.
func (sc *ServerConfig) LoginConfig() *LoginConfig {
return sc.loginConfig
}
// PasswordConfig returns the password configuration.
//
// Returns:
//
// *PasswordConfig: The password configuration.
func (sc *ServerConfig) PasswordConfig() *PasswordConfig {
return sc.passwordConfig
}
// SessionCookieName returns the name of the session cookie.
//
// Returns:
//
// string: The session cookie name.
func (sc *ServerConfig) SessionCookieName() string {
return sc.sessionCookieName
}
// MaxRequestsPerMinute returns the maximum number of requests allowed per minute.
//
// Returns:
//
// int: The maximum number of requests per minute.
func (sc *ServerConfig) MaxRequestsPerMinute() int {
return sc.requestsPerMinute
}
func (sc *ServerConfig) URL() string {
if sc.forceHTTPS {
return fmt.Sprintf("https://%s%s", sc.domain, sc.baseURL)
}
return fmt.Sprintf("http://%s:%s%s", sc.domain, sc.port, sc.baseURL)
}
func (sc *ServerConfig) EnableRequestLogging() bool {
return sc.requestLogging
}
func (sc *ServerConfig) Domain() string {
return sc.domain
}
func (sc *ServerConfig) Logger() *Logger {
return sc.logger
}
func (sc *ServerConfig) AuthorizationCodeDuration() time.Duration {
return sc.authorizationCodeDuration
}
func (sc *ServerConfig) SMTPConfig() *SMTPConfig {
return sc.smtpConfig
}
func (sc *ServerConfig) AuditLogConfig() *AuditLogConfig {
return sc.auditLogConfig
}
func (sc *ServerConfig) SetLoginConfig(loginConfig *LoginConfig) {
sc.loginConfig = loginConfig
}
func (sc *ServerConfig) SetPasswordConfig(passwordConfig *PasswordConfig) {
sc.passwordConfig = passwordConfig
}
func (sc *ServerConfig) SetTokenConfig(tokenConfig *TokenConfig) {
sc.tokenConfig = tokenConfig
}
func (sc *ServerConfig) SetSMTPConfig(smtpConfig *SMTPConfig) {
sc.smtpConfig = smtpConfig
}
func (sc *ServerConfig) SetAuditLogConfig(auditLogConfig *AuditLogConfig) {
sc.auditLogConfig = auditLogConfig
}
func (sc *ServerConfig) SetBaseURL(url string) {
sc.baseURL = url
}
func isInSeconds(duration time.Duration) bool { return duration%time.Second == 0 }
func isInHours(duration time.Duration) bool { return duration%time.Hour == 0 }
func isInMinutes(duration time.Duration) bool { return duration%time.Minute == 0 }
func isInMilliseconds(duration time.Duration) bool { return duration%time.Millisecond == 0 }
func defaultServerConfig() *ServerConfig {
logger := GetLogger()
module := "Server Config"
sc := &ServerConfig{
port: defaultPort,
domain: defaultDomain,
forceHTTPS: defaultHTTPSRequirement,
requestLogging: defaultRequestLogging,
readTimeout: defaultReadTimeout,
writeTimeout: defaultWriteTimeout,
requestsPerMinute: defaultRequestsPerMinute,
sessionCookieName: defaultSessionCookieName,
authorizationCodeDuration: defaultAuthorizationCodeDuration,
logger: logger,
module: module,
baseURL: "/identity",
}
sc.loadEnvFiles()
sc.tokenConfig = NewTokenConfig()
sc.loginConfig = NewLoginConfig()
sc.passwordConfig = NewPasswordConfig()
sc.smtpConfig = NewSMTPConfig()
sc.auditLogConfig = NewAuditLogConfig()
return sc
}
func (cfg *ServerConfig) loadOptions(opts ...ServerConfigOptions) {
if len(opts) > 0 {
cfg.logger.Info(cfg.module, "", "Creating server config with %d options", len(opts))
for _, opt := range opts {
opt(cfg)
}
} else {
cfg.logger.Info(cfg.module, "", "Using default server config")
}
}
// loadEnvFiles loads configuration from .env files
func (sc *ServerConfig) loadEnvFiles() {
var (
_, b, _, _ = runtime.Caller(0) // Get the directory of this file
basePath = filepath.Dir(b) // Base path of the current file
EnvFilePath = filepath.Join(basePath, "../../.env")
TestEnvFilePath = filepath.Join(basePath, "../../.env.test")
)
if isTestEnvironment() {
sc.logger.Info(sc.module, "", "Loading test environment file")
sc.loadEnvFile(TestEnvFilePath)
} else {
sc.logger.Info(sc.module, "", "Loading environment file: %s", EnvFilePath)
sc.loadEnvFile(EnvFilePath)
}
}
func isTestEnvironment() bool {
if testing.Testing() {
return true
}
for _, arg := range os.Args {
if strings.Contains(arg, "test.") {
return true
}
}
return false
}
func (sc *ServerConfig) loadEnvFile(fileName string) {
err := godotenv.Load(fileName)
if err != nil {
sc.logger.Warn(sc.module, "", "Environment file not loaded: %v", err)
}
}
package config
import (
"os"
"github.com/vigiloauth/vigilo/v2/internal/constants"
)
type SMTPConfig struct {
host string
port int
username string
password string
fromAddress string
encryption string
isHealthy bool
logger *Logger
module string
}
const (
defaultSMTPHost string = "smtp.gmail.com"
TLSPort int = 587
SSLPort int = 465
SSLEncryption string = "ssl"
TLSEncryption string = "tls"
)
type SMTPConfigOptions func(*SMTPConfig)
func NewSMTPConfig(opts ...SMTPConfigOptions) *SMTPConfig {
cfg := defaultSMTPConfig()
cfg.loadOptions(opts...)
return cfg
}
func WithSMTPHost(host string) SMTPConfigOptions {
return func(s *SMTPConfig) {
s.logger.Info(s.module, "", "Configuring SMTP Config to use host [%s]", host)
s.host = host
}
}
func WithSSL() SMTPConfigOptions {
return func(s *SMTPConfig) {
s.port = SSLPort
}
}
func WithTLS() SMTPConfigOptions {
return func(s *SMTPConfig) {
s.port = TLSPort
}
}
func WithCredentials(username, password string) SMTPConfigOptions {
return func(s *SMTPConfig) {
s.username = username
s.password = password
}
}
func WithFromAddress(fromAddress string) SMTPConfigOptions {
return func(s *SMTPConfig) {
s.fromAddress = fromAddress
}
}
func WithEncryption(encryption string) SMTPConfigOptions {
return func(s *SMTPConfig) {
if encryption != SSLEncryption && encryption != TLSEncryption {
s.logger.Warn(s.module, "", "SMTP Configuration not using TLS or SSL, default to SSL")
s.encryption = SSLEncryption
return
}
s.encryption = encryption
}
}
func (s *SMTPConfig) Host() string {
return s.host
}
func (s *SMTPConfig) Port() int {
return s.port
}
func (s *SMTPConfig) Username() string {
return s.username
}
func (s *SMTPConfig) Password() string {
return s.password
}
func (s *SMTPConfig) FromAddress() string {
return s.fromAddress
}
func (s *SMTPConfig) SetHealth(isHealthy bool) {
s.isHealthy = isHealthy
}
func (s *SMTPConfig) IsHealthy() bool {
return s.isHealthy
}
func (cfg *SMTPConfig) loadOptions(opts ...SMTPConfigOptions) {
if len(opts) > 0 && len(opts) == 5 {
cfg.logger.Info(cfg.module, "", "Creating custom SMTP configuration")
for _, opt := range opts {
opt(cfg)
}
} else {
cfg.logger.Warn(cfg.module, "", "Missing required options for a custom SMTP configuration. Falling back to default settings")
}
}
func defaultSMTPConfig() *SMTPConfig {
fromAddress := os.Getenv(constants.SMTPFromAddressENV)
username := os.Getenv(constants.SMTPUsernameENV)
password := os.Getenv(constants.SMTPPasswordENV)
return &SMTPConfig{
host: defaultSMTPHost,
port: TLSPort,
fromAddress: fromAddress,
encryption: TLSEncryption,
logger: GetLogger(),
module: "SMTP Config",
username: username,
password: password,
isHealthy: false,
}
}
package config
import (
"crypto/rsa"
"encoding/base64"
"fmt"
"os"
"time"
"github.com/golang-jwt/jwt"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
// TokenConfig holds the configuration for JWT token generation and validation.
type TokenConfig struct {
privateKey *rsa.PrivateKey // Secret key used for signing and verifying JWT tokens.
publicKey *rsa.PublicKey // Public key used for verifying JWT tokens.
keyID string // Key ID used to identify the key.
expirationTime time.Duration // Expiration time for JWT tokens in hours
accessTokenDuration time.Duration // Access token duration in minutes
refreshTokenDuration time.Duration // Refresh token duration in days
issuer string
logger *Logger
module string
}
// TokenConfigOptions is a function type used to configure JWTConfig options.
type TokenConfigOptions func(*TokenConfig)
const (
defaultExpirationTime time.Duration = time.Duration(24) * time.Hour
defaultAccessTokenDuration time.Duration = time.Duration(30) * time.Minute
defaultRefreshTokenDuration time.Duration = time.Duration(1) * 24 * time.Hour
)
// NewTokenConfig creates a new JWTConfig with default values and applies provided options.
//
// Parameters:
//
// opts ...JWTOption: A variadic list of JWTOption functions to configure the JWTConfig.
//
// Returns:
//
// *JWTConfig: A new JWTConfig instance.
func NewTokenConfig(opts ...TokenConfigOptions) *TokenConfig {
cfg := defaultTokenConfig()
cfg.loadOptions(opts...)
cfg.logger.Debug(cfg.module, "", "\n\nToken config parameters: %v", cfg.String())
return cfg
}
// WithExpirationTime configures the expiration time, in minutes, for the Token Config.
//
// Parameters:
//
// duration time.Duration: The expiration time duration.
//
// Returns:
//
// JWTOption: A function that configures the expiration time.
func WithExpirationTime(duration time.Duration) TokenConfigOptions {
return func(c *TokenConfig) {
if !isInHours(duration) {
c.logger.Warn(c.module, "", "Token expiration time is not in hours, using default value")
c.expirationTime = defaultExpirationTime
return
}
c.logger.Debug(c.module, "", "Configuring TokenConfig with expiration time=[%s]", duration)
c.expirationTime = duration
}
}
// WithAccessTokenDuration configures the duration, in minutes, for the access token duration.
// Default is 30 minutes.
//
// Parameters:
//
// duration time.Duration: The expiration time duration.
//
// Returns:
//
// JWTOption: A function that configures the expiration time.
func WithAccessTokenDuration(duration time.Duration) TokenConfigOptions {
return func(c *TokenConfig) {
if !isInMinutes(duration) {
c.logger.Warn(c.module, "", "Access token duration is not in minutes, using default value")
c.accessTokenDuration = defaultAccessTokenDuration
return
}
c.logger.Debug(c.module, "", "Configuring TokenConfig with access token duration=[%s]", duration)
c.accessTokenDuration = duration
}
}
// WithRefreshTokenDuration configures the duration, in days, for the refresh token duration.
// Default is 30 days.
//
// Parameters:
//
// duration time.Duration: The expiration time duration.
//
// Returns:
//
// JWTOption: A function that configures the expiration time.
func WithRefreshTokenDuration(duration time.Duration) TokenConfigOptions {
return func(c *TokenConfig) {
c.logger.Debug(c.module, "", "Configuring TokenConfig with refresh token duration=[%s]", duration)
c.refreshTokenDuration = duration
}
}
// SecretKey returns the secret key from the JWTConfig.
//
// Returns:
//
// string: The secret key.
func (j *TokenConfig) SecretKey() *rsa.PrivateKey {
return j.privateKey
}
// PublicKey returns the public key from the JWTConfig.
//
// Returns:
//
// *rsa.PublicKey: The public key.
func (j *TokenConfig) PublicKey() *rsa.PublicKey {
return j.publicKey
}
// KeyID returns the key ID from the JWTConfig.
//
// Returns:
//
// string: The key ID.
func (j *TokenConfig) KeyID() string {
return j.keyID
}
// ExpirationTime returns the expiration time from the JWTConfig.
//
// Returns:
//
// time.Duration: The expiration time.
func (j *TokenConfig) ExpirationTime() time.Duration {
return j.expirationTime
}
func (j *TokenConfig) RefreshTokenDuration() time.Duration {
return j.refreshTokenDuration
}
func (j *TokenConfig) AccessTokenDuration() time.Duration {
return j.accessTokenDuration
}
func (j *TokenConfig) Issuer() string {
return j.issuer
}
func (t *TokenConfig) SetIssuer(issuer string) {
t.issuer = issuer
}
func (j *TokenConfig) String() string {
return fmt.Sprintf(
"\tExpirationTime: %s\n"+
"\tRefreshTokenDuration: %s\n"+
"\tAccessTokenDuration: %s\n",
j.expirationTime,
j.refreshTokenDuration,
j.accessTokenDuration,
)
}
func defaultTokenConfig() *TokenConfig {
privateKeyBase64 := os.Getenv(constants.TokenPrivateKeyENV)
publicKeyBase64 := os.Getenv(constants.TokenPublicKeyENV)
if privateKeyBase64 == "" {
panic("Private key not found in environment variable")
}
if publicKeyBase64 == "" {
panic("Public key not found in environment variable")
}
privateKeyBytes, err := base64.StdEncoding.DecodeString(privateKeyBase64)
if err != nil {
panic("Failed to decode private key: " + err.Error())
}
publicKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyBase64)
if err != nil {
panic("Failed to decode public key: " + err.Error())
}
privateKeyParsed, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
panic("Failed to parse private key: " + err.Error())
}
publicKeyParsed, err := jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
if err != nil {
panic("Failed to parse public key: " + err.Error())
}
return &TokenConfig{
privateKey: privateKeyParsed,
publicKey: publicKeyParsed,
keyID: utils.GenerateJWKKeyID(publicKeyBase64),
expirationTime: defaultExpirationTime,
accessTokenDuration: defaultAccessTokenDuration,
refreshTokenDuration: defaultRefreshTokenDuration,
logger: GetLogger(),
module: "Token Config",
}
}
func (cfg *TokenConfig) loadOptions(opts ...TokenConfigOptions) {
if len(opts) > 0 {
cfg.logger.Info(cfg.module, "", "Creating token config with %d options", len(opts))
for _, opt := range opts {
opt(cfg)
}
} else {
cfg.logger.Info(cfg.module, "", "Using default token config")
}
}
package server
import (
"context"
"net/http"
"os"
"os/signal"
"syscall"
"github.com/go-chi/chi/v5"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/container"
"github.com/vigiloauth/vigilo/v2/internal/routes"
)
// VigiloIdentityServer represents the main identity server structure for the Vigilo application.
// It encapsulates the server configuration, dependency injection container, HTTP server, router,
// logger, and module information.
//
// Fields:
// - serverConfig *config.ServerConfig: Configuration settings for the server.
// - container *container.DIContainer: Dependency injection container for managing application dependencies.
// - httpServer *http.Server: The HTTP server instance used to handle incoming requests.
// - router chi.Router: The router instance for defining and managing HTTP routes.
// - logger *config.Logger: Logger instance for logging server activities and errors.
// - module string: The name of the module associated with this server.
type VigiloIdentityServer struct {
serverConfig *config.ServerConfig
container *container.DIContainer
httpServer *http.Server
router chi.Router
logger *config.Logger
module string
}
// NewVigiloIdentityServer initializes and returns a new instance of VigiloIdentityServer.
// It sets up the necessary components including the dependency injection container,
// server configuration, logger, and application router.
//
// The function performs the following steps:
// 1. Retrieves the server configuration and logger.
// 2. Logs the initialization of the Vigilo Identity Provider module.
// 3. Creates and initializes a dependency injection container.
// 4. Configures the application router with middleware, handlers, and settings.
// 5. Returns a fully initialized VigiloIdentityServer instance.
//
// Returns:
// - *VigiloIdentityServer - A pointer to the initialized VigiloIdentityServer instance.
func NewVigiloIdentityServer() *VigiloIdentityServer {
module := "Vigilo Identity Provider"
serverConfig := config.GetServerConfig()
logger := serverConfig.Logger()
logger.Info(module, "", "Initializing Vigilo Identity Provider")
container := container.NewDIContainer(logger)
container.Init()
appRouter := routes.NewRouterConfig(
chi.NewRouter(),
logger,
config.GetServerConfig().ForceHTTPS(),
config.GetServerConfig().EnableRequestLogging(),
container.ServiceRegistry().Middleware(),
container.HandlerRegistry(),
)
appRouter.Init()
return &VigiloIdentityServer{
container: container,
serverConfig: serverConfig,
logger: logger,
module: module,
httpServer: container.HTTPServer(),
router: appRouter.Router(),
}
}
// StartServer initializes and starts the Vigilo Identity Server.
// It sets up the HTTP server with the provided router and handles graceful shutdown.
//
// Parameters:
// - r *chi.Mux: The router to be used for handling HTTP routes.
//
// Behavior:
// - Configures the "/identity" route and mounts the server's router.
// - Starts the HTTP server, either with HTTPS (if configured) or plain HTTP.
// - Logs server startup information, including port and base URL.
// - Monitors for termination signals (os.Interrupt, syscall.SIGTERM) to gracefully shut down.
//
// Notes:
// - If HTTPS is enabled, the server requires valid certificate and key file paths.
// - Exits the application if HTTPS is requested but the certificate or key file paths are missing.
// - Logs any server errors and exits if the server fails to start.
func (s *VigiloIdentityServer) StartServer(r *chi.Mux) {
r.Route("/identity", func(subRouter chi.Router) {
subRouter.Mount("/", s.router)
})
s.httpServer.Handler = r
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM)
go func() {
s.logger.Info(s.module, "", "Starting VigiloAuth Identity Provider on port [%s] with base URL [%s]",
s.serverConfig.Port(),
s.serverConfig.BaseURL(),
)
var err error
if s.serverConfig.ForceHTTPS() {
certFile := s.serverConfig.CertFilePath()
keyFile := s.serverConfig.KeyFilePath()
if certFile == "" || keyFile == "" {
s.logger.Error(s.module, "", "HTTPS requested but certificate or key file path is not configured. Exiting.")
os.Exit(1)
}
err = s.httpServer.ListenAndServeTLS(certFile, keyFile)
} else {
err = s.httpServer.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
s.logger.Error(s.module, "", "HTTP server error: %v", err)
os.Exit(1)
}
}()
<-stop
}
// Shutdown gracefully shuts down the VigiloIdentityServer instance.
// It performs the following actions:
// 1. Creates a context with a timeout of 10 seconds to ensure the shutdown process does not hang indefinitely.
// 2. Shuts down the container associated with the server.
// 3. Attempts to gracefully shut down the HTTP server using the created context.
// - If an error occurs during the HTTP server shutdown, it logs the error.
// - If the shutdown is successful, it logs a message indicating the server was shut down gracefully.
func (s *VigiloIdentityServer) Shutdown() {
ctx, cancel := context.WithTimeout(context.Background(), constants.TenSecondTimeout)
defer cancel()
s.container.Shutdown()
if err := s.httpServer.Shutdown(ctx); err != nil {
s.logger.Error(s.module, "", "HTTP serve shutdown err: %v", err)
} else {
s.logger.Info(s.module, "", "HTTP server shutdown gracefully")
}
}
func (s *VigiloIdentityServer) Router() chi.Router {
return s.router
}
package background
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
)
type AuditJobs struct {
auditLogger domain.AuditLogger
retentionPeriod time.Duration
purgeInterval time.Duration
logger *config.Logger
module string
}
func NewAuditJobs(auditLogger domain.AuditLogger, retentionPeriod, purgeInterval time.Duration) *AuditJobs {
return &AuditJobs{
auditLogger: auditLogger,
retentionPeriod: retentionPeriod,
purgeInterval: purgeInterval,
logger: config.GetServerConfig().Logger(),
module: "Audit Jobs",
}
}
func (a *AuditJobs) PurgeLogs(ctx context.Context) {
a.logger.Info(a.module, "", "[PurgeLogs]: Starting process of removing old audit logs")
ticker := time.NewTicker(a.purgeInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cutoff := time.Now().Add(-a.retentionPeriod)
if err := a.auditLogger.DeleteOldEvents(ctx, cutoff); err != nil {
a.logger.Error(a.module, "", "[PurgeLogs]: There was an error deleting old audit logs: %v", err)
continue //nolint
}
case <-ctx.Done():
a.logger.Info(a.module, "", "[PurgeLogs]: Stopping the process of deleting old audit logs")
return
}
}
}
package background
import (
"context"
"sync"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
type JobFunc func(ctx context.Context)
type Scheduler struct {
mu sync.RWMutex
jobs []JobFunc
wg sync.WaitGroup
stopCh chan struct{}
logger *config.Logger
module string
}
func NewScheduler() *Scheduler {
return &Scheduler{
logger: config.GetServerConfig().Logger(),
module: "Scheduler",
stopCh: make(chan struct{}),
}
}
func (s *Scheduler) RegisterJob(jobName string, job JobFunc) {
s.mu.Lock()
defer s.mu.Unlock()
s.jobs = append(s.jobs, job)
s.logger.Info(s.module, "", "[RegisterJob]: Registered job [%s]. Total jobs: %d", jobName, len(s.jobs))
}
func (s *Scheduler) StartJobs(ctx context.Context) {
s.logger.Info(s.module, "", "[StartJobs]: Starting %d background jobs...", len(s.jobs))
s.mu.RLock()
defer s.mu.RUnlock()
for i, job := range s.jobs {
s.wg.Add(1)
go func(i int, j JobFunc) {
defer s.wg.Done()
s.logger.Info(s.module, "", "[StartJobs]: Starting job #%d", i+1)
j(ctx)
}(i, job)
}
s.wg.Wait()
s.logger.Info(s.module, "", "[StartJobs]: All background jobs completed.")
}
func (s *Scheduler) Stop() {
close(s.stopCh)
}
func (s *Scheduler) Wait() {
s.mu.Lock()
defer s.mu.Unlock()
s.logger.Info(s.module, "", "Waiting for all background jobs to finish...")
s.wg.Wait()
s.logger.Info(s.module, "", "All background jobs have finished.")
}
func (s *Scheduler) GetJobs() []JobFunc {
s.mu.RLock()
defer s.mu.RUnlock()
return s.jobs
}
package background
import (
"context"
"sync"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/email"
)
const maxRetries int = 5
type SMTPJobs struct {
healthCheckTickerInterval time.Duration
queueTickerInterval time.Duration
emailService domain.EmailService
logger *config.Logger
module string
}
func NewSMTPJobs(
emailService domain.EmailService,
healthCheckTicker time.Duration,
queueTicker time.Duration,
) *SMTPJobs {
return &SMTPJobs{
healthCheckTickerInterval: healthCheckTicker,
queueTickerInterval: queueTicker,
emailService: emailService,
logger: config.GetServerConfig().Logger(),
module: "SMTP Jobs",
}
}
func (s *SMTPJobs) RunHealthCheck(ctx context.Context) {
s.logger.Info(s.module, "", "[RunHealthCheck]: Starting SMTP health check")
ticker := time.NewTicker(s.healthCheckTickerInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.emailService.TestConnection(); err != nil {
s.logger.Error(s.module, "", "[RunHealthCheck]: Failed to test SMTP connection: %v", err)
continue //nolint
}
case <-ctx.Done():
s.logger.Info(s.module, "", "[RunHealthCheck]: Stopping SMTP health check")
return //nolint
}
}
}
func (s *SMTPJobs) RunRetryQueueProcessor(ctx context.Context) {
s.logger.Info(s.module, "", "[RunRetryQueueProcessor]: Starting retry queue processor")
ticker := time.NewTicker(s.queueTickerInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.processRetryQueue(ctx)
case <-ctx.Done():
s.logger.Info(s.module, "", "[RunRetryQueueProcessor]: Stopping retry queue processor")
return
}
}
}
func (s *SMTPJobs) processRetryQueue(ctx context.Context) {
s.logger.Info(s.module, "", "Processing email retry queue")
retryQueue := s.emailService.GetEmailRetryQueue()
if retryQueue.IsEmpty() {
s.logger.Debug(s.module, "", "Retry queue is empty, skipping")
return
}
var waitGroup sync.WaitGroup
workerChan := make(chan *domain.EmailRequest)
numEmails := retryQueue.Size()
for i := range numEmails {
waitGroup.Add(1)
go s.retryWorker(retryQueue, i+1, ctx, workerChan, &waitGroup)
}
go func() {
defer close(workerChan)
for !retryQueue.IsEmpty() {
select {
case <-ctx.Done():
return
default:
request := retryQueue.Remove()
if request != nil {
workerChan <- request
}
}
}
}()
waitGroup.Wait()
s.logger.Info(s.module, "", "Retry queue process finished")
}
func (h *SMTPJobs) retryWorker(retryQueue *domain.EmailRetryQueue, workerID int, ctx context.Context, requests <-chan *domain.EmailRequest, waitGroup *sync.WaitGroup) {
defer waitGroup.Done()
for request := range requests {
select {
case <-ctx.Done():
return
default:
if request.Retries >= maxRetries {
h.logger.Error(h.module, "", "[Worker=%d] Max retries reached for email %s. Dropping.", workerID, request.ID)
continue
}
if err := h.emailService.SendEmail(ctx, request); err != nil {
request.Retries++
retryQueue.Add(request)
h.logger.Error(h.module, "", "[Worker=%d] Failed to retry sending email %s. Retrying. Error: %v", workerID, request.ID, err)
} else {
h.logger.Debug(h.module, "", "[Worker=%d] Successfully retried sending email %s.", workerID, request.ID)
}
}
}
}
package background
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/token"
)
type TokenJobs struct {
tokenService domain.TokenManager
interval time.Duration
logger *config.Logger
module string
}
func NewTokenJobs(tokenService domain.TokenManager, interval time.Duration) *TokenJobs {
return &TokenJobs{
tokenService: tokenService,
interval: interval,
logger: config.GetServerConfig().Logger(),
module: "Token Jobs",
}
}
func (t *TokenJobs) DeleteExpiredTokens(ctx context.Context) {
t.logger.Info(t.module, "", "[DeleteExpiredTokens]: Starting process of deleting expired tokens")
ticker := time.NewTicker(t.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := t.tokenService.DeleteExpiredTokens(ctx); err != nil {
t.logger.Error(t.module, "", "[DeleteExpiredTokens]: An error occurred deleting expired tokens: %v", err)
continue //nolint
}
case <-ctx.Done():
t.logger.Info(t.module, "", "[DeleteExpiredTokens]: Stopping process of deleting expired tokens")
return
}
}
}
package background
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
type UserJobs struct {
userService domain.UserManager
interval time.Duration
logger *config.Logger
module string
}
func NewUserJobs(userService domain.UserManager, interval time.Duration) *UserJobs {
return &UserJobs{
userService: userService,
interval: interval,
logger: config.GetServerConfig().Logger(),
module: "User Jobs",
}
}
func (u *UserJobs) DeleteUnverifiedUsers(ctx context.Context) {
u.logger.Info(u.module, "", "[DeleteUnverifiedUsers]: Starting Process of deleting unverified users that were created over a week ago")
ticker := time.NewTicker(u.interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := u.userService.DeleteUnverifiedUsers(ctx); err != nil {
u.logger.Error(u.module, "", "[DeleteUnverifiedUsers]: Failed to delete unverified users: %v", err)
continue //nolint:nlreturn
}
case <-ctx.Done():
u.logger.Info(u.module, "", "[DeleteUnverifiedUsers]: Stopping process of deleting unverified users")
return //nolint:nlreturn
}
}
}
package container
import (
"net/http"
"os"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
)
// DIContainer represents a dependency injection container that manages
// the registration and lifecycle of various application components.
// It provides registries for services, handlers, repositories, schedulers,
// and server configurations, enabling modular and organized application design.
//
// Fields:
// - serviceRegistry: Manages the registration and retrieval of service components.
// - handlerRegistry: Manages the registration and retrieval of handler components.
// - repoRegistry: Manages the registration and retrieval of repository components.
// - schedulerRegistry: Manages the registration and retrieval of scheduler components.
// - serverConfigRegistry: Manages the registration and retrieval of server configuration components.
// - exitCh: A channel used for signaling application shutdown.
// - logger: A logger instance for logging application events.
// - module: Represents the name or identifier of the current module.
type DIContainer struct {
serviceRegistry *ServiceRegistry
handlerRegistry *HandlerRegistry
repoRegistry *RepositoryRegistry
schedulerRegistry *SchedulerRegistry
serverConfigRegistry *ServerConfigRegistry
exitCh chan struct{}
logger *config.Logger
module string
}
func NewDIContainer(logger *config.Logger) *DIContainer {
module := "DI Container"
logger.Info(module, "", "Initializing Dependencies")
return &DIContainer{
logger: logger,
module: module,
}
}
// Init initializes the DIContainer by setting up various registries and dependencies.
// It creates and configures the following components:
// - RepositoryRegistry: Manages repositories and is initialized with a logger.
// - ServiceRegistry: Manages services and is initialized with the RepositoryRegistry and a logger.
// - HandlerRegistry: Manages handlers and is initialized with the ServiceRegistry and a logger.
// - ServerConfigRegistry: Manages server configurations and is initialized with the ServiceRegistry.
// - SchedulerRegistry: Manages scheduled tasks and is initialized with the ServiceRegistry, a logger, and an exit channel.
//
// Additionally, it starts the SchedulerRegistry to begin processing scheduled tasks.
func (di *DIContainer) Init() {
di.repoRegistry = NewRepositoryRegistry(di.logger)
di.serviceRegistry = NewServiceRegistry(di.repoRegistry, di.logger)
di.handlerRegistry = NewHandlerRegistry(di.serviceRegistry, di.logger)
di.serverConfigRegistry = NewServerConfigRegistry(di.serviceRegistry)
di.exitCh = make(chan struct{})
di.schedulerRegistry = NewSchedulerRegistry(di.serviceRegistry, di.logger, di.exitCh)
di.schedulerRegistry.Start()
}
func (di *DIContainer) ServiceRegistry() *ServiceRegistry {
return di.serviceRegistry
}
func (di *DIContainer) HandlerRegistry() *HandlerRegistry {
return di.handlerRegistry
}
func (di *DIContainer) RepositoryRegistry() *RepositoryRegistry {
return di.repoRegistry
}
func (di *DIContainer) ServerConfigRegistry() *ServerConfigRegistry {
return di.serverConfigRegistry
}
func (di *DIContainer) HTTPServer() *http.Server {
return di.ServerConfigRegistry().HTTPServer()
}
// Shutdown gracefully shuts down the DIContainer by stopping its scheduler registry.
// It waits for the shutdown process to complete or times out after a predefined duration.
// If the timeout is reached, the application exits forcefully.
// Logs are generated to indicate the progress and outcome of the shutdown process.
func (di *DIContainer) Shutdown() {
di.logger.Info(di.module, "", "Shutting down DI Container")
done := make(chan struct{})
go func() {
di.schedulerRegistry.Shutdown()
close(done)
}()
select {
case <-done:
di.logger.Info(di.module, "", "DI Container shut down successfully")
case <-time.After(constants.ThirtySecondTimeout):
di.logger.Warn(di.module, "", "Shutdown timeout reached. Forcing application exit.")
os.Exit(1)
}
}
package container
import (
"github.com/vigiloauth/vigilo/v2/idp/config"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
authzCode "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/client"
login "github.com/vigiloauth/vigilo/v2/internal/domain/login"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
userConsent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
auditEventRepo "github.com/vigiloauth/vigilo/v2/internal/repository/audit"
authzCodeRepo "github.com/vigiloauth/vigilo/v2/internal/repository/authzcode"
clientRepo "github.com/vigiloauth/vigilo/v2/internal/repository/client"
loginRepo "github.com/vigiloauth/vigilo/v2/internal/repository/login"
sessionRepo "github.com/vigiloauth/vigilo/v2/internal/repository/session"
tokenRepo "github.com/vigiloauth/vigilo/v2/internal/repository/token"
userRepo "github.com/vigiloauth/vigilo/v2/internal/repository/user"
consentRepo "github.com/vigiloauth/vigilo/v2/internal/repository/userconsent"
)
type RepositoryRegistry struct {
tokenRepo token.TokenRepository
loginAttemptRepo login.LoginAttemptRepository
userRepo user.UserRepository
clientRepo domain.ClientRepository
consentRepo userConsent.UserConsentRepository
authzCodeRepo authzCode.AuthorizationCodeRepository
sessionRepo session.SessionRepository
auditEventRepo audit.AuditRepository
logger *config.Logger
module string
}
func NewRepositoryRegistry(logger *config.Logger) *RepositoryRegistry {
module := "Repository Registry"
logger.Info(module, "", "Initializing repositories")
rr := &RepositoryRegistry{
logger: logger,
module: module,
}
rr.initInMemoryRepositories()
return rr
}
func (dr *RepositoryRegistry) initInMemoryRepositories() {
dr.tokenRepo = tokenRepo.GetInMemoryTokenRepository()
dr.loginAttemptRepo = loginRepo.GetInMemoryLoginRepository()
dr.userRepo = userRepo.GetInMemoryUserRepository()
dr.clientRepo = clientRepo.GetInMemoryClientRepository()
dr.consentRepo = consentRepo.GetInMemoryUserConsentRepository()
dr.authzCodeRepo = authzCodeRepo.GetInMemoryAuthorizationCodeRepository()
dr.sessionRepo = sessionRepo.GetInMemorySessionRepository()
dr.auditEventRepo = auditEventRepo.GetInMemoryAuditEventRepository()
}
func (dr *RepositoryRegistry) TokenRepository() token.TokenRepository {
return dr.tokenRepo
}
func (dr *RepositoryRegistry) LoginAttemptRepository() login.LoginAttemptRepository {
return dr.loginAttemptRepo
}
func (dr *RepositoryRegistry) UserRepository() user.UserRepository {
return dr.userRepo
}
func (dr *RepositoryRegistry) ClientRepository() domain.ClientRepository {
return dr.clientRepo
}
func (dr *RepositoryRegistry) UserConsentRepository() userConsent.UserConsentRepository {
return dr.consentRepo
}
func (dr *RepositoryRegistry) AuthorizationCodeRepository() authzCode.AuthorizationCodeRepository {
return dr.authzCodeRepo
}
func (dr *RepositoryRegistry) SessionRepository() session.SessionRepository {
return dr.sessionRepo
}
func (dr *RepositoryRegistry) AuditEventRepository() audit.AuditRepository {
return dr.auditEventRepo
}
package container
import (
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/handlers"
)
type HandlerRegistry struct {
sr *ServiceRegistry
userHandler LazyInit[*handlers.UserHandler]
clientHandler LazyInit[*handlers.ClientHandler]
tokenHandler LazyInit[*handlers.TokenHandler]
authzHandler LazyInit[*handlers.AuthorizationHandler]
oauthHandler LazyInit[*handlers.ConsentHandler]
adminHandler LazyInit[*handlers.AdminHandler]
oidcHandler LazyInit[*handlers.OIDCHandler]
logger *config.Logger
module string
}
func NewHandlerRegistry(sr *ServiceRegistry, logger *config.Logger) *HandlerRegistry {
module := "Handler Registry"
logger.Info(module, "", "Initializing handlers")
h := &HandlerRegistry{
sr: sr,
logger: logger,
module: module,
}
h.initHandlers()
return h
}
func (h *HandlerRegistry) initHandlers() {
h.initUserHandler()
h.initClientHandler()
h.initTokenHandler()
h.initAuthzHandler()
h.initOAuthHandler()
h.initAdminHandler()
h.initOIDCHandler()
}
func (h *HandlerRegistry) initUserHandler() {
h.userHandler = LazyInit[*handlers.UserHandler]{
initFunc: func() *handlers.UserHandler {
return handlers.NewUserHandler(
h.sr.UserCreator(),
h.sr.UserAuthenticator(),
h.sr.UserManager(),
h.sr.UserVerifier(),
h.sr.SessionService(),
)
},
}
}
func (h *HandlerRegistry) initClientHandler() {
h.clientHandler = LazyInit[*handlers.ClientHandler]{
initFunc: func() *handlers.ClientHandler {
return handlers.NewClientHandler(
h.sr.ClientCreator(),
h.sr.ClientManager(),
)
},
}
}
func (h *HandlerRegistry) initTokenHandler() {
h.tokenHandler = LazyInit[*handlers.TokenHandler]{
initFunc: func() *handlers.TokenHandler {
return handlers.NewTokenHandler(
h.sr.TokenGrantProcessor(),
)
},
}
}
func (h *HandlerRegistry) initAuthzHandler() {
h.authzHandler = LazyInit[*handlers.AuthorizationHandler]{
initFunc: func() *handlers.AuthorizationHandler {
return handlers.NewAuthorizationHandler(
h.sr.ClientAuthorization(),
)
},
}
}
func (h *HandlerRegistry) initOAuthHandler() {
h.oauthHandler = LazyInit[*handlers.ConsentHandler]{
initFunc: func() *handlers.ConsentHandler {
return handlers.NewConsentHandler(
h.sr.SessionService(),
h.sr.UserConsentService(),
)
},
}
}
func (h *HandlerRegistry) initAdminHandler() {
h.adminHandler = LazyInit[*handlers.AdminHandler]{
initFunc: func() *handlers.AdminHandler {
return handlers.NewAdminHandler(h.sr.AuditLogger())
},
}
}
func (h *HandlerRegistry) initOIDCHandler() {
h.oidcHandler = LazyInit[*handlers.OIDCHandler]{
initFunc: func() *handlers.OIDCHandler {
return handlers.NewOIDCHandler(h.sr.OIDCService())
},
}
}
func (h *HandlerRegistry) UserHandler() *handlers.UserHandler {
return h.userHandler.Get()
}
func (h *HandlerRegistry) ClientHandler() *handlers.ClientHandler {
return h.clientHandler.Get()
}
func (h *HandlerRegistry) TokenHandler() *handlers.TokenHandler {
return h.tokenHandler.Get()
}
func (h *HandlerRegistry) AuthorizationHandler() *handlers.AuthorizationHandler {
return h.authzHandler.Get()
}
func (h *HandlerRegistry) OAuthHandler() *handlers.ConsentHandler {
return h.oauthHandler.Get()
}
func (h *HandlerRegistry) AdminHandler() *handlers.AdminHandler {
return h.adminHandler.Get()
}
func (h *HandlerRegistry) OIDCHandler() *handlers.OIDCHandler {
return h.oidcHandler.Get()
}
package container
import "sync"
// LazyInit is a generic type that provides lazy initialization for a value of type T.
// It ensures that the initialization function (initFunc) is executed only once,
// regardless of how many times the value is accessed.
//
// Fields:
// - once: A sync.Once instance used to guarantee that the initialization function
// is executed only once.
// - value: The lazily initialized value of type T.
// - initFunc: A function that initializes and returns the value of type T.
//
// Usage:
// LazyInit can be used to defer the computation or initialization of a value
// until it is actually needed, while ensuring thread-safe access.
type LazyInit[T any] struct {
once sync.Once
value T
initFunc func() T
}
// Get retrieves the value of the LazyInit instance, initializing it if necessary.
// The initialization is performed only once using the provided initFunc.
// Subsequent calls to Get will return the already initialized value.
//
// Returns:
// - T: The initialized value of the LazyInit instance.
func (l *LazyInit[T]) Get() T {
l.once.Do(func() {
l.value = l.initFunc()
})
return l.value
}
package container
import (
"context"
"os"
"os/signal"
"syscall"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/background"
)
// SchedulerRegistry is a struct that manages the registration and execution of scheduled tasks.
// It integrates with a service registry, a background scheduler, and provides logging capabilities.
// Fields:
// - services *ServiceRegistry: A reference to the ServiceRegistry, which manages service dependencies.
// - scheduler *Scheduler: A background scheduler responsible for executing tasks at specified intervals.
// - exitCh chan struct{}: A channel used to signal termination or shutdown of the scheduler.
// - logger *Logger: A logger instance for recording events and errors.
// - module string: A string representing the module name associated with the scheduler.
type SchedulerRegistry struct {
services *ServiceRegistry
scheduler *background.Scheduler
exitCh chan struct{}
logger *config.Logger
module string
}
func NewSchedulerRegistry(services *ServiceRegistry, logger *config.Logger, exitCh chan struct{}) *SchedulerRegistry {
module := "Scheduler Registry"
logger.Info(module, "", "Initializing schedulers")
sr := &SchedulerRegistry{
services: services,
logger: logger,
module: module,
exitCh: exitCh,
scheduler: background.NewScheduler(),
}
return sr
}
func (sr *SchedulerRegistry) Start() {
sr.initJobs()
}
func (sr *SchedulerRegistry) Shutdown() {
sr.logger.Info(sr.module, "", "Shutting down schedulers and worker pool")
if sr.scheduler != nil {
sr.scheduler.Stop()
sr.scheduler.Wait()
}
}
func (sr *SchedulerRegistry) initJobs() {
sr.registerSMTPJobs()
sr.registerTokenJobs()
sr.registerUserJobs()
sr.registerAuditLogJobs()
ctx, cancel := context.WithCancel(context.Background())
go sr.scheduler.StartJobs(ctx)
go func() {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(sigCh)
select {
case <-sigCh:
sr.logger.Info(sr.module, "", "Received termination signal")
cancel()
close(sr.exitCh)
case <-sr.exitCh:
cancel()
}
}()
}
func (c *SchedulerRegistry) registerSMTPJobs() {
const healthCheckInterval time.Duration = 15 * time.Minute
const queueProcessorInterval time.Duration = 10 * time.Minute
smtpJobs := background.NewSMTPJobs(c.services.EmailService(), healthCheckInterval, queueProcessorInterval)
c.scheduler.RegisterJob("SMTP Health Check", smtpJobs.RunHealthCheck)
c.scheduler.RegisterJob("Email Retry Queue", smtpJobs.RunRetryQueueProcessor)
}
func (c *SchedulerRegistry) registerTokenJobs() {
const tokenDeletionInterval time.Duration = 5 * time.Minute
tokenJobs := background.NewTokenJobs(c.services.TokenManager(), tokenDeletionInterval)
c.scheduler.RegisterJob("Expired Token Deletion", tokenJobs.DeleteExpiredTokens)
}
func (c *SchedulerRegistry) registerUserJobs() {
const userDeletionInterval time.Duration = 24 * time.Hour
userJobs := background.NewUserJobs(c.services.UserManager(), userDeletionInterval)
c.scheduler.RegisterJob("Unverified User Deletion", userJobs.DeleteUnverifiedUsers)
}
func (c *SchedulerRegistry) registerAuditLogJobs() {
retentionPeriod := config.GetServerConfig().AuditLogConfig().RetentionPeriod()
const purgeInterval time.Duration = 24 * time.Hour
auditLogJobs := background.NewAuditJobs(c.services.AuditLogger(), retentionPeriod, purgeInterval)
c.scheduler.RegisterJob("Audit Log Deletion", auditLogJobs.PurgeLogs)
}
package container
import (
"crypto/tls"
"fmt"
"net/http"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
type ServerConfigRegistry struct {
tlsConfig *tls.Config
httpServer *http.Server
}
func NewServerConfigRegistry(services *ServiceRegistry) *ServerConfigRegistry {
sr := &ServerConfigRegistry{}
sr.initServerConfigurations()
return sr
}
func (sr *ServerConfigRegistry) HTTPServer() *http.Server {
return sr.httpServer
}
func (sr *ServerConfigRegistry) initServerConfigurations() {
sr.initTLS()
sr.initHTTPServer()
}
func (sr *ServerConfigRegistry) initTLS() {
sr.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
},
}
}
func (sr *ServerConfigRegistry) initHTTPServer() {
sr.httpServer = &http.Server{
Addr: fmt.Sprintf(":%s", config.GetServerConfig().Port()),
ReadTimeout: config.GetServerConfig().ReadTimeout(),
WriteTimeout: config.GetServerConfig().WriteTimeout(),
TLSConfig: sr.tlsConfig,
}
}
package container
import (
"github.com/vigiloauth/vigilo/v2/idp/config"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authorization"
authzCode "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
cookie "github.com/vigiloauth/vigilo/v2/internal/domain/cookies"
crypto "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
email "github.com/vigiloauth/vigilo/v2/internal/domain/email"
jwt "github.com/vigiloauth/vigilo/v2/internal/domain/jwt"
login "github.com/vigiloauth/vigilo/v2/internal/domain/login"
oidc "github.com/vigiloauth/vigilo/v2/internal/domain/oidc"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
consent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/middleware"
auditLogger "github.com/vigiloauth/vigilo/v2/internal/service/audit"
authzService "github.com/vigiloauth/vigilo/v2/internal/service/authorization"
authzCodeService "github.com/vigiloauth/vigilo/v2/internal/service/authzcode"
clientService "github.com/vigiloauth/vigilo/v2/internal/service/client"
cookieService "github.com/vigiloauth/vigilo/v2/internal/service/cookies"
cryptoService "github.com/vigiloauth/vigilo/v2/internal/service/crypto"
emailService "github.com/vigiloauth/vigilo/v2/internal/service/email"
jwtService "github.com/vigiloauth/vigilo/v2/internal/service/jwt"
loginService "github.com/vigiloauth/vigilo/v2/internal/service/login"
oidcService "github.com/vigiloauth/vigilo/v2/internal/service/oidc"
sessionService "github.com/vigiloauth/vigilo/v2/internal/service/session"
tokenService "github.com/vigiloauth/vigilo/v2/internal/service/token"
userService "github.com/vigiloauth/vigilo/v2/internal/service/user"
consentService "github.com/vigiloauth/vigilo/v2/internal/service/userconsent"
)
type ServiceRegistry struct {
db *RepositoryRegistry
consentService LazyInit[consent.UserConsentService]
loginAttemptService LazyInit[login.LoginAttemptService]
authorizationService LazyInit[authz.AuthorizationService]
httpCookieService LazyInit[cookie.HTTPCookieService]
emailService LazyInit[email.EmailService]
goMailerService LazyInit[email.Mailer]
auditLogger LazyInit[audit.AuditLogger]
oidcService LazyInit[oidc.OIDCService]
jwtService LazyInit[jwt.JWTService]
encryptor LazyInit[crypto.Cryptographer]
middlewares LazyInit[*middleware.Middleware]
sessionService LazyInit[session.SessionService]
sessionManager LazyInit[session.SessionManager]
authzCodeManager LazyInit[authzCode.AuthorizationCodeManager]
authzCodeCreator LazyInit[authzCode.AuthorizationCodeCreator]
authzCodeIssuer LazyInit[authzCode.AuthorizationCodeIssuer]
authzCodeRequestValidator LazyInit[authzCode.AuthorizationCodeValidator]
clientAuthenticator LazyInit[client.ClientAuthenticator]
clientValidator LazyInit[client.ClientValidator]
clientCreator LazyInit[client.ClientCreator]
clientManager LazyInit[client.ClientManager]
clientAuthorization LazyInit[client.ClientAuthorization]
userAuthenticator LazyInit[user.UserAuthenticator]
userManager LazyInit[user.UserManager]
userVerifier LazyInit[user.UserVerifier]
userCreator LazyInit[user.UserCreator]
tokenManager LazyInit[token.TokenManager]
tokenParser LazyInit[token.TokenParser]
tokenRequestProcessor LazyInit[token.TokenGrantProcessor]
tokenIssuer LazyInit[token.TokenIssuer]
tokenValidator LazyInit[token.TokenValidator]
tokenCreator LazyInit[token.TokenCreator]
logger *config.Logger
module string
}
func NewServiceRegistry(dbRegistry *RepositoryRegistry, logger *config.Logger) *ServiceRegistry {
module := "Service Registry"
logger.Info(module, "", "Initializing services")
sr := &ServiceRegistry{
logger: logger,
module: module,
db: dbRegistry,
}
sr.initServices()
return sr
}
func (sr *ServiceRegistry) initServices() {
sr.initMiddleware()
sr.initTokenRequestProcessor()
sr.initTokenManager()
sr.initTokenParser()
sr.initTokenValidator()
sr.initTokenCreator()
sr.initTokenIssuer()
sr.initJWTService()
sr.initClientAuthenticator()
sr.initClientValidator()
sr.initClientCreator()
sr.initClientManager()
sr.initClientAuthorization()
sr.initUserAuthenticator()
sr.initUserManager()
sr.initUserVerifier()
sr.initUserCreator()
sr.initAuthzCodeManager()
sr.initAuthzCodeCreator()
sr.initAuthzCodeIssuer()
sr.initAuthzCodeRequestValidator()
sr.initCryptographer()
sr.initSessionService()
sr.initSessionManager()
sr.initConsentService()
sr.initLoginAttemptService()
sr.initAuthorizationService()
sr.initHTTPCookieService()
sr.initEmailService()
sr.initAuditLogger()
sr.initOIDCService()
}
func (sr *ServiceRegistry) initMiddleware() {
sr.middlewares = LazyInit[*middleware.Middleware]{
initFunc: func() *middleware.Middleware {
return middleware.NewMiddleware(
sr.TokenParser(),
sr.TokenValidator(),
)
},
}
}
func (sr *ServiceRegistry) initJWTService() {
sr.jwtService = LazyInit[jwt.JWTService]{
initFunc: func() jwt.JWTService {
return jwtService.NewJWTService()
},
}
}
func (sr *ServiceRegistry) initSessionService() {
sr.sessionService = LazyInit[session.SessionService]{
initFunc: func() session.SessionService {
return sessionService.NewSessionService(
sr.db.SessionRepository(),
sr.HTTPCookieService(),
sr.AuditLogger(),
)
},
}
}
func (sr *ServiceRegistry) initSessionManager() {
sr.sessionManager = LazyInit[session.SessionManager]{
initFunc: func() session.SessionManager {
return sessionService.NewSessionManager(
sr.db.SessionRepository(),
sr.HTTPCookieService(),
)
},
}
}
func (sr *ServiceRegistry) initUserAuthenticator() {
sr.userAuthenticator = LazyInit[user.UserAuthenticator]{
initFunc: func() user.UserAuthenticator {
return userService.NewUserAuthenticator(
sr.db.UserRepository(),
sr.AuditLogger(),
sr.LoginAttemptService(),
)
},
}
}
func (sr *ServiceRegistry) initUserManager() {
sr.userManager = LazyInit[user.UserManager]{
initFunc: func() user.UserManager {
return userService.NewUserManager(
sr.db.UserRepository(),
sr.TokenParser(),
sr.TokenManager(),
sr.Cryptographer(),
)
},
}
}
func (sr *ServiceRegistry) initUserVerifier() {
sr.userVerifier = LazyInit[user.UserVerifier]{
initFunc: func() user.UserVerifier {
return userService.NewUserVerifier(
sr.db.UserRepository(),
sr.TokenParser(),
sr.TokenValidator(),
sr.TokenManager(),
)
},
}
}
func (sr *ServiceRegistry) initUserCreator() {
sr.userCreator = LazyInit[user.UserCreator]{
initFunc: func() user.UserCreator {
return userService.NewUserCreator(
sr.db.UserRepository(),
sr.TokenIssuer(),
sr.AuditLogger(),
sr.EmailService(),
sr.Cryptographer(),
)
},
}
}
func (sr *ServiceRegistry) initCryptographer() {
sr.encryptor = LazyInit[crypto.Cryptographer]{
initFunc: func() crypto.Cryptographer {
return cryptoService.NewCryptographer()
},
}
}
func (sr *ServiceRegistry) initClientAuthenticator() {
sr.clientAuthenticator = LazyInit[client.ClientAuthenticator]{
initFunc: func() client.ClientAuthenticator {
return clientService.NewClientAuthenticator(
sr.db.ClientRepository(),
sr.TokenValidator(),
sr.TokenParser(),
)
},
}
}
func (sr *ServiceRegistry) initClientValidator() {
sr.clientValidator = LazyInit[client.ClientValidator]{
initFunc: func() client.ClientValidator {
return clientService.NewClientValidator(
sr.db.ClientRepository(),
sr.TokenManager(),
sr.TokenValidator(),
sr.TokenParser(),
)
},
}
}
func (sr *ServiceRegistry) initClientCreator() {
sr.clientCreator = LazyInit[client.ClientCreator]{
initFunc: func() client.ClientCreator {
return clientService.NewClientCreator(
sr.db.ClientRepository(),
sr.ClientValidator(),
sr.TokenIssuer(),
sr.Cryptographer(),
)
},
}
}
func (sr *ServiceRegistry) initClientManager() {
sr.clientManager = LazyInit[client.ClientManager]{
initFunc: func() client.ClientManager {
return clientService.NewClientManager(
sr.db.ClientRepository(),
sr.ClientValidator(),
sr.ClientAuthenticator(),
sr.Cryptographer(),
)
},
}
}
func (sr *ServiceRegistry) initClientAuthorization() {
sr.clientAuthorization = LazyInit[client.ClientAuthorization]{
initFunc: func() client.ClientAuthorization {
return clientService.NewClientAuthorization(
sr.ClientValidator(),
sr.ClientManager(),
sr.SessionManager(),
sr.UserConsentService(),
sr.AuthorizationCodeIssuer(),
)
},
}
}
func (sr *ServiceRegistry) initConsentService() {
sr.consentService = LazyInit[consent.UserConsentService]{
initFunc: func() consent.UserConsentService {
return consentService.NewUserConsentService(
sr.db.UserConsentRepository(),
sr.db.UserRepository(),
sr.SessionService(),
sr.ClientManager(),
)
},
}
}
func (sr *ServiceRegistry) initAuthzCodeManager() {
sr.authzCodeManager = LazyInit[authzCode.AuthorizationCodeManager]{
initFunc: func() authzCode.AuthorizationCodeManager {
return authzCodeService.NewAuthorizationCodeManager(
sr.db.AuthorizationCodeRepository(),
)
},
}
}
func (sr *ServiceRegistry) initAuthzCodeCreator() {
sr.authzCodeCreator = LazyInit[authzCode.AuthorizationCodeCreator]{
initFunc: func() authzCode.AuthorizationCodeCreator {
return authzCodeService.NewAuthorizationCodeCreator(
sr.db.AuthorizationCodeRepository(),
sr.Cryptographer(),
)
},
}
}
func (sr *ServiceRegistry) initAuthzCodeIssuer() {
sr.authzCodeIssuer = LazyInit[authzCode.AuthorizationCodeIssuer]{
initFunc: func() authzCode.AuthorizationCodeIssuer {
return authzCodeService.NewAuthorizationCodeIssuer(
sr.AuthorizationCodeCreator(),
)
},
}
}
func (sr *ServiceRegistry) initAuthzCodeRequestValidator() {
sr.authzCodeRequestValidator = LazyInit[authzCode.AuthorizationCodeValidator]{
initFunc: func() authzCode.AuthorizationCodeValidator {
return authzCodeService.NewAuthorizationCodeValidator(
sr.db.AuthorizationCodeRepository(),
sr.ClientValidator(),
sr.ClientAuthenticator(),
)
},
}
}
func (sr *ServiceRegistry) initLoginAttemptService() {
sr.loginAttemptService = LazyInit[login.LoginAttemptService]{
initFunc: func() login.LoginAttemptService {
return loginService.NewLoginAttemptService(
sr.db.UserRepository(),
sr.db.LoginAttemptRepository(),
)
},
}
}
func (sr *ServiceRegistry) initAuthorizationService() {
sr.authorizationService = LazyInit[authz.AuthorizationService]{
initFunc: func() authz.AuthorizationService {
return authzService.NewAuthorizationService(
sr.AuthorizationCodeManager(),
sr.UserConsentService(),
sr.TokenManager(),
sr.ClientManager(),
sr.ClientValidator(),
sr.UserManager(),
)
},
}
}
func (sr *ServiceRegistry) initHTTPCookieService() {
sr.httpCookieService = LazyInit[cookie.HTTPCookieService]{
initFunc: func() cookie.HTTPCookieService {
return cookieService.NewHTTPCookieService()
},
}
}
func (sr *ServiceRegistry) initEmailService() {
sr.emailService = LazyInit[email.EmailService]{
initFunc: func() email.EmailService {
return emailService.NewEmailService(sr.GoMailerService())
},
}
sr.goMailerService = LazyInit[email.Mailer]{
initFunc: func() email.Mailer {
return emailService.NewGoMailer()
},
}
}
func (sr *ServiceRegistry) initAuditLogger() {
sr.auditLogger = LazyInit[audit.AuditLogger]{
initFunc: func() audit.AuditLogger {
return auditLogger.NewAuditLogger(sr.db.AuditEventRepository())
},
}
}
func (sr *ServiceRegistry) initOIDCService() {
sr.oidcService = LazyInit[oidc.OIDCService]{
initFunc: func() oidc.OIDCService {
return oidcService.NewOIDCService(sr.AuthorizationService())
},
}
}
func (sr *ServiceRegistry) initTokenManager() {
sr.tokenManager = LazyInit[token.TokenManager]{
initFunc: func() token.TokenManager {
return tokenService.NewTokenManager(
sr.db.TokenRepository(),
sr.TokenParser(),
sr.TokenValidator(),
)
},
}
}
func (sr *ServiceRegistry) initTokenValidator() {
sr.tokenValidator = LazyInit[token.TokenValidator]{
initFunc: func() token.TokenValidator {
return tokenService.NewTokenValidator(
sr.db.TokenRepository(),
sr.TokenParser(),
)
},
}
}
func (sr *ServiceRegistry) initTokenCreator() {
sr.tokenCreator = LazyInit[token.TokenCreator]{
initFunc: func() token.TokenCreator {
return tokenService.NewTokenCreator(
sr.db.TokenRepository(),
sr.JWTService(),
sr.Cryptographer(),
)
},
}
}
func (sr *ServiceRegistry) initTokenParser() {
sr.tokenParser = LazyInit[token.TokenParser]{
initFunc: func() token.TokenParser {
return tokenService.NewTokenParser(
sr.JWTService(),
)
},
}
}
func (sr *ServiceRegistry) initTokenIssuer() {
sr.tokenIssuer = LazyInit[token.TokenIssuer]{
initFunc: func() token.TokenIssuer {
return tokenService.NewTokenIssuer(
sr.TokenCreator(),
)
},
}
}
func (sr *ServiceRegistry) initTokenRequestProcessor() {
sr.tokenRequestProcessor = LazyInit[token.TokenGrantProcessor]{
initFunc: func() token.TokenGrantProcessor {
return tokenService.NewTokenGrantProcessor(
sr.TokenIssuer(),
sr.TokenManager(),
sr.ClientAuthenticator(),
sr.UserAuthenticator(),
sr.AuthorizationService(),
)
},
}
}
func (sr *ServiceRegistry) Middleware() *middleware.Middleware {
return sr.middlewares.Get()
}
func (sr *ServiceRegistry) TokenManager() token.TokenManager {
return sr.tokenManager.Get()
}
func (sr *ServiceRegistry) TokenParser() token.TokenParser {
return sr.tokenParser.Get()
}
func (sr *ServiceRegistry) TokenValidator() token.TokenValidator {
return sr.tokenValidator.Get()
}
func (sr *ServiceRegistry) TokenCreator() token.TokenCreator {
return sr.tokenCreator.Get()
}
func (sr *ServiceRegistry) TokenGrantProcessor() token.TokenGrantProcessor {
return sr.tokenRequestProcessor.Get()
}
func (sr *ServiceRegistry) TokenIssuer() token.TokenIssuer {
return sr.tokenIssuer.Get()
}
func (sr *ServiceRegistry) SessionService() session.SessionService {
return sr.sessionService.Get()
}
func (sr *ServiceRegistry) SessionManager() session.SessionManager {
return sr.sessionManager.Get()
}
func (sr *ServiceRegistry) UserAuthenticator() user.UserAuthenticator {
return sr.userAuthenticator.Get()
}
func (sr *ServiceRegistry) UserManager() user.UserManager {
return sr.userManager.Get()
}
func (sr *ServiceRegistry) UserVerifier() user.UserVerifier {
return sr.userVerifier.Get()
}
func (sr *ServiceRegistry) UserCreator() user.UserCreator {
return sr.userCreator.Get()
}
func (sr *ServiceRegistry) ClientAuthenticator() client.ClientAuthenticator {
return sr.clientAuthenticator.Get()
}
func (sr *ServiceRegistry) ClientValidator() client.ClientValidator {
return sr.clientValidator.Get()
}
func (sr *ServiceRegistry) ClientCreator() client.ClientCreator {
return sr.clientCreator.Get()
}
func (sr *ServiceRegistry) ClientManager() client.ClientManager {
return sr.clientManager.Get()
}
func (sr *ServiceRegistry) ClientAuthorization() client.ClientAuthorization {
return sr.clientAuthorization.Get()
}
func (sr *ServiceRegistry) UserConsentService() consent.UserConsentService {
return sr.consentService.Get()
}
func (sr *ServiceRegistry) AuthorizationCodeManager() authzCode.AuthorizationCodeManager {
return sr.authzCodeManager.Get()
}
func (sr *ServiceRegistry) AuthorizationCodeCreator() authzCode.AuthorizationCodeCreator {
return sr.authzCodeCreator.Get()
}
func (sr *ServiceRegistry) AuthorizationCodeIssuer() authzCode.AuthorizationCodeIssuer {
return sr.authzCodeIssuer.Get()
}
func (sr *ServiceRegistry) AuthorizationCodeRequestValidator() authzCode.AuthorizationCodeValidator {
return sr.authzCodeRequestValidator.Get()
}
func (sr *ServiceRegistry) LoginAttemptService() login.LoginAttemptService {
return sr.loginAttemptService.Get()
}
func (sr *ServiceRegistry) AuthorizationService() authz.AuthorizationService {
return sr.authorizationService.Get()
}
func (sr *ServiceRegistry) HTTPCookieService() cookie.HTTPCookieService {
return sr.httpCookieService.Get()
}
func (sr *ServiceRegistry) JWTService() jwt.JWTService {
return sr.jwtService.Get()
}
func (sr *ServiceRegistry) EmailService() email.EmailService {
return sr.emailService.Get()
}
func (sr *ServiceRegistry) GoMailerService() email.Mailer {
return sr.goMailerService.Get()
}
func (sr *ServiceRegistry) AuditLogger() audit.AuditLogger {
return sr.auditLogger.Get()
}
func (sr *ServiceRegistry) OIDCService() oidc.OIDCService {
return sr.oidcService.Get()
}
func (sr *ServiceRegistry) Cryptographer() crypto.Cryptographer {
return sr.encryptor.Get()
}
package domain
import (
"context"
"encoding/json"
"time"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
type AuditEvent struct {
EventID string `json:"event_id"`
Timestamp time.Time `json:"timestamp"`
EventType EventType `json:"event_type"`
Success bool `json:"success"`
UserID string `json:"user_id,omitempty"`
IP string `json:"ip_address"`
UserAgent string `json:"user_agent,omitempty"`
RequestID string `json:"request_id,omitempty"`
Details json.RawMessage `json:"details,omitempty"`
SessionID string `json:"session_id,omitempty"`
ErrCode string `json:"error_code,omitempty"`
}
type EventType string
type ActionType string
type MethodType string
const (
LoginAttempt EventType = "login_attempt"
PasswordChange EventType = "password_reset"
RegistrationAttempt EventType = "registration_attempt"
AccountDeletion EventType = "account_deletion_attempt"
SessionCreated EventType = "session_created"
SessionDeleted EventType = "session_deleted"
RegistrationAction ActionType = "registration"
AuthenticationAction ActionType = "authentication"
PasswordResetAction ActionType = "password_reset"
AccountDeletionAction ActionType = "deletion"
SessionCreationAction ActionType = "session_creation"
SessionDeletionAction ActionType = "session_deletion"
EmailMethod MethodType = "email"
OAuthMethod MethodType = "oauth"
IDMethod MethodType = "id"
CookieMethod MethodType = "cookie"
)
func (e EventType) String() string { return string(e) }
func (a ActionType) String() string { return string(a) }
func (m MethodType) String() string { return string(m) }
func NewAuditEvent(ctx context.Context, eventType EventType, success bool, action ActionType, method MethodType, errCode string) *AuditEvent {
event := &AuditEvent{
EventID: constants.AuditEventIDPrefix + utils.GenerateUUID(),
Timestamp: time.Now().UTC(),
EventType: eventType,
Success: success,
RequestID: utils.GetRequestID(ctx),
ErrCode: errCode,
}
if userID := utils.GetValueFromContext(ctx, constants.ContextKeyUserID); userID != nil {
event.UserID, _ = userID.(string)
}
if IP := utils.GetValueFromContext(ctx, constants.ContextKeyIPAddress); IP != nil {
event.IP, _ = IP.(string)
}
if userAgent := utils.GetValueFromContext(ctx, constants.ContextKeyUserAgent); userAgent != nil {
event.UserAgent, _ = userAgent.(string)
}
if sessionID := utils.GetValueFromContext(ctx, constants.ContextKeySessionID); sessionID != nil {
event.SessionID, _ = sessionID.(string)
}
event.addEventDetails(action, method)
return event
}
func (e *AuditEvent) addEventDetails(action ActionType, method MethodType) {
details := map[string]string{}
if action != "" {
details[constants.ActionDetails] = action.String()
}
if method != "" {
details[constants.MethodDetails] = method.String()
}
JSONDetails, err := json.Marshal(details)
if err == nil {
e.Details = JSONDetails
}
}
func (e *AuditEvent) String() string {
eventJSON, err := json.MarshalIndent(e, "", " ")
if err != nil {
return "AuditEvent: error serializing to string"
}
return string(eventJSON)
}
package domain
import "github.com/vigiloauth/vigilo/v2/internal/errors"
func (c *AuthorizationCodeData) ValidateFields(clientID, redirectURI string) error {
if c.Used {
return errors.New(errors.ErrCodeInvalidGrant, "authorization code has already been used")
} else if c.ClientID != clientID {
return errors.New(errors.ErrCodeInvalidGrant, "authorization code client ID and request client ID do no match")
} else if c.RedirectURI != redirectURI {
return errors.New(errors.ErrCodeInvalidGrant, "authorization code redirect URI and request redirect URI do no match")
}
return nil
}
package domain
import (
"encoding/json"
"net/url"
)
type ClaimsRequest struct {
UserInfo *ClaimSet `json:"userinfo,omitempty"`
}
type ClaimSet map[string]*ClaimSpec
type ClaimSpec struct {
Essential bool `json:"essential,omitempty"`
Value string `json:"value,omitempty"`
}
func ParseClaimsParameter(claimsParam string) (*ClaimsRequest, error) {
decodedClaims, _ := url.QueryUnescape(claimsParam)
var claimsRequest ClaimsRequest
_ = json.Unmarshal([]byte(decodedClaims), &claimsRequest)
return &claimsRequest, nil
}
func SerializeClaims(claims *ClaimsRequest) string {
if claims != nil {
claimsJSON, err := json.Marshal(claims)
if err != nil {
claimsJSON = nil
}
return string(claimsJSON)
}
return ""
}
package domain
import (
"net/http"
"net/url"
"slices"
"time"
"github.com/vigiloauth/vigilo/v2/internal/constants"
claims "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
jwks "github.com/vigiloauth/vigilo/v2/internal/domain/jwks"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
// Client represents an OAuth 2.0 client application.
// It stores the client's metadata and configuration.
type Client struct {
Name string // The human-readable name of the client application.
ID string // The unique identifier assigned to the client.
Secret string // The client secret used for confidential client authentication.
Type types.ClientType // The type of the client: "confidential" or "public".
TokenEndpointAuthMethod types.TokenAuthMethod // The authentication method used by the client at the token endpoint (e.g., "client_secret_basic", "client_secret_post", "private_key_jwt").
JwksURI string // The URL of the client's JSON Web Key Set (JWKS) document for verifying signatures.
LogoURI string // The URL of the client's logo.
PolicyURI string // The URL of the client's privacy policy.
SectorIdentifierURI string // The URL of the client's containing their redirect URI
ApplicationType string // The type of the application (e.g., "web", "native").
RegistrationAccessToken string // The access token used to read and update the client's registration information.
RedirectURIs []string // A list of allowed redirect URIs for the client.
GrantTypes []string // A list of OAuth 2.0 grant types the client is authorized to use.
Scopes []types.Scope // A list of authorization scopes the client can request.
ResponseTypes []string // A list of OAuth 2.0 response types the client is authorized to use.
Contacts []string // A list of contact persons for the client.
CreatedAt time.Time // The timestamp when the client was created.
UpdatedAt time.Time // The timestamp when the client was last updated.
IDIssuedAt time.Time // The timestamp when the client ID was issued.
SecretExpiration int // The expiration time of the client secret in seconds (0 for no expiration).
RequiresPKCE bool // Indicates if the client requires Proof Key for Code Exchange (PKCE) for the authorization code grant.
JWKS *jwks.Jwks // The client's JSON Web Key Set (JWKS) for verifying signatures, embedded directly.
RegistrationClientURI string // The URL of the client's registration endpoint.
// CanRequestScopes indicates if the client is restricted to its registered scopes during authorization.
// If false, the client can request any valid scope.
CanRequestScopes bool
}
type ClientReadResponse struct {
ID string `json:"client_id,omitempty"`
Name string `json:"name,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
}
// ClientRegistrationRequest represents a request to register a new OAuth client.
type ClientRegistrationRequest struct {
Name string `json:"client_name"`
ApplicationType string `json:"application_type,omitempty"`
RedirectURIs []string `json:"redirect_uris"`
Scopes []types.Scope `json:"scope,omitempty"`
GrantTypes []string `json:"grant_types"`
ResponseTypes []string `json:"response_types"`
Contacts []string `json:"contacts,omitempty"`
JwksURI string `json:"jwks_uri,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
TokenEndpointAuthMethod types.TokenAuthMethod `json:"token_endpoint_auth_method,omitempty"`
JWKS *jwks.Jwks `json:"jwks,omitempty"`
RequiresPKCE bool
Type types.ClientType
}
// ClientUpdateRequest represents a request to update an existing OAuth client.
type ClientUpdateRequest struct {
ID string `json:"client_id"`
Secret string `json:"client_secret,omitempty"`
Name string `json:"client_name,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
RedirectURIs []string `json:"redirect_uris,omitempty"`
GrantTypes []string `json:"grant_types,omitempty"`
Scopes []types.Scope `json:"scope,omitempty"`
ResponseTypes []string `json:"response_types,omitempty"`
Contacts []string `json:"contacts,omitempty"`
JwksURI string `json:"jwks_uri,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
TokenEndpointAuthMethod types.TokenAuthMethod `json:"token_endpoint_auth_method,omitempty"`
Type types.ClientType
}
// ClientRegistrationResponse represents a response after registering an OAuth client.
type ClientRegistrationResponse struct {
ID string `json:"client_id"`
Name string `json:"client_name"`
Type types.ClientType `json:"client_type"`
Secret string `json:"client_secret,omitempty"`
SecretExpiration int `json:"client_secret_expires_at,omitempty"`
ApplicationType string `json:"application_type,omitempty"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
Scopes []types.Scope `json:"scope"`
ResponseTypes []string `json:"response_types"`
Contacts []string `json:"contacts,omitempty"`
JwksURI string `json:"jwks_uri,omitempty"`
LogoURI string `json:"logo_uri,omitempty"`
PolicyURI string `json:"policy_uri,omitempty"`
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
RegistrationAccessToken string `json:"registration_access_token"`
RegistrationClientURI string `json:"registration_client_uri"`
TokenEndpointAuthMethod types.TokenAuthMethod `json:"token_endpoint_auth_method,omitempty"`
UpdatedAt time.Time `json:"updated_at"`
CreatedAt time.Time `json:"created_at"`
IDIssuedAt time.Time `json:"client_id_issued_at"`
}
type ClientConfigurationEndpoint struct {
Name string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
GrantTypes []string `json:"grant_types"`
Scopes []types.Scope `json:"scope,omitempty"`
ResponseTypes []string `json:"response_types"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
IDIssuedAt time.Time `json:"client_id_issued_at"`
TokenEndpointAuthMethod types.TokenAuthMethod `json:"token_endpoint_auth_method,omitempty"`
ConfigurationEndpoint string `json:"client_configuration_endpoint"`
}
// ClientSecretRegenerationResponse represents the response when regenerating a client secret.
type ClientSecretRegenerationResponse struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
UpdatedAt time.Time `json:"updated_at"`
}
// ClientAuthorizationRequest represents the incoming request to the /authorize endpoint.
type ClientAuthorizationRequest struct {
ClientID string `schema:"client_id"`
ResponseType string `schema:"response_type"`
RedirectURI string `schema:"redirect_uri"`
Scope types.Scope `schema:"scope,omitempty"`
State string `schema:"state,omitempty"`
Nonce string `schema:"nonce,omitempty"`
CodeChallenge string `schema:"code_challenge,omitempty"`
CodeChallengeMethod types.CodeChallengeMethod `schema:"code_challenge_method,omitempty"`
Display string `schema:"display,omitempty"`
Prompt string `schema:"prompt,omitempty"`
MaxAge string `schema:"max_age,omitempty"`
ClaimsRequest *claims.ClaimsRequest `schema:"-"`
ACRValues string `schema:"acr_values,omitempty"`
RequestURI string `schema:"request_uri,omitempty"`
RequestObject string `schema:"request,omitempty"`
UserID string
ConsentApproved bool
Client *Client
HTTPWriter http.ResponseWriter
HTTPRequest *http.Request
UserAuthenticationTime time.Time
}
type ClientInformationResponse struct {
ID string `json:"client_id"`
Secret string `json:"client_secret,omitempty"`
RegistrationClientURI string `json:"registration_client_uri"`
RegistrationAccessToken string `json:"registration_access_token"`
}
type ClientAuthenticationRequest struct {
ClientID string
ClientSecret string
RequestedGrant string
RequestedScopes types.Scope
RedirectURI string
}
type ClientRequest interface {
GetType() types.ClientType
GetGrantTypes() []string
GetRedirectURIS() []string
GetScopes() []types.Scope
GetResponseTypes() []string
GetJwksURI() string
GetLogoURI() string
GetSectorIdentifierURI() string
SetScopes(scopes []types.Scope)
HasGrantType(grantType string) bool
}
func NewClientFromRegistrationRequest(req *ClientRegistrationRequest) *Client {
client := &Client{
Name: req.Name,
Type: req.Type,
RedirectURIs: req.RedirectURIs,
GrantTypes: req.GrantTypes,
ResponseTypes: req.ResponseTypes,
}
if req.ApplicationType != "" {
client.ApplicationType = req.ApplicationType
}
if len(req.Scopes) != 0 {
client.Scopes = req.Scopes
client.CanRequestScopes = false
} else {
client.Scopes = []types.Scope{}
client.CanRequestScopes = true
}
if len(req.Contacts) != 0 {
client.Contacts = req.Contacts
}
if req.JwksURI != "" {
client.JwksURI = req.JwksURI
}
if req.PolicyURI != "" {
client.PolicyURI = req.PolicyURI
}
if req.SectorIdentifierURI != "" {
client.SectorIdentifierURI = req.SectorIdentifierURI
}
if req.LogoURI != "" {
client.LogoURI = req.LogoURI
}
if req.TokenEndpointAuthMethod != "" {
client.TokenEndpointAuthMethod = req.TokenEndpointAuthMethod
}
if req.JWKS != nil {
client.JWKS = req.JWKS
}
return client
}
func NewClientRegistrationResponseFromClient(client *Client) *ClientRegistrationResponse {
response := &ClientRegistrationResponse{
ID: client.ID,
Name: client.Name,
Type: client.Type,
RedirectURIs: client.RedirectURIs,
GrantTypes: client.RedirectURIs,
Scopes: client.Scopes,
ResponseTypes: client.RedirectURIs,
CreatedAt: client.CreatedAt,
UpdatedAt: client.UpdatedAt,
RegistrationAccessToken: client.RegistrationAccessToken,
IDIssuedAt: client.IDIssuedAt,
RegistrationClientURI: client.RegistrationClientURI,
}
if client.Secret != "" {
response.Secret = client.Secret
response.SecretExpiration = client.SecretExpiration
}
if client.ApplicationType != "" {
response.ApplicationType = client.ApplicationType
}
if len(client.Contacts) != 0 {
response.Contacts = client.Contacts
}
if client.JwksURI != "" {
response.JwksURI = client.JwksURI
}
if client.TokenEndpointAuthMethod != "" {
response.TokenEndpointAuthMethod = client.TokenEndpointAuthMethod
}
return response
}
func NewClientInformationResponse(clientID, clientSecret, registrationClientURI, registrationAccessToken string) *ClientInformationResponse {
clientInfo := &ClientInformationResponse{
ID: clientID,
RegistrationClientURI: registrationClientURI,
RegistrationAccessToken: registrationAccessToken,
}
if clientSecret != "" {
clientInfo.Secret = clientSecret
}
return clientInfo
}
func NewClientAuthorizationRequest(query url.Values) *ClientAuthorizationRequest {
req := &ClientAuthorizationRequest{
ClientID: query.Get(constants.ClientIDReqField),
RedirectURI: query.Get(constants.RedirectURIReqField),
Scope: types.Scope(query.Get(constants.ScopeReqField)),
State: query.Get(constants.StateReqField),
ResponseType: query.Get(constants.ResponseTypeReqField),
CodeChallenge: query.Get(constants.CodeChallengeReqField),
CodeChallengeMethod: types.CodeChallengeMethod(query.Get(constants.CodeChallengeMethodReqField)),
Nonce: query.Get(constants.NonceReqField),
Display: query.Get(constants.DisplayReqField),
ConsentApproved: query.Get(constants.ConsentApprovedURLValue) == "true",
Prompt: query.Get(constants.PromptReqField),
MaxAge: query.Get(constants.MaxAgeReqField),
ACRValues: query.Get(constants.ACRReqField),
RequestURI: query.Get(constants.RequestURIReqField),
RequestObject: query.Get(constants.RequestObjectReqField),
}
claimsParam := query.Get(constants.ClaimsReqField)
if claimsParam != "" {
claimsRequest, err := claims.ParseClaimsParameter(claimsParam)
if err == nil {
req.ClaimsRequest = claimsRequest
}
}
return req
}
// HasGrantType checks to see if the client has the required grant type.
func (c *Client) HasGrantType(requiredGrantType string) bool {
return slices.Contains(c.GrantTypes, requiredGrantType)
}
// HasRedirectURI checks to see if the client has the required redirectURI.
func (c *Client) HasRedirectURI(redirectURI string) bool {
return slices.Contains(c.RedirectURIs, redirectURI)
}
// HasScope checks to see if the client has the required scope.
func (c *Client) HasScope(requiredScope types.Scope) bool {
return slices.Contains(c.Scopes, requiredScope)
}
func (c *Client) HasResponseType(responseType string) bool {
return slices.Contains(c.ResponseTypes, responseType)
}
// IsConfidential checks to see if the client is public or confidential.
func (c *Client) IsConfidential() bool {
return c.Type == types.ConfidentialClient
}
func (c *Client) SecretsMatch(secret string) bool {
return c.Secret == secret
}
func (c *Client) UpdateValues(request *ClientUpdateRequest) {
if request.Name != "" {
c.Name = request.Name
}
if request.LogoURI != "" {
c.LogoURI = request.LogoURI
}
if request.PolicyURI != "" {
c.PolicyURI = request.PolicyURI
}
if request.SectorIdentifierURI != "" {
c.SectorIdentifierURI = request.SectorIdentifierURI
}
if request.JwksURI != "" {
c.JwksURI = request.JwksURI
}
if len(request.RedirectURIs) > 0 {
c.RedirectURIs = append(c.RedirectURIs, request.RedirectURIs...)
}
if len(request.GrantTypes) > 0 {
c.GrantTypes = append(c.GrantTypes, request.GrantTypes...)
}
if len(request.Scopes) > 0 {
c.Scopes = append(c.Scopes, request.Scopes...)
}
if len(request.ResponseTypes) > 0 {
c.ResponseTypes = append(c.ResponseTypes, request.ResponseTypes...)
}
if request.TokenEndpointAuthMethod != "" {
c.TokenEndpointAuthMethod = request.TokenEndpointAuthMethod
}
c.UpdatedAt = time.Now()
}
func (req *ClientRegistrationRequest) GetType() types.ClientType {
return req.Type
}
func (req *ClientRegistrationRequest) GetGrantTypes() []string {
return req.GrantTypes
}
func (req *ClientRegistrationRequest) GetRedirectURIS() []string {
return req.RedirectURIs
}
func (req *ClientRegistrationRequest) GetScopes() []types.Scope {
return req.Scopes
}
func (req *ClientRegistrationRequest) GetResponseTypes() []string {
return req.ResponseTypes
}
func (req *ClientRegistrationRequest) GetLogoURI() string {
return req.LogoURI
}
func (req *ClientRegistrationRequest) GetSectorIdentifierURI() string {
return req.SectorIdentifierURI
}
func (req *ClientRegistrationRequest) GetJwksURI() string {
return req.JwksURI
}
func (req *ClientRegistrationRequest) SetScopes(scopes []types.Scope) {
req.Scopes = scopes
}
func (req *ClientRegistrationRequest) HasGrantType(grantType string) bool {
return slices.Contains(req.GrantTypes, grantType)
}
func (req *ClientUpdateRequest) GetType() types.ClientType {
return req.Type
}
func (req *ClientUpdateRequest) GetGrantTypes() []string {
return req.GrantTypes
}
func (req *ClientUpdateRequest) GetRedirectURIS() []string {
return req.RedirectURIs
}
func (req *ClientUpdateRequest) GetScopes() []types.Scope {
return req.Scopes
}
func (req *ClientUpdateRequest) GetResponseTypes() []string {
return req.ResponseTypes
}
func (req *ClientUpdateRequest) GetLogoURI() string {
return req.LogoURI
}
func (req *ClientUpdateRequest) GetSectorIdentifierURI() string {
return req.SectorIdentifierURI
}
func (req *ClientUpdateRequest) GetJwksURI() string {
return req.JwksURI
}
func (req *ClientUpdateRequest) SetScopes(scopes []types.Scope) {
req.Scopes = scopes
}
func (req *ClientUpdateRequest) HasGrantType(grantType string) bool {
return slices.Contains(req.GrantTypes, grantType)
}
package domain
import (
"sync"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
type EmailRequest struct {
Recipient string
EmailType EmailType
VerificationCode string
VerificationToken string
BaseURL string
ID string
Retries int
}
func NewEmailRequest(recipient, verificationCode, verificationToken string, emailType EmailType) *EmailRequest {
return &EmailRequest{
Recipient: recipient,
VerificationCode: verificationCode,
VerificationToken: verificationCode,
EmailType: emailType,
ID: utils.GenerateUUID(),
Retries: 0,
}
}
type EmailType string
const (
AccountVerification EmailType = "account_verification"
AccountDeletion EmailType = "account_deletion"
)
func (t EmailType) String() string {
return string(t)
}
type EmailRetryQueue struct {
mu sync.Mutex
requests []*EmailRequest
}
func (q *EmailRetryQueue) Add(request *EmailRequest) {
q.mu.Lock()
defer q.mu.Unlock()
q.requests = append(q.requests, request)
}
func (q *EmailRetryQueue) Remove() *EmailRequest {
q.mu.Lock()
defer q.mu.Unlock()
if len(q.requests) == 0 {
return nil
}
request := q.requests[0]
q.requests = q.requests[1:]
return request
}
func (q *EmailRetryQueue) IsEmpty() bool {
q.mu.Lock()
defer q.mu.Unlock()
return len(q.requests) == 0
}
func (q *EmailRetryQueue) Size() int {
q.mu.Lock()
defer q.mu.Unlock()
return len(q.requests)
}
package domain
import (
"crypto/rsa"
"encoding/base64"
"math/big"
)
type Jwks struct {
Keys []JWK `json:"keys"`
}
type JWK struct {
Kty string `json:"kty"` // Key type (e.g., "RSA", "EC")
Kid string `json:"kid"` // Key ID, used to identify the key
Use string `json:"use"` // Public key use (e.g., "sig" for signature, "enc" for encryption)
Alg string `json:"alg"` // Algorithm intended for use with the key (e.g., "RS256", "ES256")
N string `json:"n,omitempty"` // RSA modulus (base64url-encoded)
E string `json:"e,omitempty"` // RSA public exponent (base64url-encoded)
X string `json:"x,omitempty"` // EC public key x-coordinate (base64url-encoded)
Y string `json:"y,omitempty"` // EC public key y-coordinate (base64url-encoded)
Crv string `json:"crv,omitempty"` // EC curve name (e.g., "P-256", "P-384", "P-521")
}
func NewJWK(keyID string, publicKey *rsa.PublicKey) JWK {
return JWK{
Kty: "RSA",
Kid: keyID,
Use: "sig",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(publicKey.E)).Bytes()),
}
}
package domain
import (
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// DiscoveryJSON represents the OpenID Connect Discovery Document.
// This document provides metadata about the OpenID Provider (OP),
// including supported endpoints, scopes, grant types, and algorithms.
type DiscoveryJSON struct {
Issuer string `json:"issuer"` // The URL of the OpenID Provider (OP).
AuthorizationEndpoint string `json:"authorization_endpoint"` // The endpoint for authorization requests.
TokenEndpoint string `json:"token_endpoint"` // The endpoint for token requests.
UserInfoEndpoint string `json:"userinfo_endpoint"` // The endpoint for retrieving user information.
JwksURI string `json:"jwks_uri"` // The URL for the JSON Web Key Set (JWKS).
RegistrationEndpoint string `json:"registration_endpoint"` // The endpoint for client registration.
SupportedClaims []string `json:"claims_supported"` // List of supported claims.
SupportedScopes []types.Scope `json:"scopes_supported"` // List of supported scopes.
SupportedResponseTypes []string `json:"response_types_supported"` // List of supported response types.
SupportedGrantTypes []string `json:"grant_types_supported"` // List of supported grant types.
SupportedSubjectTypes []string `json:"subject_types_supported"` // List of supported subject types (e.g., "public").
SupportedIDTokenSigningAlg []string `json:"id_token_signing_alg_values_supported"` // List of supported algorithms for ID token signing.
SupportedIDTokenEncryptionAlg []string `json:"id_token_encryption_alg_values_supported"` // List of supported algorithms for ID token encryption.
SupportedTokenEndpointAuthMethods []string `json:"token_endpoint_auth_methods_supported"` // List of supported token endpoint authentication methods.
}
// NewDiscoveryJSON creates a new instance of DiscoveryJSON with the provided base URL.
// It populates the discovery document with metadata about the OpenID Provider.
//
// Parameters:
// - baseURL: The base URL of the OpenID Provider.
//
// Returns:
// - *DiscoveryJSON: A populated DiscoveryJSON instance.
func NewDiscoveryJSON(baseURL string) *DiscoveryJSON {
return &DiscoveryJSON{
Issuer: baseURL + "/oauth2",
AuthorizationEndpoint: baseURL + web.OAuthEndpoints.Authorize,
TokenEndpoint: baseURL + web.OAuthEndpoints.Token,
UserInfoEndpoint: baseURL + web.OIDCEndpoints.UserInfo,
JwksURI: baseURL + web.OIDCEndpoints.JWKS,
RegistrationEndpoint: baseURL + web.ClientEndpoints.Register,
SupportedScopes: utils.KeysToSlice(types.SupportedScopes),
SupportedResponseTypes: utils.KeysToSlice(constants.SupportedResponseTypes),
SupportedGrantTypes: utils.KeysToSlice(constants.SupportedGrantTypes),
SupportedClaims: utils.KeysToSlice(constants.SupportedClaims),
SupportedSubjectTypes: []string{constants.SubjectTypePublic, constants.SubjectTypePairwise},
SupportedIDTokenSigningAlg: []string{constants.IDTokenSigningAlgorithmRS256},
SupportedIDTokenEncryptionAlg: []string{constants.IDTokenEncryptionAlgorithmRSA},
SupportedTokenEndpointAuthMethods: []string{
constants.AuthMethodClientSecretBasic,
constants.AuthMethodClientSecretPost,
constants.AuthMethodNone,
},
}
}
package domain
import (
"github.com/golang-jwt/jwt"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
// TokenData represents the data associated with a token.
type TokenData struct {
Token string // The token string.
ID string // The id associated with the token.
ExpiresAt int64 // The token's expiration time.
TokenID string // The ID of the token.
TokenClaims *TokenClaims // The claims associated with the token.
}
// TokenResponse represents the structure of an OAuth token response.
// This is returned to the client after successful authentication.
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
Scope types.Scope `json:"scope,omitempty"`
ExpiresIn int64 `json:"expires_in"`
}
type TokenIntrospectionResponse struct {
Active bool `json:"active"`
ExpiresAt int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
Subject string `json:"subject,omitempty"`
Audience string `json:"aud,omitempty"`
Issuer string `json:"iss,omitempty"`
TokenIdentifier string `json:"jti,omitempty"`
}
type TokenRequest struct {
GrantType string `json:"grant_type"`
AuthorizationCode string `json:"code"`
RedirectURI string `json:"redirect_uri"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
State string `json:"state"`
CodeVerifier string `json:"code_verifier,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
type TokenClaims struct {
Scopes types.Scope `json:"scopes,omitempty"`
Roles string `json:"roles,omitempty"`
Nonce string `json:"nonce,omitempty"`
AuthTime int64 `json:"auth_time,omitempty"`
RequestedClaims *domain.ClaimsRequest `json:"claims,omitempty"`
ACRValues string `json:"acr,omitempty"`
ClientID string `json:"client_id,omitempty"`
RedirectURI string `json:"redirect_uri,omitempty"`
State string `json:"state,omitempty"`
*jwt.StandardClaims
}
const BearerToken string = "bearer"
func NewTokenIntrospectionResponse(claims *TokenClaims) *TokenIntrospectionResponse {
response := &TokenIntrospectionResponse{
ExpiresAt: claims.ExpiresAt,
IssuedAt: claims.IssuedAt,
Subject: claims.Subject,
Issuer: claims.Issuer,
TokenIdentifier: claims.Id,
Active: true,
}
if claims.Audience != "" {
response.Audience = claims.Audience
}
return response
}
package domain
import (
"fmt"
"regexp"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
func (t *TokenRequest) ValidateCodeVerifier() error {
codeVerifierLength := len(t.CodeVerifier)
if codeVerifierLength < 43 || codeVerifierLength > 128 {
return errors.New(errors.ErrCodeInvalidRequest, fmt.Sprintf("invalid code verifier length (%d): must be between 43 and 128 characters", codeVerifierLength))
}
validCodeVerifierRegex := regexp.MustCompile(`^[A-Za-z0-9._~-]+$`)
if !validCodeVerifierRegex.MatchString(t.CodeVerifier) {
return errors.New(errors.ErrCodeInvalidRequest, "invalid characters: only A-Z, a-z, 0-9, '-', and '_' are allowed (Base64 URL encoding)")
}
return nil
}
package domain
import (
"fmt"
"slices"
"time"
)
// User represents a user in the system.
type User struct {
ID string
PreferredUsername string
Name string
GivenName string
MiddleName string
FamilyName string
Nickname string
Profile string // URL to user’s profile
Picture string // URL to user’s picture
Website string // Personal website URL
Email string
PhoneNumber string
Password string
Gender string
Birthdate string
Zoneinfo string // e.g., "America/New_York"
Locale string // e.g., "en-US"
Address *UserAddress
Roles []string
LastFailedLogin time.Time
CreatedAt time.Time
UpdatedAt time.Time
AccountLocked bool
EmailVerified bool
PhoneNumberVerified bool
}
type UserAddress struct {
Formatted string `json:"formatted,omitempty"`
StreetAddress string `json:"street_address"`
Locality string `json:"locality"`
Region string `json:"region"`
PostalCode string `json:"postal_code"`
Country string `json:"country"`
}
// UserRegistrationRequest represents the registration request payload.
type UserRegistrationRequest struct {
Username string `json:"username"`
Nickname string `json:"nickname,omitempty"`
FirstName string `json:"first_name"`
MiddleName string `json:"middle_name,omitempty"`
FamilyName string `json:"family_name"`
Birthdate string `json:"birthdate"`
Email string `json:"email"`
Profile string `json:"profile,omitempty"` // URL to the user's profile page
Picture string `json:"picture,omitempty"` // URL to the user's picture/avatar
Website string `json:"website,omitempty"` // URL to the user's personal website
Gender string `json:"gender"`
PhoneNumber string `json:"phone_number,omitempty"`
Password string `json:"password"`
Address UserAddress `json:"address"`
Scopes []string `json:"scope,omitempty"`
Roles []string `json:"roles,omitempty"`
}
// UserInfoResponse represents the payload for the user info request.
type UserInfoResponse struct {
Sub string `json:"sub"`
Name string `json:"name,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Nickname string `json:"nickname,omitempty"`
PreferredUsername string `json:"preferred_username,omitempty"`
Profile string `json:"profile,omitempty"`
Picture string `json:"picture,omitempty"`
Website string `json:"website,omitempty"`
Gender string `json:"gender,omitempty"`
Birthdate string `json:"birthdate,omitempty"`
Zoneinfo string `json:"zoneinfo,omitempty"`
Locale string `json:"locale,omitempty"`
Email string `json:"email,omitempty"`
EmailVerified *bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
UpdatedAt int64 `json:"updated_at,omitempty"`
Address *UserAddress `json:"address,omitempty"`
}
// UserRegistrationResponse represents the registration response payload.
type UserRegistrationResponse struct {
Username string `json:"username"`
Name string `json:"name"`
Gender string `json:"gender"`
Birthdate string `json:"birthdate"`
Email string `json:"email"`
PhoneNumber string `json:"phone_number,omitempty"`
Address string `json:"address"`
JWTToken string `json:"token"`
}
// UserLoginRequest represents the login request payload.
type UserLoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
ClientID string
RedirectURI string
}
// UserLoginResponse represents the login response payload.
type UserLoginResponse struct {
UserID string
Username string `json:"username"`
Email string `json:"email"`
AccessToken string `json:"access_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
OAuthRedirectURL string `json:"oauth_redirect_url,omitempty"`
LastFailedLogin time.Time `json:"last_failed_login"`
Scopes []string `json:"scopes,omitempty"`
Roles []string `json:"roles,omitempty"`
}
// UserPasswordResetRequest represents the password reset request payload.
type UserPasswordResetRequest struct {
Email string `json:"email"`
ResetToken string `json:"reset_token"`
NewPassword string `json:"new_password"`
}
// UserPasswordResetResponse represents the password reset response payload.
type UserPasswordResetResponse struct {
Message string `json:"message"`
}
// UserLoginAttempt represents a user's login attempt.
type UserLoginAttempt struct {
UserID string
IPAddress string
Username string
Password string
ForwardedFor string
Timestamp time.Time
RequestMetadata string
Details string
UserAgent string
FailedAttempts int
}
// NewUser creates a new User instance.
//
// Parameters:
// - username string: The user's username.
// - email string: The user's email address.
// - password string: The user's password (hashed).
//
// Returns:
// - *User: A new User instance.
func NewUser(username, email, password string) *User {
return &User{
PreferredUsername: username,
Email: email,
Password: password,
LastFailedLogin: time.Time{},
AccountLocked: false,
EmailVerified: false,
}
}
// NewUserFromRegistrationRequest create a new user instance from a registration request.
//
// Parameters:
// - req *UserRegistrationRequest: The request.
//
// Returns:
// - *User: A new user instance.
func NewUserFromRegistrationRequest(req *UserRegistrationRequest) *User {
name := fmt.Sprintf("%s %s %s", req.FirstName, req.MiddleName, req.FamilyName)
return &User{
PreferredUsername: req.Username,
Password: req.Password,
Name: name,
GivenName: req.FirstName,
MiddleName: req.MiddleName,
FamilyName: req.FamilyName,
Gender: req.Gender,
Birthdate: req.Birthdate,
Email: req.Email,
PhoneNumber: req.PhoneNumber,
Zoneinfo: time.UTC.String(),
Roles: req.Roles,
Locale: req.Address.Locality,
Website: req.Website,
Profile: req.Profile,
Picture: req.Picture,
Nickname: req.Nickname,
Address: NewUserAddress(
req.Address.StreetAddress,
req.Address.Locality,
req.Address.Region,
req.Address.PostalCode,
req.Address.Country,
),
LastFailedLogin: time.Time{},
AccountLocked: false,
EmailVerified: false,
PhoneNumberVerified: false,
}
}
// NewUserRegistrationRequest creates a new UserRegistrationRequest instance.
//
// Parameters:
// - username string: The username for the registration request.
// - email string: The email for the registration request.
// - password string: The password for the registration request.
//
// Returns:
// - *UserRegistrationRequest: A new UserRegistrationRequest instance.
func NewUserRegistrationRequest(username, email, password string) *UserRegistrationRequest {
return &UserRegistrationRequest{
Username: username,
Email: email,
Password: password,
}
}
// NewUserRegistrationResponse creates a new UserRegistrationResponse instance.
//
// Parameters:
// - user *User: The created User object.
// - jwtToken string: The JWT token for the registered user.
//
// Returns:
// - *UserRegistrationResponse: A new UserRegistrationResponse instance.
func NewUserRegistrationResponse(user *User, jwtToken string) *UserRegistrationResponse {
return &UserRegistrationResponse{
Username: user.PreferredUsername,
Name: user.Name,
Gender: user.Gender,
Birthdate: user.Birthdate,
Email: user.Email,
PhoneNumber: user.PhoneNumber,
Address: user.Address.Formatted,
JWTToken: jwtToken,
}
}
// NewUserAddress created a new UserAddress instance.
//
// Parameters:
// - streetAddress string: The street address component, which may include house number, street name, and post office box.
// - locality string: City or locality component.
// - region string: State, province, prefecture or region component.
// - postalCode string: Zip code or postal code component.
// - country string: Country name component.
//
// Returns:
// - *UserAddress: A new UserAddress instance.
func NewUserAddress(streetAddress, locality, region, postalCode, country string) *UserAddress {
formattedAddress := formatAddress(streetAddress, locality, region, postalCode, country)
return &UserAddress{
Formatted: formattedAddress,
StreetAddress: streetAddress,
Locality: locality,
Region: region,
PostalCode: postalCode,
Country: country,
}
}
// NewUserLoginRequest creates a new UserLoginRequest instance.
//
// Parameters:
// - username string: The username for the login request.
// - password string: The password for the login request.
//
// Returns:
// - *UserLoginRequest: A new UserLoginRequest instance.
func NewUserLoginRequest(username, password string) *UserLoginRequest {
return &UserLoginRequest{
Username: username,
Password: password,
}
}
// NewUserLoginResponse creates a new UserLoginResponse instance.
//
// Parameters:
// - user *User: The authenticated User object.
//
// Returns:
// - *UserLoginResponse: A new UserLoginResponse instance.
func NewUserLoginResponse(user *User) *UserLoginResponse {
return &UserLoginResponse{
UserID: user.ID,
Username: user.PreferredUsername,
Email: user.Email,
Roles: user.Roles,
LastFailedLogin: user.LastFailedLogin,
}
}
// NewUserLoginAttempt creates a new UserLoginAttempt instance.
//
// Parameters:
// - ipAddress string: The IP address of the login attempt.
// - requestMetadata string: Additional request metadata.
// - details string: Details about the login attempt.
// - userAgent string: The user agent of the login attempt.
//
// Returns:
// - *LoginAttempt: A new UserLoginAttempt instance.
func NewUserLoginAttempt(ipAddress, userAgent string) *UserLoginAttempt {
return &UserLoginAttempt{
IPAddress: ipAddress,
Timestamp: time.Now(),
UserAgent: userAgent,
FailedAttempts: 0,
}
}
func (u *User) HasRole(role string) bool {
return slices.Contains(u.Roles, role)
}
func formatAddress(streetAddress, locality, region, postalCode, country string) string {
return fmt.Sprintf("%s\n%s, %s %s\n%s",
streetAddress,
locality, region, postalCode,
country,
)
}
package domain
import (
"fmt"
"net/mail"
"regexp"
"strings"
"time"
"unicode"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
// Validate validates the UserRegistrationRequest fields.
//
// Returns:
//
// error: An ErrorCollection if validation fails, or nil if validation succeeds.
func (req *UserRegistrationRequest) Validate() error {
errorCollection := errors.NewErrorCollection()
if req.Username == "" {
err := errors.New(errors.ErrCodeEmptyInput, "'username' is empty")
errorCollection.Add(err)
}
validateEmail(req.Email, errorCollection)
validatePassword(req.Password, errorCollection)
validatePhoneNumber(req.PhoneNumber, errorCollection)
validateBirthdate(req.Birthdate, errorCollection)
req.validateRole(errorCollection)
if errorCollection.HasErrors() {
return errorCollection
}
return nil
}
// Validate validates the UserPasswordResetRequest fields.
//
// Returns:
// - error: An ErrorCollection if validation fails, or nil if validation succeeds.
func (req *UserPasswordResetRequest) Validate() error {
errorCollection := errors.NewErrorCollection()
validatePassword(req.NewPassword, errorCollection)
if errorCollection.HasErrors() {
return errorCollection
}
return nil
}
// Validate validates the UserLoginRequest fields.
//
// Returns:
// - error: An ErrorCollection if validation fails, or nil if validation succeeds.
func (req *UserLoginRequest) Validate() error {
errorCollection := errors.NewErrorCollection()
if req.Password == "" {
err := errors.New(errors.ErrCodeEmptyInput, "'password' is empty")
errorCollection.Add(err)
}
if req.Username == "" {
err := errors.New(errors.ErrCodeEmptyInput, "'username' is empty")
errorCollection.Add(err)
}
if errorCollection.HasErrors() {
return errorCollection
}
return nil
}
// validateEmail validates the email format and adds errors to the ErrorCollection.
//
// Parameters:
// - email string: The email address to validate.
// - errorCollection *errors.ErrorCollection: The ErrorCollection to add errors to.
func validateEmail(email string, errorCollection *errors.ErrorCollection) {
if email == "" {
err := errors.New(errors.ErrCodeEmptyInput, "'email' is empty")
errorCollection.Add(err)
} else if !isValidEmailFormat(email) {
err := errors.New(errors.ErrCodeInvalidFormat, "invalid email format")
errorCollection.Add(err)
}
}
// validatePassword validates the password and adds errors to the ErrorCollection.
//
// Parameters:
// - password string: The password to validate.
// - errorCollection *errors.ErrorCollection: The ErrorCollection to add errors to.
func validatePassword(password string, errorCollection *errors.ErrorCollection) {
if password == "" {
err := errors.New(errors.ErrCodeEmptyInput, "password is empty")
errorCollection.Add(err)
return
}
passwordConfig := config.GetServerConfig().PasswordConfig()
minimumLength := passwordConfig.MinLength()
if len(password) < minimumLength {
err := errors.New(errors.ErrCodePasswordLength, "password does not match required length")
errorCollection.Add(err)
}
if passwordConfig.RequireUppercase() && !containsUppercase(password) {
err := errors.New(errors.ErrCodeMissingUppercase, "password is missing a required uppercase letter")
errorCollection.Add(err)
}
if passwordConfig.RequireNumber() && !containsNumber(password) {
err := errors.New(errors.ErrCodeMissingNumber, "password is missing a required number")
errorCollection.Add(err)
}
if passwordConfig.RequireSymbol() && !containsSymbol(password) {
err := errors.New(errors.ErrCodeMissingSymbol, "password is missing a required symbol")
errorCollection.Add(err)
}
}
// validateBirthdate validates the user's password and ensures it follows the
// ISO 8601:2004 YYYY-MM-DD format
//
// Parameters:
// - birthdate: The birthdate to validate.
// - errorCollection *errors.ErrorCollection: The ErrorCollection to add errors to.
func validateBirthdate(birthdate string, errorCollection *errors.ErrorCollection) {
if birthdate == "" {
err := errors.New(errors.ErrCodeEmptyInput, "birthdate is empty")
errorCollection.Add(err)
return
}
const pattern string = `^\d{4}-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$`
re := regexp.MustCompile(pattern)
if !re.MatchString(birthdate) {
err := errors.New(errors.ErrCodeInvalidFormat, "invalid birthdate format - must follow the ISO 8601:2004 YYYY-MM-DD format")
errorCollection.Add(err)
return
}
const dateFormat string = "2006-01-02"
if _, err := time.Parse(dateFormat, birthdate); err != nil {
err := errors.New(errors.ErrCodeInvalidDate, "the birthdate provided is an invalid date")
errorCollection.Add(err)
return
}
}
// validatePhoneNumber validates the phone number and makes sure it is in E.164 format.
//
// Parameters:
// - phoneNumber string: The phone number to verify.
// - errorCollection *errors.ErrorCollection: The ErrorCollection to add errors to.
func validatePhoneNumber(phoneNumber string, errorCollection *errors.ErrorCollection) {
if phoneNumber == "" {
return
}
const e164Format string = `^\+[1-9]\d{1,14}$`
re := regexp.MustCompile(e164Format)
if !re.MatchString(phoneNumber) {
err := errors.New(errors.ErrCodeInvalidFormat, "invalid phone number format")
errorCollection.Add(err)
return
}
}
// isValidEmailFormat validates the email format.
//
// Parameters:
// - email string: The email address to validate.
//
// Returns:
// - bool: True if the email format is valid, false otherwise.
func isValidEmailFormat(email string) bool {
_, err := mail.ParseAddress(email)
return err == nil
}
func containsUppercase(password string) bool {
return strings.IndexFunc(password, unicode.IsUpper) >= 0
}
func containsNumber(password string) bool {
return strings.IndexFunc(password, unicode.IsNumber) >= 0
}
func containsSymbol(password string) bool {
return strings.IndexFunc(password, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsNumber(r)
}) >= 0
}
func (req *UserRegistrationRequest) validateRole(errorCollection *errors.ErrorCollection) {
if len(req.Roles) == 0 {
req.Roles = append(req.Roles, constants.UserRole)
}
for _, role := range req.Roles {
if _, ok := constants.ValidRoles[role]; !ok {
errorCollection.Add(errors.New(errors.ErrCodeBadRequest, fmt.Sprintf("invalid role: %s", role)))
}
}
}
package errors
import "net/http"
// Error codes
const (
// Validation errors
ErrCodeEmptyInput string = "empty_field"
ErrCodeInvalidPasswordFormat string = "invalid_password_format"
ErrCodePasswordLength string = "invalid_password_length"
ErrCodeMissingUppercase string = "missing_required_uppercase"
ErrCodeMissingNumber string = "missing_required_number"
ErrCodeMissingSymbol string = "missing_required_symbol"
ErrCodeInvalidEmail string = "invalid_email_format"
ErrCodeInvalidFormat string = "invalid_format"
ErrCodeValidationError string = "validation_error"
ErrCodeMissingHeader string = "missing_header"
ErrCodeInvalidContentType string = "invalid_content_type"
ErrCodeInvalidRequest string = "invalid_request"
ErrCodeBadRequest string = "bad_request"
ErrCodeInvalidInput string = "invalid_input"
ErrCodeInvalidDate string = "invalid_date"
ErrCodeNotFound string = "not_found"
// User errors
ErrCodeDuplicateUser string = "duplicate_user"
ErrCodeUserNotFound string = "user_not_found"
ErrCodeInvalidCredentials string = "invalid_credentials"
ErrCodeAccountLocked string = "account_locked"
ErrCodeUnauthorized string = "unauthorized"
ErrCodeLoginRequired string = "login_required"
ErrCodeConsentRequired string = "consent_required"
ErrCodeInsufficientRole string = "insufficient_role"
ErrCodeInteractionRequired string = "interaction_required"
// Token errors
ErrCodeTokenNotFound string = "token_not_found"
ErrCodeExpiredToken string = "token_expired"
ErrCodeTokenCreation string = "token_creation"
ErrCodeInvalidToken string = "invalid_token"
ErrCodeTokenParsing string = "token_parsing"
ErrCodeTokenEncryption string = "token_encrypt_failed"
ErrCodeTokenDecryption string = "token_decrypt_failed"
ErrCodeDuplicateToken string = "duplicate_token"
ErrCodeTokenSigning string = "token_signing"
// Email errors
ErrCodeConnectionFailed string = "connection_failed"
ErrCodeEmailDeliveryFailed string = "delivery_failed"
// Client errors
ErrCodeInvalidClient string = "invalid_client"
ErrCodeInvalidGrant string = "invalid_grant"
ErrCodeInvalidRedirectURI string = "invalid_redirect_uri"
ErrCodeInsufficientScope string = "insufficient_scope"
ErrCodeClientSecretNotAllowed string = "client_secret_not_allowed"
ErrCodeClientNotFound string = "client_not_found"
ErrCodeDuplicateClient string = "duplicate_client"
ErrCodeInvalidResponseType string = "invalid_response_type"
ErrCodeUnauthorizedClient string = "unauthorized_client"
ErrCodeUnsupportedGrantType string = "unsupported_grant_type"
ErrCodeAccessDenied string = "access_denied"
ErrCodeInvalidClientMetadata string = "invalid_client_metadata"
ErrCodeRequestURINotSupported string = "request_uri_not_supported"
ErrCodeRequestObjectNotSupported string = "request_not_supported"
// Session errors
ErrCodeDuplicateSession string = "duplicate_session"
ErrCodeSessionNotFound string = "session_not_found"
ErrCodeInvalidSession string = "invalid_session"
ErrCodeSessionExpired string = "expired_session"
// Middleware Errors
ErrCodeRequestLimitExceeded string = "request_limit_exceeded"
ErrCodeSessionCreation string = "session_create_failed"
ErrCodeSessionSave string = "session_save_failed"
// System errors
ErrCodeInternalServerError string = "server_error"
ErrCodeRequestTimeout string = "request_timeout"
ErrCodeRequestCancelled string = "request_cancelled"
ErrCodeResourceNotFound string = "resource_not_found"
ErrCodeMethodNotAllowed string = "method_not_allowed"
// Authorization Code Errors
ErrCodeInvalidAuthorizationCode string = "invalid_authorization_code"
ErrCodeExpiredAuthorizationCode string = "expired_authorization_code"
ErrCodeAuthorizationCodeNotFound string = "code_not_found"
// Encryption Errors
ErrCodeHashingFailed string = "hashing_failed"
ErrCodeRandomGenerationFailed string = "random_generation_failed"
ErrCodeEncryptionFailed string = "encryption_failed"
ErrCodeDecryptionFailed string = "decryption_failed"
)
// HTTP status code mappings
var HTTPStatusCodeMap = map[string]int{
// 400 Bad Request
ErrCodeEmptyInput: http.StatusBadRequest,
ErrCodeMissingNumber: http.StatusBadRequest,
ErrCodeMissingSymbol: http.StatusBadRequest,
ErrCodeMissingUppercase: http.StatusBadRequest,
ErrCodeInvalidPasswordFormat: http.StatusBadRequest,
ErrCodePasswordLength: http.StatusBadRequest,
ErrCodeInvalidEmail: http.StatusBadRequest,
ErrCodeValidationError: http.StatusBadRequest,
ErrCodeMissingHeader: http.StatusBadRequest,
ErrCodeUnauthorizedClient: http.StatusBadRequest,
ErrCodeClientSecretNotAllowed: http.StatusBadRequest,
ErrCodeInvalidResponseType: http.StatusBadRequest,
ErrCodeInvalidContentType: http.StatusBadRequest,
ErrCodeUnsupportedGrantType: http.StatusBadRequest,
ErrCodeInvalidRequest: http.StatusBadRequest,
ErrCodeBadRequest: http.StatusBadRequest,
ErrCodeInvalidClientMetadata: http.StatusBadRequest,
ErrCodeInvalidGrant: http.StatusBadRequest,
ErrCodeInvalidInput: http.StatusBadRequest,
ErrCodeInvalidDate: http.StatusBadRequest,
ErrCodeInteractionRequired: http.StatusBadRequest,
ErrCodeLoginRequired: http.StatusBadRequest,
ErrCodeInvalidRedirectURI: http.StatusBadRequest,
ErrCodeRequestURINotSupported: http.StatusBadRequest,
ErrCodeRequestObjectNotSupported: http.StatusBadRequest,
// 401 Unauthorized
ErrCodeInvalidCredentials: http.StatusUnauthorized,
ErrCodeInvalidClient: http.StatusUnauthorized,
ErrCodeUnauthorized: http.StatusUnauthorized,
ErrCodeExpiredToken: http.StatusUnauthorized,
ErrCodeInvalidToken: http.StatusUnauthorized,
ErrCodeTokenParsing: http.StatusUnauthorized,
ErrCodeConsentRequired: http.StatusUnauthorized,
ErrCodeInvalidSession: http.StatusUnauthorized,
ErrCodeInvalidAuthorizationCode: http.StatusUnauthorized,
ErrCodeExpiredAuthorizationCode: http.StatusUnauthorized,
ErrCodeSessionExpired: http.StatusUnauthorized,
// 404 Not Found
ErrCodeUserNotFound: http.StatusNotFound,
ErrCodeTokenNotFound: http.StatusNotFound,
ErrCodeClientNotFound: http.StatusNotFound,
ErrCodeSessionNotFound: http.StatusNotFound,
ErrCodeResourceNotFound: http.StatusNotFound,
ErrCodeAuthorizationCodeNotFound: http.StatusNotFound,
ErrCodeNotFound: http.StatusNotFound,
// 403 Forbidden
ErrCodeAccessDenied: http.StatusForbidden,
ErrCodeInsufficientRole: http.StatusForbidden,
ErrCodeInsufficientScope: http.StatusForbidden,
// 409 Conflict
ErrCodeDuplicateUser: http.StatusConflict,
ErrCodeDuplicateClient: http.StatusConflict,
ErrCodeDuplicateSession: http.StatusConflict,
// 422 Unprocessable Entity
ErrCodeInvalidFormat: http.StatusUnprocessableEntity,
// 423 Locked
ErrCodeAccountLocked: http.StatusLocked,
// 408 Request Timeout
ErrCodeRequestTimeout: http.StatusRequestTimeout,
ErrCodeRequestCancelled: http.StatusRequestTimeout,
// 500 Internal Server Error
ErrCodeInternalServerError: http.StatusInternalServerError,
ErrCodeTokenCreation: http.StatusInternalServerError,
ErrCodeEmailDeliveryFailed: http.StatusInternalServerError,
ErrCodeSessionCreation: http.StatusInternalServerError,
ErrCodeSessionSave: http.StatusInternalServerError,
ErrCodeTokenEncryption: http.StatusInternalServerError,
ErrCodeTokenDecryption: http.StatusInternalServerError,
ErrCodeDuplicateToken: http.StatusInternalServerError,
ErrCodeTokenSigning: http.StatusInternalServerError,
ErrCodeHashingFailed: http.StatusInternalServerError,
ErrCodeRandomGenerationFailed: http.StatusInternalServerError,
ErrCodeEncryptionFailed: http.StatusInternalServerError,
ErrCodeDecryptionFailed: http.StatusInternalServerError,
// 502 Bad Gateway
ErrCodeConnectionFailed: http.StatusBadGateway,
// 429 Too Many Requests
ErrCodeRequestLimitExceeded: http.StatusTooManyRequests,
// 405 Method Not Allowed
ErrCodeMethodNotAllowed: http.StatusMethodNotAllowed,
}
const (
prefix string = "VIG"
validationError string = "VAL"
userError string = "USR"
tokenError string = "TOK"
emailError string = "EML"
clientError string = "CLI"
sessionError string = "SES"
middlewareError string = "MDW"
systemError string = "SYS"
authzCodeError string = "AUTH"
cryptoError string = "CRY"
)
var SystemErrorCodeMap = map[string]string{
// Validation Errors
ErrCodeEmptyInput: prefix + validationError + "0001",
ErrCodeInvalidPasswordFormat: prefix + validationError + "0002",
ErrCodePasswordLength: prefix + validationError + "0003",
ErrCodeMissingUppercase: prefix + validationError + "0004",
ErrCodeMissingNumber: prefix + validationError + "0005",
ErrCodeMissingSymbol: prefix + validationError + "0006",
ErrCodeInvalidEmail: prefix + validationError + "0007",
ErrCodeInvalidFormat: prefix + validationError + "0008",
ErrCodeValidationError: prefix + validationError + "0009",
ErrCodeMissingHeader: prefix + validationError + "0010",
ErrCodeInvalidContentType: prefix + validationError + "0011",
ErrCodeInvalidRequest: prefix + validationError + "0012",
ErrCodeBadRequest: prefix + validationError + "0013",
ErrCodeInvalidInput: prefix + validationError + "0014",
ErrCodeInvalidDate: prefix + validationError + "0015",
ErrCodeNotFound: prefix + validationError + "0016",
// User Errors
ErrCodeDuplicateUser: prefix + userError + "0001",
ErrCodeUserNotFound: prefix + userError + "0002",
ErrCodeInvalidCredentials: prefix + userError + "0003",
ErrCodeAccountLocked: prefix + userError + "0004",
ErrCodeUnauthorized: prefix + userError + "0005",
ErrCodeLoginRequired: prefix + userError + "0006",
ErrCodeConsentRequired: prefix + userError + "0007",
ErrCodeInsufficientRole: prefix + userError + "0008",
ErrCodeInteractionRequired: prefix + userError + "0009",
// Token Errors
ErrCodeTokenNotFound: prefix + tokenError + "0001",
ErrCodeExpiredToken: prefix + tokenError + "0002",
ErrCodeTokenCreation: prefix + tokenError + "0003",
ErrCodeInvalidToken: prefix + tokenError + "0004",
ErrCodeTokenParsing: prefix + tokenError + "0005",
ErrCodeTokenEncryption: prefix + tokenError + "0006",
ErrCodeTokenDecryption: prefix + tokenError + "0007",
ErrCodeDuplicateToken: prefix + tokenError + "0008",
ErrCodeTokenSigning: prefix + tokenError + "0009",
// Email Errors
ErrCodeConnectionFailed: prefix + emailError + "0001",
ErrCodeEmailDeliveryFailed: prefix + emailError + "0002",
// Client Errors
ErrCodeInvalidClient: prefix + clientError + "0001",
ErrCodeInvalidGrant: prefix + clientError + "0002",
ErrCodeInvalidRedirectURI: prefix + clientError + "0003",
ErrCodeInsufficientScope: prefix + clientError + "0004",
ErrCodeClientSecretNotAllowed: prefix + clientError + "0005",
ErrCodeClientNotFound: prefix + clientError + "0006",
ErrCodeDuplicateClient: prefix + clientError + "0007",
ErrCodeInvalidResponseType: prefix + clientError + "0008",
ErrCodeUnauthorizedClient: prefix + clientError + "0009",
ErrCodeUnsupportedGrantType: prefix + clientError + "0010",
ErrCodeAccessDenied: prefix + clientError + "0011",
ErrCodeInvalidClientMetadata: prefix + clientError + "0012",
ErrCodeRequestURINotSupported: prefix + clientError + "0013",
ErrCodeRequestObjectNotSupported: prefix + clientError + "0014",
// Session Errors
ErrCodeDuplicateSession: prefix + sessionError + "0001",
ErrCodeSessionNotFound: prefix + sessionError + "0002",
ErrCodeInvalidSession: prefix + sessionError + "0003",
ErrCodeSessionExpired: prefix + sessionError + "0004",
// Middleware Errors
ErrCodeRequestLimitExceeded: prefix + middlewareError + "0001",
ErrCodeSessionCreation: prefix + middlewareError + "0002",
ErrCodeSessionSave: prefix + middlewareError + "0003",
// System Errors
ErrCodeInternalServerError: prefix + systemError + "0001",
ErrCodeRequestTimeout: prefix + systemError + "0002",
ErrCodeRequestCancelled: prefix + systemError + "0003",
ErrCodeResourceNotFound: prefix + systemError + "0004",
ErrCodeMethodNotAllowed: prefix + systemError + "0005",
// Authorization Code Errors
ErrCodeInvalidAuthorizationCode: prefix + authzCodeError + "0001",
ErrCodeExpiredAuthorizationCode: prefix + authzCodeError + "0002",
ErrCodeAuthorizationCodeNotFound: prefix + authzCodeError + "0003",
// Crypto Errors
ErrCodeHashingFailed: prefix + cryptoError + "0001",
ErrCodeRandomGenerationFailed: prefix + cryptoError + "0002",
ErrCodeEncryptionFailed: prefix + cryptoError + "0003",
ErrCodeDecryptionFailed: prefix + cryptoError + "0004",
}
// StatusCode returns the HTTP status code associated with the error code
func StatusCode(errorCode string) int {
if status, exists := HTTPStatusCodeMap[errorCode]; exists {
return status
}
return http.StatusInternalServerError // Default status
}
package errors
import "fmt"
type ErrorCollection struct {
errors []error
}
// NewErrorCollection creates a new collection of errors
func NewErrorCollection() *ErrorCollection {
return &ErrorCollection{
errors: []error{},
}
}
// Add adds an error to the collection
func (ec *ErrorCollection) Add(err error) {
ec.errors = append(ec.errors, err)
}
// Errors returns the list of validation errors
func (ec *ErrorCollection) Errors() *[]error {
return &ec.errors
}
// HasErrors checks if there are any validation errors
func (ec *ErrorCollection) HasErrors() bool {
return len(ec.errors) > 0
}
// Error implements the error interface
func (ec *ErrorCollection) Error() string {
return fmt.Sprintf("%d errors", len(ec.errors))
}
package errors
import (
"context"
"errors"
"fmt"
)
// VigiloAuthError represents a standardized error structure
type VigiloAuthError struct {
SystemCode string `json:"error_code"`
ErrorCode string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorDetails string `json:"error_details,omitempty"`
WrappedErr error `json:"-"`
Errors *[]error `json:"errors,omitempty"`
RedirectURL string `json:"redirect_url,omitempty"`
ConsentURL string `json:"consent_url,omitempty"`
}
// Error implements the error interface
func (e *VigiloAuthError) Error() string {
if e.ErrorDetails != "" {
return fmt.Sprintf("%s: %s", e.ErrorDescription, e.ErrorDetails)
}
return e.ErrorDescription
}
// New creates a new error with the given code and message
//
// Parameters:
//
// errCode string: The error code
// errorDescription string: A brief description of the error.
func New(errCode string, errorDescription string) error {
return &VigiloAuthError{
SystemCode: SystemErrorCodeMap[errCode],
ErrorCode: errCode,
ErrorDescription: errorDescription,
}
}
// NewInternalServerError creates a new error with default fields.
func NewInternalServerError(errDetails string) error {
return &VigiloAuthError{
SystemCode: SystemErrorCodeMap[ErrCodeInternalServerError],
ErrorCode: ErrCodeInternalServerError,
ErrorDescription: "An unexpected error occurred. Please try again later.",
ErrorDetails: errDetails,
}
}
// NewConsentRequiredError returns a new VigiloAuthError when the user's consent is required
// for the requested scope. The error includes the consent URL.
func NewConsentRequiredError(url string) *VigiloAuthError {
return &VigiloAuthError{
ErrorCode: ErrCodeConsentRequired,
ErrorDescription: "user consent required for the requested scope(s)",
ConsentURL: url,
}
}
func NewAccessDeniedError() *VigiloAuthError {
return &VigiloAuthError{
ErrorCode: ErrCodeAccessDenied,
ErrorDescription: "the resource owner denied the request",
}
}
func NewSessionCreationError(err error) error {
return Wrap(err, "", "failed to create new session")
}
func NewRequestValidationError(err error) error {
return Wrap(err, "", "failed to validate request parameters")
}
func NewRequestBodyDecodingError(err error) error {
return New(ErrCodeInvalidRequest, "missing one or more required fields in the request")
}
func NewMethodNotAllowedError(method string) error {
return New(ErrCodeMethodNotAllowed, fmt.Sprintf("method not allowed: %s", method))
}
func NewMissingParametersError() error {
return New(ErrCodeInvalidRequest, "missing one or more required parameters")
}
func NewInvalidSessionError() error {
return &VigiloAuthError{
ErrorCode: ErrCodeInvalidSession,
ErrorDescription: "unable to retrieve session data",
ErrorDetails: "session not found or expired",
}
}
func NewClientAuthenticationError(err error) error {
return Wrap(err, "", "failed to authenticate request")
}
func NewFormParsingError(err error) error {
return Wrap(err, ErrCodeInvalidRequest, "unable to parse form")
}
func NewTimeoutError(err error) error {
return Wrap(err, ErrCodeRequestTimeout, "the request timed out")
}
func NewContextCancelledError(err error) error {
return Wrap(err, ErrCodeRequestCancelled, "the request was cancelled")
}
func NewContextError(err error) error {
if errors.Is(err, context.DeadlineExceeded) {
return NewTimeoutError(err)
} else {
return NewContextCancelledError(err)
}
}
func IsContextError(err error) bool {
return errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)
}
// Wrap wraps an existing error with additional context
// If no code is provided, it will extract it from the wrapper error.
func Wrap(err error, code string, message string) error {
if err == nil {
return nil
}
if code == "" {
var vigiloErr *VigiloAuthError
if errors.As(err, &vigiloErr) {
code = vigiloErr.ErrorCode
}
}
vigiloError := &VigiloAuthError{}
if e, ok := err.(*ErrorCollection); ok { //nolint:errorlint
vigiloError.ErrorCode = ErrCodeValidationError
vigiloError.ErrorDescription = message
vigiloError.Errors = e.Errors()
vigiloError.ErrorDetails = "one or more validation errors occurred"
} else if IsContextError(err) {
return NewContextError(err)
} else {
vigiloError.SystemCode = SystemErrorCodeMap[code]
vigiloError.ErrorCode = code
vigiloError.ErrorDescription = message
vigiloError.ErrorDetails = err.Error()
vigiloError.WrappedErr = err
}
return vigiloError
}
// Unwrap returns the wrapped error
func (e *VigiloAuthError) Unwrap() error {
return e.WrappedErr
}
// ErrorCode extracts the error code from a VigiloAuthError
func ErrorCode(err error) string {
if err == nil {
return ""
}
if vigiloErr, ok := err.(*VigiloAuthError); ok { //nolint:errorlint
return vigiloErr.ErrorCode
}
var vigiloErr *VigiloAuthError
if errors.As(err, &vigiloErr) {
return vigiloErr.ErrorCode
}
return ""
}
func SystemErrorCode(err error) string {
if err == nil {
return ""
}
if vigiloErr, ok := err.(*VigiloAuthError); ok { //nolint:errorlint
return vigiloErr.SystemCode
}
var vigiloErr *VigiloAuthError
if errors.As(err, &vigiloErr) {
return vigiloErr.SystemCode
}
return ""
}
package handlers
import (
"net/http"
"net/url"
"strconv"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
type AdminHandler struct {
auditLogger domain.AuditLogger
logger *config.Logger
module string
}
func NewAdminHandler(auditLogger domain.AuditLogger) *AdminHandler {
return &AdminHandler{
auditLogger: auditLogger,
logger: config.GetServerConfig().Logger(),
module: "Admin Handler",
}
}
func (h *AdminHandler) GetAuditEvents(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[GetAuditEvents]: Processing request")
query := r.URL.Query()
limit, err := strconv.Atoi(query.Get("limit"))
if err != nil {
limit = 100
}
offset, err := strconv.Atoi(query.Get("offset"))
if err != nil {
offset = 0
}
filters := h.buildFilters(w, query)
fromStr := query.Get("from")
toStr := query.Get("to")
events, err := h.auditLogger.GetAuditEvents(ctx, filters, fromStr, toStr, limit, offset)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to retrieve audit events")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, events)
}
func (h *AdminHandler) buildFilters(w http.ResponseWriter, query url.Values) map[string]any {
filters := make(map[string]any)
if userID := query.Get("UserID"); userID != "" {
filters["UserID"] = userID
}
if eventType := query.Get("EventType"); eventType != "" {
filters["EventType"] = eventType
}
if successStr := query.Get("Success"); successStr != "" {
success, err := strconv.ParseBool(successStr)
if err != nil {
web.WriteError(w, errors.New(errors.ErrCodeInvalidInput, "invalid 'Success' boolean"))
return nil
}
filters["Success"] = success
}
if ip := query.Get("IP"); ip != "" {
filters["IP"] = ip
}
if requestID := query.Get("RequestID"); requestID != "" {
filters["RequestID"] = requestID
}
if sessionID := query.Get("SessionID"); sessionID != "" {
filters["SessionID"] = sessionID
}
return filters
}
package handlers
import (
"net/http"
"net/url"
"github.com/vigiloauth/vigilo/v2/idp/config"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// AuthorizationHandler handles HTTP requests related to authorization.
type AuthorizationHandler struct {
clientAuthorization client.ClientAuthorization
logger *config.Logger
module string
}
// NewAuthorizationHandler creates a new AuthorizationHandler instance.
// It initializes the handler with the provided authorization and session services.
func NewAuthorizationHandler(clientAuthorization client.ClientAuthorization) *AuthorizationHandler {
return &AuthorizationHandler{
clientAuthorization: clientAuthorization,
logger: config.GetServerConfig().Logger(),
module: "Authorization Handler",
}
}
// AuthorizeClient is the HTTP handler responsible for the authorization code flow.
// It retrieves authorization parameters from the request, verifies the user's session,
// and delegates the authorization logic to the AuthorizationService.
//
// Parameters:
//
// - w: http.ResponseWriter for writing the HTTP response.
// - r: *http.Request containing the authorization request parameters.
//
// It handles login redirection, authorization code generation, and consent verification.
// If authorization is successful, it redirects the user to the redirect URI with the authorization code.
// If an error occurs, it writes an appropriate error response.
func (h *AuthorizationHandler) AuthorizeClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[AuthorizeClient]: Processing request")
var query url.Values
switch r.Method {
case http.MethodGet:
query = r.URL.Query()
case http.MethodPost:
if err := r.ParseForm(); err != nil {
h.logger.Error(h.module, requestID, "[AuthorizeClient]: Failed to parse form: %v", err)
web.WriteError(w, errors.New(errors.ErrCodeInvalidRequest, "invalid form data"))
return
}
query = r.Form
}
req := client.NewClientAuthorizationRequest(query)
req.HTTPWriter = w
req.HTTPRequest = r
redirectURL, err := h.clientAuthorization.Authorize(ctx, req)
if err != nil {
if errors.ErrorCode(err) == errors.ErrCodeInvalidRedirectURI {
web.RenderErrorPage(w, r, errors.ErrorCode(err), req.RedirectURI)
return
}
wrappedErr := errors.Wrap(err, "", "failed to authorize client")
h.logger.Error(h.module, requestID, "[AuthorizeClient]: Failed to authorize client: %v", err)
web.WriteError(w, wrappedErr)
return
}
h.logger.Info(h.module, requestID, "[AuthorizeClient]: Successfully processed request")
http.Redirect(w, r, redirectURL, http.StatusFound)
}
package handlers
import (
"context"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// ClientHandler handles HTTP requests related to client operations.
type ClientHandler struct {
creator client.ClientCreator
manager client.ClientManager
logger *config.Logger
module string
}
func NewClientHandler(
creator client.ClientCreator,
manager client.ClientManager,
) *ClientHandler {
return &ClientHandler{
creator: creator,
manager: manager,
logger: config.GetServerConfig().Logger(),
module: "Client Handler",
}
}
// RegisterClient is the HTTP handler for public client registration.
func (h *ClientHandler) RegisterClient(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[RegisterClient]: Processing request")
req, err := web.DecodeJSONRequest[client.ClientRegistrationRequest](w, r)
if err != nil {
h.logger.Error(h.module, requestID, "Failed to decode request body: %v", err)
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
return
}
response, err := h.creator.Register(ctx, req)
if err != nil {
h.logger.Error(h.module, requestID, "Failed to register client: %v", err)
wrappedErr := errors.Wrap(err, "", "failed to register client")
web.WriteError(w, wrappedErr)
return
}
h.logger.Info(h.module, requestID, "[RegisterClient]: Successfully processed request")
web.WriteJSON(w, http.StatusCreated, response)
}
// RegenerateSecret is the HTTP handler for regenerating client secrets.
func (h *ClientHandler) RegenerateSecret(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[RegenerateSecret]: Processing request")
clientID := chi.URLParam(r, constants.ClientIDReqField)
response, err := h.manager.RegenerateClientSecret(ctx, clientID)
if err != nil {
web.WriteError(w, errors.Wrap(err, "", "failed to regenerate client_secret"))
return
}
h.logger.Info(h.module, requestID, "[RegenerateSecret]: Successfully processed request")
web.WriteJSON(w, http.StatusOK, response)
}
func (h *ClientHandler) GetClientByID(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[GetClientByID]: Processing request")
clientID := chi.URLParam(r, constants.ClientIDReqField)
retrievedClient, err := h.manager.GetClientByID(ctx, clientID)
if err != nil {
web.WriteError(w, errors.Wrap(err, "", "failed to retrieve client by ID"))
return
}
response := &client.ClientReadResponse{
ID: retrievedClient.ID,
Name: retrievedClient.Name,
LogoURI: retrievedClient.LogoURI,
}
web.WriteJSON(w, http.StatusOK, response)
}
// ManageClientConfiguration handles client configuration management requests.
// It supports GET, PUT, and DELETE methods to retrieve, update, or delete client configurations.
// The method validates the registration access token and extracts the client ID from the URL.
func (h *ClientHandler) ManageClientConfiguration(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
var registrationAccessToken string
if token := utils.GetValueFromContext(ctx, constants.ContextKeyAccessToken); token != nil {
registrationAccessToken, _ = token.(string)
}
clientID := chi.URLParam(r, constants.ClientIDReqField)
switch r.Method {
case http.MethodGet:
h.getClient(w, clientID, registrationAccessToken, ctx)
case http.MethodPut:
h.updateClient(w, r, clientID, registrationAccessToken, ctx)
case http.MethodDelete:
h.deleteClient(w, clientID, registrationAccessToken, ctx)
default:
err := errors.New(errors.ErrCodeMethodNotAllowed, fmt.Sprintf("method '%s' not allowed for this request", r.Method))
web.WriteError(w, err)
return
}
}
// getClient retrieves client information for the given client ID and registration access token.
// It validates the token and client, then writes the client information as a JSON response.
func (h *ClientHandler) getClient(w http.ResponseWriter, clientID, registrationAccessToken string, ctx context.Context) {
clientInformation, err := h.manager.GetClientInformation(ctx, clientID, registrationAccessToken)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to validate and retrieve client information")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, clientInformation)
}
// updateClient updates the client configuration for the given client ID and registration access token.
// It uses the ValidateAndUpdateClient service method to perform the update.
func (h *ClientHandler) updateClient(w http.ResponseWriter, r *http.Request, clientID, registrationAccessToken string, ctx context.Context) {
request, err := web.DecodeJSONRequest[client.ClientUpdateRequest](w, r)
if err != nil {
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
return
}
clientInformation, err := h.manager.UpdateClientInformation(ctx, clientID, registrationAccessToken, request)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to validate and update client")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, clientInformation)
}
// deleteClient deletes the client configuration for the given client ID and registration access token.
// It uses the ValidateAndDeleteClient service method to perform the deletion.
func (h *ClientHandler) deleteClient(w http.ResponseWriter, clientID, registrationAccessToken string, ctx context.Context) {
if err := h.manager.DeleteClientInformation(ctx, clientID, registrationAccessToken); err != nil {
wrappedErr := errors.Wrap(err, "", "failed to validate and delete client")
web.WriteError(w, wrappedErr)
return
}
web.SetNoStoreHeader(w)
w.WriteHeader(http.StatusNoContent)
}
package handlers
import (
"context"
"net/http"
"net/url"
"strconv"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
consent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// UserHandler handles HTTP requests related to OAuth operations.
type ConsentHandler struct {
sessionService session.SessionService
consentService consent.UserConsentService
jwtConfig *config.TokenConfig
logger *config.Logger
module string
}
// NewConsentHandler creates a new instance of UserHandler.
//
// Parameters:
//
// userService UserService: The user service.
// sessionService Session: The session service.
// consentService ConsentService: The consent service.
//
// Returns:
// *UserHandler: A new UserHandler instance.
func NewConsentHandler(
sessionService session.SessionService,
consentService consent.UserConsentService,
) *ConsentHandler {
return &ConsentHandler{
sessionService: sessionService,
consentService: consentService,
jwtConfig: config.GetServerConfig().TokenConfig(),
logger: config.GetServerConfig().Logger(),
module: "User Consent Handler",
}
}
// HandleUserConsent handles user consent decisions for OAuth authorization
func (h *ConsentHandler) HandleUserConsent(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[UserConsent]: Processing request")
query := r.URL.Query()
clientID := query.Get(constants.ClientIDReqField)
redirectURI := query.Get(constants.RedirectURIReqField)
scope := types.Scope(query.Get(constants.ScopeReqField))
responseType := query.Get(constants.ResponseTypeReqField)
state := query.Get(constants.StateReqField)
nonce := query.Get(constants.NonceReqField)
display := query.Get(constants.DisplayReqField)
acrValues := query.Get(constants.ACRReqField)
claims := query.Get(constants.ClaimsReqField)
if clientID == "" || redirectURI == "" || scope == "" {
web.WriteError(w, errors.New(errors.ErrCodeBadRequest, "missing required parameters"))
return
}
// Check if the user is logged in
userID, err := h.sessionService.GetUserIDFromSession(r)
if err != nil {
h.logger.Error(h.module, requestID, "[UserConsent]: Failed to retrieve user ID from session: %v", err)
oauthLoginURL := h.buildLoginURL(clientID, redirectURI, scope.String(), responseType, state, nonce, display)
http.Redirect(w, r, oauthLoginURL, http.StatusFound)
return
}
switch r.Method {
case http.MethodGet:
h.handleGetConsent(w, r, userID, clientID, redirectURI, scope, responseType, state, nonce, display)
case http.MethodPost:
h.handlePostConsent(w, r, userID, clientID, redirectURI, scope, responseType, state, nonce, display, acrValues, claims)
default:
web.WriteError(w, errors.NewMethodNotAllowedError(r.Method))
}
}
// handleGetConsent handles GET requests for user consent
func (h *ConsentHandler) handleGetConsent(
w http.ResponseWriter,
r *http.Request,
userID string,
clientID string,
redirectURI string,
scope types.Scope,
responseType string,
state string,
nonce string,
display string,
) {
response, err := h.consentService.GetConsentDetails(userID, clientID, redirectURI, state, scope, responseType, nonce, display, r)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to retrieve user consent details")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, response)
}
// handlePostConsent handles POST requests for user consent
func (h *ConsentHandler) handlePostConsent(w http.ResponseWriter,
r *http.Request,
userID string,
clientID string,
redirectURI string,
scope types.Scope,
responseType string,
state string,
nonce string,
display string,
acrValues string,
claims string,
) {
consentRequest, err := web.DecodeJSONRequest[consent.UserConsentRequest](w, r)
if err != nil {
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
web.WriteError(w, err)
return
}
consentRequest.ResponseType = responseType
consentRequest.State = state
consentRequest.Nonce = nonce
consentRequest.Display = display
response, err := h.consentService.ProcessUserConsent(
userID,
clientID,
redirectURI,
scope,
consentRequest,
r,
)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to process user consent")
web.WriteError(w, wrappedErr)
return
}
response.RedirectURI = h.buildOAuthRedirectURL(
clientID,
redirectURI,
scope.String(),
responseType,
state,
nonce,
display,
acrValues,
claims,
response.Approved,
)
web.WriteJSON(w, http.StatusOK, response)
}
func (h *ConsentHandler) buildLoginURL(
clientID string,
redirectURI string,
scope string,
responseType string,
state string,
nonce string,
display string,
) string {
queryParams := url.Values{}
queryParams.Add(constants.ClientIDReqField, clientID)
queryParams.Add(constants.RedirectURIReqField, redirectURI)
queryParams.Add(constants.ScopeReqField, scope)
queryParams.Add(constants.ResponseTypeReqField, responseType)
if state != "" {
queryParams.Add(constants.StateReqField, state)
}
if nonce != "" {
queryParams.Add(constants.NonceReqField, nonce)
}
if display != "" && constants.ValidAuthenticationDisplays[display] {
queryParams.Add(constants.DisplayReqField, display)
} else {
queryParams.Add(constants.DisplayReqField, constants.DisplayPage)
}
return "/authenticate?" + queryParams.Encode()
}
func (h *ConsentHandler) buildOAuthRedirectURL(
clientID string,
redirectURI string,
scope string,
responseType string,
state string,
nonce string,
display string,
acrValues string,
claims string,
approved bool,
) string {
queryParams := url.Values{}
queryParams.Add(constants.ClientIDReqField, clientID)
queryParams.Add(constants.RedirectURIReqField, redirectURI)
queryParams.Add(constants.ConsentApprovedURLValue, strconv.FormatBool(approved))
if state != "" {
queryParams.Add(constants.StateReqField, state)
}
if scope != "" {
queryParams.Add(constants.ScopeReqField, scope)
}
if responseType != "" {
queryParams.Add(constants.ResponseTypeReqField, responseType)
}
if nonce != "" {
queryParams.Add(constants.NonceReqField, nonce)
}
if display != "" {
queryParams.Add(constants.DisplayReqField, display)
}
if acrValues != "" {
queryParams.Add(constants.ACRReqField, acrValues)
}
if claims != "" {
queryParams.Add(constants.ClaimsReqField, claims)
}
return "/identity" + web.OAuthEndpoints.Authorize + "?" + queryParams.Encode()
}
package handlers
import (
"context"
"net/http"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
oidc "github.com/vigiloauth/vigilo/v2/internal/domain/oidc"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// OIDCHandler handles OpenID Connect-related HTTP requests.
type OIDCHandler struct {
oidcService oidc.OIDCService
logger *config.Logger
module string
}
// NewOIDCHandler creates a new instance of OIDCHandler.
//
// Parameters:
// - oidcService oidc.OIDCService: The OIDC service to use.
//
// Returns:
// - *OIDCHandler: A new OIDCHandler instance.
func NewOIDCHandler(oidcService oidc.OIDCService) *OIDCHandler {
return &OIDCHandler{
oidcService: oidcService,
logger: config.GetServerConfig().Logger(),
module: "OIDC Handler",
}
}
// GetUserInfo handles the UserInfo endpoint of the ODIC specification.
//
// Parameters:
// - w http.ResponseWriter: The HTTP response writer.
// - r *http.Request: The HTTP request.
//
// Behavior:
// - Retrieves token claims from the request context.
// - Calls the OIDC service to fetch user information based on token claims.
// - Returns the user information as a JSON response or an error if something goes wrong.
func (h *OIDCHandler) GetUserInfo(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[UserInfo]: Processing request")
claims := utils.GetValueFromContext(ctx, constants.ContextKeyTokenClaims)
if claims == nil {
h.logger.Warn(h.module, requestID, "[UserInfo]: Token claims not found in context")
web.WriteError(w, errors.New(errors.ErrCodeUnauthorized, "invalid or missing access token"))
return
}
tokenClaims, ok := claims.(*token.TokenClaims)
if !ok {
h.logger.Error(h.module, requestID, "[UserInfo]: Invalid token claims type in context")
web.WriteError(w, errors.NewInternalServerError(""))
return
}
userInfoResponse, err := h.oidcService.GetUserInfo(ctx, tokenClaims)
if err != nil {
h.logger.Error(h.module, requestID, "[UserInfo]: An error occurred processing the request: %v", err)
wrappedErr := errors.Wrap(err, "", "failed to retrieve the requested user info")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, userInfoResponse)
}
// GetJWKS handles the JWKS (JSON Web Key Set) endpoint of the OIDC specification.
//
// Parameters:
// - w http.ResponseWriter: The HTTP response writer.
// - r *http.Request: The HTTP request.
//
// Behavior:
// - Calls the OIDC service to retrieve the JWKS.
// - Returns the JWKS as a JSON response.
func (h *OIDCHandler) GetJWKS(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[GetJWKS]: Processing request")
jwks := h.oidcService.GetJwks(ctx)
web.WriteJSON(w, http.StatusOK, jwks)
}
// GetOpenIDConfiguration handles the OpenID Provider Configuration endpoint.
//
// Parameters:
// - w http.ResponseWriter: The HTTP response writer.
// - r *http.Request: The HTTP request.
//
// Behavior:
// - Constructs the OpenID Provider Configuration JSON object.
// - Returns the configuration as a JSON response.
func (h *OIDCHandler) GetOpenIDConfiguration(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[GetOpenIDConfiguration]: Processing request")
URL := config.GetServerConfig().URL()
discoveryJSON := oidc.NewDiscoveryJSON(URL)
web.WriteJSON(w, http.StatusOK, discoveryJSON)
}
package handlers
import (
"context"
"fmt"
"net/http"
"net/url"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
type TokenHandler struct {
grantProcessor token.TokenGrantProcessor
logger *config.Logger
module string
}
func NewTokenHandler(grantProcessor token.TokenGrantProcessor) *TokenHandler {
return &TokenHandler{
grantProcessor: grantProcessor,
logger: config.GetServerConfig().Logger(),
module: "Token Handler",
}
}
func (h *TokenHandler) IntrospectToken(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[IntrospectToken]: Processing request")
err := r.ParseForm()
if err != nil {
web.WriteError(w, errors.NewFormParsingError(err))
return
}
tokenStr := r.FormValue(constants.TokenReqField)
response, err := h.grantProcessor.IntrospectToken(ctx, r, tokenStr)
if err != nil {
h.logger.Error(h.module, requestID, "[IntrospectToken]: Failed to introspect token: %v", err)
wrappedErr := errors.Wrap(err, "", "failed to introspect token")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, response)
}
func (h *TokenHandler) RevokeToken(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[RevokeToken]: Processing request")
err := r.ParseForm()
if err != nil {
web.WriteError(w, errors.NewFormParsingError(err))
return
}
tokenStr := r.FormValue(constants.TokenReqField)
if err := h.grantProcessor.RevokeToken(ctx, r, tokenStr); err != nil {
h.logger.Error(h.module, requestID, "[RevokeToken]: Failed to revoke token: %v", err)
wrappedErr := errors.Wrap(err, "", "failed to revoke token")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, nil)
}
func (h *TokenHandler) IssueTokens(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[IssueTokens]: Processing request")
err := r.ParseForm()
if err != nil {
web.WriteError(w, errors.NewFormParsingError(err))
return
}
clientID, clientSecret, err := h.extractClientCredentials(r)
if err != nil {
h.logger.Error(h.module, requestID, "[IssueTokens]: Invalid client credentials: %v", err)
web.WriteError(w, err)
return
}
requestedGrantType := r.FormValue(constants.GrantTypeReqField)
requestedScopes := types.Scope(r.FormValue(constants.ScopeReqField))
if requestedGrantType == "" {
web.WriteError(w, errors.New(errors.ErrCodeInvalidRequest, "one or more required parameters are missing"))
return
}
switch requestedGrantType {
case constants.ClientCredentialsGrantType:
h.handleClientCredentialsRequest(ctx, w, requestID, clientID, clientSecret, requestedGrantType, requestedScopes)
return
case constants.PasswordGrantType:
h.handlePasswordGrantRequest(ctx, w, r, requestID, clientID, clientSecret, requestedGrantType, requestedScopes)
return
case constants.AuthorizationCodeGrantType:
h.handleAuthorizationCodeTokenExchange(ctx, w, r, requestID, clientID, clientSecret)
return
case constants.RefreshTokenGrantType:
h.handleRefreshTokenRequest(ctx, w, r, requestID, clientID, clientSecret, requestedGrantType, requestedScopes)
default:
h.logger.Warn(h.module, requestID, "[IssueTokens]: Unsupported grant type")
err := errors.New(errors.ErrCodeUnsupportedGrantType, fmt.Sprintf("the provided grant type [%s] is not supported", requestedGrantType))
web.WriteError(w, err)
return
}
}
func (h *TokenHandler) handleClientCredentialsRequest(
ctx context.Context,
w http.ResponseWriter,
requestID string,
clientID string,
clientSecret string,
requestedGrantType string,
requestedScopes types.Scope,
) {
response, err := h.grantProcessor.IssueClientCredentialsToken(ctx, clientID, clientSecret, requestedGrantType, requestedScopes)
if err != nil {
h.logger.Error(h.module, requestID, "Failed to issue token for client credentials grant: %v", err)
web.WriteError(w, errors.Wrap(err, "", "invalid client credentials or unauthorized grant type/scopes"))
return
}
web.WriteJSON(w, http.StatusOK, response)
}
func (h *TokenHandler) handlePasswordGrantRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
requestID string,
clientID string,
clientSecret string,
requestedGrantType string,
requestedScopes types.Scope,
) {
if r.URL.Query().Get(constants.PasswordReqField) != "" {
web.WriteError(w, errors.New(errors.ErrCodeInvalidRequest, "password must not be in the URL"))
return
}
username := r.FormValue(constants.UsernameReqField)
password := r.FormValue(constants.PasswordReqField)
userAuthRequest := &user.UserLoginRequest{
Username: username,
Password: password,
}
tokenResponse, err := h.grantProcessor.IssueResourceOwnerToken(ctx, clientID, clientSecret, requestedGrantType, requestedScopes, userAuthRequest)
if err != nil {
h.logger.Error(h.module, requestID, "Failed to issue tokens for password grant: %v", err)
web.WriteError(w, errors.Wrap(err, "", "invalid credentials or unauthorized grant type/scopes"))
return
}
web.WriteJSON(w, http.StatusOK, tokenResponse)
}
func (h *TokenHandler) handleAuthorizationCodeTokenExchange(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
requestID string,
clientID string,
clientSecret string,
) {
tokenRequest := &token.TokenRequest{
GrantType: r.FormValue(constants.GrantTypeReqField),
AuthorizationCode: r.FormValue(constants.CodeURLValue),
RedirectURI: r.FormValue(constants.RedirectURIReqField),
ClientID: clientID,
State: r.FormValue(constants.StateReqField),
}
codeVerifier := r.FormValue(constants.CodeVerifierReqField)
if codeVerifier != "" {
tokenRequest.CodeVerifier = codeVerifier
}
if clientSecret != "" {
tokenRequest.ClientSecret = clientSecret
}
response, err := h.grantProcessor.ExchangeAuthorizationCode(ctx, tokenRequest)
if err != nil {
h.logger.Error(h.module, requestID, "Failed to generate access and refresh tokens: %v", err)
wrappedErr := errors.Wrap(err, "", "failed to generate access & refresh tokens")
web.WriteError(w, wrappedErr)
return
}
h.logger.Info(h.module, requestID, "Successfully processed request=[TokenExchange]")
web.WriteJSON(w, http.StatusOK, response)
}
func (h *TokenHandler) handleRefreshTokenRequest(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
requestID string,
clientID string,
clientSecret string,
requestedGrantType string,
requestedScopes types.Scope,
) {
refreshToken := r.FormValue(constants.RefreshTokenURLValue)
response, err := h.grantProcessor.RefreshToken(ctx, clientID, clientSecret, requestedGrantType, refreshToken, requestedScopes)
if err != nil {
h.logger.Error(h.module, requestID, "Failed to issue new access token: %v", err)
web.SetNoStoreHeader(w)
web.WriteError(w, errors.Wrap(err, "", "failed to issue new access and refresh tokens"))
return
}
web.SetNoStoreHeader(w)
web.WriteJSON(w, http.StatusOK, response)
}
func (h *TokenHandler) extractClientCredentials(r *http.Request) (string, string, error) {
clientID, clientSecret, err := web.ExtractClientBasicAuth(r)
if err != nil {
clientID = r.FormValue(constants.ClientIDReqField)
clientSecret = r.FormValue(constants.ClientSecretReqField)
if clientID == "" {
return "", "", errors.New(errors.ErrCodeInvalidClient, "missing client identification")
}
}
if decodedSecret, err := url.QueryUnescape(clientSecret); err == nil {
clientSecret = decodedSecret
}
return clientID, clientSecret, nil
}
package handlers
import (
"context"
"net/http"
"net/url"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// UserHandler handles HTTP requests related to user operations.
type UserHandler struct {
creator users.UserCreator
authenticator users.UserAuthenticator
manager users.UserManager
verifier users.UserVerifier
sessionService session.SessionService
logger *config.Logger
module string
}
// NewUserHandler creates a new instance of UserHandler.
//
// Parameters:
//
// userService UserService: The user service.
// passwordResetService PasswordResetService: The password reset service.
// sessionService Session: The session service.
//
// Returns:
//
// *UserHandler: A new UserHandler instance.
func NewUserHandler(
creator users.UserCreator,
authenticator users.UserAuthenticator,
manager users.UserManager,
verifier users.UserVerifier,
sessionService session.SessionService,
) *UserHandler {
return &UserHandler{
creator: creator,
authenticator: authenticator,
manager: manager,
verifier: verifier,
sessionService: sessionService,
logger: config.GetServerConfig().Logger(),
module: "User Handler",
}
}
func (h *UserHandler) Register(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[Register]: Processing request")
request, err := web.DecodeJSONRequest[users.UserRegistrationRequest](w, r)
if err != nil {
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
return
}
if err := request.Validate(); err != nil {
web.WriteError(w, errors.NewRequestValidationError(err))
return
}
user := users.NewUserFromRegistrationRequest(request)
response, err := h.creator.CreateUser(ctx, user)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to create user")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusCreated, response)
}
// Login is the HTTP handler for user login.
// It processes incoming HTTP requests for user login, validates the input,
// logs in the user, and returns a JWT token if successful or a generic error
// message for failed attempts.
func (h *UserHandler) Login(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[Login]: Processing request")
request, err := web.DecodeJSONRequest[users.UserLoginRequest](w, r)
if err != nil {
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
return
}
if err := request.Validate(); err != nil {
web.WriteError(w, errors.NewRequestValidationError(err))
return
}
response, err := h.authenticator.AuthenticateUser(ctx, request)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to authenticate user")
web.WriteError(w, wrappedErr)
return
}
sessionData := &session.SessionData{
UserID: response.UserID,
IPAddress: r.RemoteAddr,
UserAgent: r.UserAgent(),
AuthenticationTime: time.Now().Unix(),
}
if err := h.sessionService.CreateSession(w, r, sessionData); err != nil {
web.WriteError(w, errors.NewSessionCreationError(err))
return
}
web.WriteJSON(w, http.StatusOK, response)
}
// OAuthLogin handles login specifically for the OAuth authorization code flow
// It expects the same login credentials as the regular Login endpoint,
// but processes the OAuth context parameters and redirects accordingly
func (h *UserHandler) OAuthLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[OAuthLogin]: Processing request")
query := r.URL.Query()
clientID := query.Get(constants.ClientIDReqField)
redirectURI := query.Get(constants.RedirectURIReqField)
request, err := web.DecodeJSONRequest[users.UserLoginRequest](w, r)
if err != nil {
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
return
}
response, err := h.authenticator.AuthenticateUser(ctx, request)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to authenticate user")
web.WriteError(w, wrappedErr)
return
}
sessionData := &session.SessionData{
UserID: response.UserID,
IPAddress: r.RemoteAddr,
UserAgent: r.UserAgent(),
AuthenticationTime: time.Now().Unix(),
}
if err := h.sessionService.CreateSession(w, r, sessionData); err != nil {
web.WriteError(w, errors.NewSessionCreationError(err))
return
}
response.OAuthRedirectURL = h.buildOAuthRedirectURL(query, clientID, redirectURI)
web.WriteJSON(w, http.StatusOK, response)
}
// Logout is the HTTP handler for user logout.
// It processes incoming HTTP requests for user logout, validates the JWT token,
// adds the token to the blacklist to prevent further use, and sends an appropriate response.
// If the Authorization header is missing or the token is invalid, it returns an error.
func (h *UserHandler) Logout(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[Logout]: Processing request")
if err := h.sessionService.InvalidateSession(w, r); err != nil {
wrappedErr := errors.Wrap(err, "", "failed to invalidate session")
web.WriteError(w, wrappedErr)
return
}
w.WriteHeader(http.StatusOK)
}
// ResetPassword handles the password reset request.
// It decodes the request body into a UserPasswordResetRequest, validates the request,
// and then calls the passwordResetService to reset the user's password.
func (h *UserHandler) ResetPassword(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[ResetPassword]: Processing request")
request, err := web.DecodeJSONRequest[users.UserPasswordResetRequest](w, r)
if err != nil {
web.WriteError(w, errors.NewRequestBodyDecodingError(err))
return
}
if err := request.Validate(); err != nil {
web.WriteError(w, errors.NewRequestValidationError(err))
return
}
response, err := h.manager.ResetPassword(
ctx,
request.Email,
request.NewPassword,
request.ResetToken,
)
if err != nil {
web.WriteError(w, err)
return
}
web.WriteJSON(w, http.StatusOK, response)
}
func (h *UserHandler) VerifyAccount(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), constants.ThreeSecondTimeout)
defer cancel()
requestID := utils.GetRequestID(ctx)
h.logger.Info(h.module, requestID, "[VerifyAccount]: Processing request")
query := r.URL.Query()
verificationToken := query.Get(constants.TokenReqField)
if err := h.verifier.VerifyEmailAddress(ctx, verificationToken); err != nil {
wrappedErr := errors.Wrap(err, "", "failed to validate user account")
web.WriteError(w, wrappedErr)
return
}
web.WriteJSON(w, http.StatusOK, "")
}
func (h *UserHandler) buildOAuthRedirectURL(query url.Values, clientID, redirectURI string) string {
queryParams := url.Values{}
queryParams.Add(constants.ClientIDReqField, clientID)
queryParams.Add(constants.RedirectURIReqField, redirectURI)
if state := query.Get(constants.StateReqField); state != "" {
queryParams.Add(constants.StateReqField, state)
}
if scope := query.Get(constants.ScopeReqField); scope != "" {
queryParams.Add(constants.ScopeReqField, scope)
}
if responseType := query.Get(constants.ResponseTypeReqField); responseType != "" {
queryParams.Add(constants.ResponseTypeReqField, responseType)
}
if nonce := query.Get(constants.NonceReqField); nonce != "" {
queryParams.Add(constants.NonceReqField, nonce)
}
if approved := query.Get(constants.ConsentApprovedURLValue); approved != "" {
queryParams.Add(constants.ConsentApprovedURLValue, approved)
}
if acrValues := query.Get(constants.ACRReqField); acrValues != "" {
queryParams.Add(constants.ACRReqField, acrValues)
}
if claims := query.Get(constants.ClaimsReqField); claims != "" {
queryParams.Add(constants.ClaimsReqField, claims)
}
return "/identity" + web.OAuthEndpoints.Authorize + "?" + queryParams.Encode()
}
package middleware
import (
"fmt"
"net/http"
"slices"
"strings"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
web "github.com/vigiloauth/vigilo/v2/internal/web"
)
const maxRequestsForStrictRateLimiting int = 3
// Middleware encapsulates middleware functionalities.
type Middleware struct {
tokenParser token.TokenParser
tokenValidator token.TokenValidator
serverConfig *config.ServerConfig
rateLimiter *RateLimiter
logger *config.Logger
module string
}
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(statusCode int) {
rw.statusCode = statusCode
rw.ResponseWriter.WriteHeader(statusCode)
}
func NewMiddleware(
tokenParser token.TokenParser,
tokenValidator token.TokenValidator,
) *Middleware {
serverConfig := config.GetServerConfig()
return &Middleware{
tokenParser: tokenParser,
tokenValidator: tokenValidator,
serverConfig: serverConfig,
rateLimiter: NewRateLimiter(serverConfig.MaxRequestsPerMinute()),
logger: serverConfig.Logger(),
module: "Middleware",
}
}
// AuthMiddleware is a middleware that checks for a valid JWT token in the Authorization header or POST body.
func (m *Middleware) AuthMiddleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
tokenString, authHeaderErr := web.ExtractBearerToken(r)
if r.Method == http.MethodPost && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" {
m.logger.Debug(m.module, requestID, "[AuthMiddleware]: Bearer token not found in header, attempting to check POST body for access_token parameter.")
parseErr := r.ParseForm()
if parseErr != nil {
m.logger.Warn(m.module, requestID, "[AuthMiddleware]: Failed to parse form body: %v", parseErr)
web.WriteError(w, errors.Wrap(parseErr, errors.ErrCodeInvalidRequest, "failed to parse request body"))
return
}
bodyToken := r.Form.Get(constants.AccessTokenPost)
if bodyToken != "" {
m.logger.Debug(m.module, requestID, "[AuthMiddleware]: Found access_token in POST body.")
tokenString = bodyToken
} else {
m.logger.Warn(m.module, requestID, "[AuthMiddleware]: Access token not found in header or POST body.")
web.WriteError(w, errors.New(errors.ErrCodeUnauthorized, "missing or invalid access token"))
return
}
} else if authHeaderErr != nil {
m.logger.Warn(m.module, requestID, "[AuthMiddleware]: Access token not found in header, and POST body check conditions not met or token not found in body.")
web.WriteError(w, errors.Wrap(authHeaderErr, errors.ErrCodeUnauthorized, "missing or invalid authorization header"))
return
}
if tokenString == "" {
m.logger.Warn(m.module, requestID, "[AuthMiddleware]: tokenString is empty after all extraction attempts.")
web.WriteError(w, errors.New(errors.ErrCodeUnauthorized, "missing or invalid access token after extraction attempts"))
return
}
claims, parseErr := m.tokenParser.ParseToken(ctx, tokenString)
if parseErr != nil {
m.logger.Warn(m.module, requestID, "[AuthMiddleware]: Failed to parse token: %s", parseErr)
web.WriteError(w, errors.Wrap(parseErr, errors.ErrCodeTokenParsing, "failed to parse token"))
return
}
m.logger.Debug(m.module, requestID, "[AuthMiddleWare]: Attempting to validate token")
if validateErr := m.tokenValidator.ValidateToken(ctx, tokenString); validateErr != nil {
m.logger.Warn(m.module, requestID, "[AuthMiddleware]: Failed to validate token: %s", validateErr)
web.WriteError(w, errors.Wrap(validateErr, errors.ErrCodeUnauthorized, "an error occurred validating the access token"))
return
}
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyTokenClaims, claims)
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyAccessToken, tokenString)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func (m *Middleware) RequestLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
clientIP := r.RemoteAddr
if forwarderFor := r.Header.Get(constants.XForwardedHeader); forwarderFor != "" {
clientIP = strings.Split(forwarderFor, ",")[0]
}
authHeader := r.Header.Get(constants.AuthorizationHeader)
if authHeader != "" {
parts := strings.Split(authHeader, " ")
if len(parts) > 1 {
authHeader = fmt.Sprintf("%s:%s", parts[0], parts[1])
}
}
wrappedWriter := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK, // Default to 200 OK
}
next.ServeHTTP(wrappedWriter, r)
duration := time.Since(startTime)
userAgent := r.Header.Get("User-Agent")
m.logger.Debug(m.module, utils.GetRequestID(r.Context()),
"Method=[%s] | URL=[%s] | Status=[%d] | IP=[%s] | Duration=[%v] | User-Agent=[%s] | Auth=[%v]",
r.Method, r.URL.Path, wrappedWriter.statusCode, clientIP, duration, userAgent, authHeader != "",
)
})
}
// WithRole is a middleware that checks if an access token has sufficient privileges to access resources.
func (m *Middleware) WithRole(requiredRole string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
var claims *token.TokenClaims
if val := utils.GetValueFromContext(ctx, constants.ContextKeyTokenClaims); val != nil {
claims, _ = val.(*token.TokenClaims)
} else {
m.logger.Error(m.module, requestID, "[WithRole]: An error occurred accessing token from context")
web.WriteError(w, errors.NewInternalServerError(""))
return
}
roles := strings.Split(claims.Roles, " ")
hasRole := slices.Contains(roles, requiredRole)
if !hasRole {
err := errors.New(errors.ErrCodeInsufficientRole, "the request requires higher privileges than provided by the access token")
web.WriteError(w, err)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RedirectToHTTPS is a middleware that redirects HTTP requests to HTTPS.
func (m *Middleware) RedirectToHTTPS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
if r.TLS == nil {
m.logger.Debug(m.module, requestID, "[RedirectToHTTPS]: Redirecting request to HTTPS")
redirectToHttps(w, r)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RateLimit is a middleware that limits the number of requests based on the rate limiter.
func (m *Middleware) RateLimit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
if !m.rateLimiter.Allow(requestID) {
m.logger.Warn(m.module, requestID, "[RateLimit]: Rate limit exceeded for url=[%s]", r.URL.Path)
err := errors.New(errors.ErrCodeRequestLimitExceeded, "too many requests")
web.WriteError(w, err)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// StrictRateLimit applies stricter rate limiting for sensitive operations.
func (m *Middleware) StrictRateLimit(next http.Handler) http.Handler {
strictLimiter := NewRateLimiter(maxRequestsForStrictRateLimiting)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
if !strictLimiter.Allow(requestID) {
m.logger.Warn(m.module, requestID, "[StrictRateLimit]: Strict rate limit exceeded for url=[%s]", r.URL.Path)
err := errors.New(errors.ErrCodeRequestLimitExceeded, "rate limit exceeded for sensitive operations")
web.WriteError(w, err)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireRequestMethod checks to see if a request method is valid for a request
func (m *Middleware) RequireRequestMethod(requestMethod string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
if r.Method != requestMethod {
m.logger.Warn(m.module, requestID, "[RequireRequestMethod]: Invalid request method received for url=%s", r.URL.Path)
err := errors.New(errors.ErrCodeMethodNotAllowed, fmt.Sprintf("method %s not allowed for this request", r.Method))
web.WriteError(w, err)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// WithContextValues ensures each request has contains a request ID, the UserAgent, remote address, and the header.
func (m *Middleware) WithContextValues(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := r.Header.Get(constants.RequestIDHeader)
if requestID == "" {
requestID = constants.RequestIDPrefix + utils.GenerateUUID()
}
w.Header().Set(constants.RequestIDHeader, requestID)
ipAddress := r.RemoteAddr
if forwardedFor := r.Header.Get(constants.XForwardedHeader); forwardedFor != "" {
ipAddress = forwardedFor
}
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyIPAddress, ipAddress)
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyRequestID, requestID)
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyUserAgent, r.UserAgent())
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequiresContentType creates middleware that validates the request Content-Type header
// against the provided contentType (e.g., "application/json")
func (m *Middleware) RequiresContentType(contentType string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
if r.Method == http.MethodOptions {
m.logger.Debug(m.module, requestID, "[RequiresContentType]: Passing request to next handler")
next.ServeHTTP(w, r)
return
}
if r.Method == http.MethodGet || r.Method == http.MethodHead {
m.logger.Debug(m.module, requestID, "[RequiresContentType]: Passing request to next handler")
next.ServeHTTP(w, r)
return
}
ct := r.Header.Get("Content-Type")
if ct == "" {
m.logger.Warn(m.module, requestID, "[RequiresContentType]: Content-Type header is missing in request")
err := errors.New(errors.ErrCodeInvalidContentType, "Content-Type header is required")
web.WriteError(w, err)
return
}
if !strings.HasPrefix(ct, contentType) {
err := errors.New(
errors.ErrCodeInvalidContentType,
fmt.Sprintf("unsupported Content-Type, expected: %s", contentType),
)
m.logger.Warn(m.module, requestID, "[RequiresContentType]: Unsupported Content-Type=[%s] received for request", contentType)
web.WriteError(w, err)
return
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// redirectToHttps redirects an HTTP request to HTTPS.
func redirectToHttps(w http.ResponseWriter, r *http.Request) {
host := r.Host
target := "https://" + host + r.URL.Path
if len(r.URL.RawQuery) > 0 {
target += "?" + r.URL.RawQuery
}
http.Redirect(w, r, target, http.StatusPermanentRedirect)
}
package middleware
import (
"sync"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
)
// RateLimiter implements a token bucket rate limiting algorithm.
// RateLimiter implements a token bucket rate limiting algorithm.
type RateLimiter struct {
rate float64 // Number of tokens to add per second
capacity float64 // Maximum number of tokens that can be stored
tokens float64 // Current number of available tokens
lastUpdate time.Time // Timestamp of the last token update
mu sync.Mutex // Mutex to protect concurrent access
logger *config.Logger
module string
}
// NewRateLimiter creates a new RateLimiter instance.
func NewRateLimiter(rate int) *RateLimiter {
return &RateLimiter{
rate: float64(rate),
capacity: float64(rate),
tokens: float64(rate),
lastUpdate: time.Now(),
logger: config.GetServerConfig().Logger(),
module: "Rate Limiter",
}
}
// Allow checks if a request is allowed based on the rate limit.
func (rl *RateLimiter) Allow(requestID string) bool {
rl.logger.Debug(rl.module, requestID, "[Allow]: Verifying if the request exceeds the rate limit")
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
elapsed := now.Sub(rl.lastUpdate).Seconds()
rl.lastUpdate = now
// Add tokens based on the elapsed time
rl.tokens += elapsed * rl.rate
// Cap tokens at the maximum capacity
if rl.tokens > rl.capacity {
rl.tokens = rl.capacity
}
if rl.tokens >= 1.0 {
rl.logger.Debug(rl.module, requestID, "[Allow]: Request is valid")
rl.tokens -= 1.0 // Consume a token
return true
}
rl.logger.Warn(rl.module, requestID, "[Allow]: Request is invalid as it exceeds the rate limit")
return false
}
package mocks
import (
"context"
"time"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
)
var _ domain.AuditRepository = (*MockAuditRepository)(nil)
type MockAuditRepository struct {
StoreAuditEventFunc func(ctx context.Context, event *domain.AuditEvent) error
GetAuditEventsFunc func(ctx context.Context, filters map[string]any, from time.Time, to time.Time, limit, offset int) ([]*domain.AuditEvent, error)
DeleteEventFunc func(ctx context.Context, eventID string) error
}
func (m *MockAuditRepository) StoreAuditEvent(ctx context.Context, event *domain.AuditEvent) error {
return m.StoreAuditEventFunc(ctx, event)
}
func (m *MockAuditRepository) GetAuditEvents(ctx context.Context, filters map[string]any, from time.Time, to time.Time, limit, offset int) ([]*domain.AuditEvent, error) {
return m.GetAuditEventsFunc(ctx, filters, from, to, limit, offset)
}
func (m *MockAuditRepository) DeleteEvent(ctx context.Context, eventID string) error {
return m.DeleteEventFunc(ctx, eventID)
}
package mocks
import (
"context"
"time"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
)
var _ domain.AuditLogger = (*MockAuditLogger)(nil)
type MockAuditLogger struct {
StoreEventFunc func(ctx context.Context, eventType domain.EventType, success bool, action domain.ActionType, method domain.MethodType, err error)
DeleteOldEventsFunc func(cts context.Context, olderThan time.Time) error
GetAuditEventsFunc func(ctx context.Context, filters map[string]any, from string, to string, limit, offset int) ([]*domain.AuditEvent, error)
}
func (m *MockAuditLogger) StoreEvent(ctx context.Context, eventType domain.EventType, success bool, action domain.ActionType, method domain.MethodType, err error) {
m.StoreEventFunc(ctx, eventType, success, action, method, err)
}
func (m *MockAuditLogger) DeleteOldEvents(ctx context.Context, olderThan time.Time) error {
return m.DeleteOldEventsFunc(ctx, olderThan)
}
func (m *MockAuditLogger) GetAuditEvents(ctx context.Context, filters map[string]any, from string, to string, limit, offset int) ([]*domain.AuditEvent, error) {
return m.GetAuditEventsFunc(ctx, filters, from, to, limit, offset)
}
package mocks
import (
"context"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authorization"
authzCode "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ authz.AuthorizationService = (*MockAuthorizationService)(nil)
type MockAuthorizationService struct {
AuthorizeClientFunc func(ctx context.Context, authorizationRequest *client.ClientAuthorizationRequest) (string, error)
AuthorizeTokenExchangeFunc func(ctx context.Context, tokenRequest *token.TokenRequest) (*authzCode.AuthorizationCodeData, error)
AuthorizeUserInfoRequestFunc func(ctx context.Context, accessTokenClaims *token.TokenClaims) (*user.User, error)
UpdateAuthorizationCodeFunc func(ctx context.Context, authzCode *authzCode.AuthorizationCodeData) error
}
func (m *MockAuthorizationService) AuthorizeClient(ctx context.Context, authorizationRequest *client.ClientAuthorizationRequest) (string, error) {
return m.AuthorizeClientFunc(ctx, authorizationRequest)
}
func (m *MockAuthorizationService) AuthorizeTokenExchange(ctx context.Context, tokenRequest *token.TokenRequest) (*authzCode.AuthorizationCodeData, error) {
return m.AuthorizeTokenExchangeFunc(ctx, tokenRequest)
}
func (m *MockAuthorizationService) AuthorizeUserInfoRequest(ctx context.Context, accessTokenClaims *token.TokenClaims) (*user.User, error) {
return m.AuthorizeUserInfoRequestFunc(ctx, accessTokenClaims)
}
func (m *MockAuthorizationService) UpdateAuthorizationCode(ctx context.Context, authData *authzCode.AuthorizationCodeData) error {
return m.UpdateAuthorizationCodeFunc(ctx, authData)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.AuthorizationCodeCreator = (*MockAuthorizationCodeCreator)(nil)
type MockAuthorizationCodeCreator struct {
GenerateAuthorizationCodeFunc func(ctx context.Context, request *client.ClientAuthorizationRequest) (string, error)
}
func (m *MockAuthorizationCodeCreator) GenerateAuthorizationCode(ctx context.Context, request *client.ClientAuthorizationRequest) (string, error) {
return m.GenerateAuthorizationCodeFunc(ctx, request)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.AuthorizationCodeIssuer = (*MockAuthorizationCodeIssuer)(nil)
type MockAuthorizationCodeIssuer struct {
IssueAuthorizationCodeFunc func(ctx context.Context, req *client.ClientAuthorizationRequest) (string, error)
}
func (m *MockAuthorizationCodeIssuer) IssueAuthorizationCode(ctx context.Context, req *client.ClientAuthorizationRequest) (string, error) {
return m.IssueAuthorizationCodeFunc(ctx, req)
}
package mocks
import (
"context"
authzCode "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
)
var _ authzCode.AuthorizationCodeManager = (*MockAuthorizationCodeManager)(nil)
type MockAuthorizationCodeManager struct {
RevokeAuthorizationCodeFunc func(ctx context.Context, code string) error
GetAuthorizationCodeFunc func(ctx context.Context, code string) (*authzCode.AuthorizationCodeData, error)
UpdateAuthorizationCodeFunc func(ctx context.Context, authData *authzCode.AuthorizationCodeData) error
}
func (m *MockAuthorizationCodeManager) RevokeAuthorizationCode(ctx context.Context, code string) error {
return m.RevokeAuthorizationCodeFunc(ctx, code)
}
func (m *MockAuthorizationCodeManager) GetAuthorizationCode(ctx context.Context, code string) (*authzCode.AuthorizationCodeData, error) {
return m.GetAuthorizationCodeFunc(ctx, code)
}
func (m *MockAuthorizationCodeManager) UpdateAuthorizationCode(ctx context.Context, authData *authzCode.AuthorizationCodeData) error {
return m.UpdateAuthorizationCodeFunc(ctx, authData)
}
package mocks
import (
"context"
"time"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
)
var _ authz.AuthorizationCodeRepository = (*MockAuthorizationCodeRepository)(nil)
type MockAuthorizationCodeRepository struct {
StoreAuthorizationCodeFunc func(ctx context.Context, code string, data *authz.AuthorizationCodeData, expiresAt time.Time) error
GetAuthorizationCodeFunc func(ctx context.Context, code string) (*authz.AuthorizationCodeData, error)
DeleteAuthorizationCodeFunc func(ctx context.Context, code string) error
UpdateAuthorizationCodeFunc func(ctx context.Context, code string, authData *authz.AuthorizationCodeData) error
}
func (m *MockAuthorizationCodeRepository) StoreAuthorizationCode(ctx context.Context, code string, data *authz.AuthorizationCodeData, expiresAt time.Time) error {
return m.StoreAuthorizationCodeFunc(ctx, code, data, expiresAt)
}
func (m *MockAuthorizationCodeRepository) GetAuthorizationCode(ctx context.Context, code string) (*authz.AuthorizationCodeData, error) {
return m.GetAuthorizationCodeFunc(ctx, code)
}
func (m *MockAuthorizationCodeRepository) DeleteAuthorizationCode(ctx context.Context, code string) error {
return m.DeleteAuthorizationCodeFunc(ctx, code)
}
func (m *MockAuthorizationCodeRepository) UpdateAuthorizationCode(ctx context.Context, code string, authData *authz.AuthorizationCodeData) error {
return m.UpdateAuthorizationCodeFunc(ctx, code, authData)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.AuthorizationCodeValidator = (*MockAuthorizationCodeRequestValidator)(nil)
type MockAuthorizationCodeRequestValidator struct {
ValidateRequestFunc func(ctx context.Context, req *client.ClientAuthorizationRequest) error
ValidateAuthorizationCodeFunc func(ctx context.Context, code, clientID, redirectURI string) error
ValidatePKCEFunc func(ctx context.Context, authzCodeData *domain.AuthorizationCodeData, codeVerifier string) error
}
func (m *MockAuthorizationCodeRequestValidator) ValidateRequest(ctx context.Context, req *client.ClientAuthorizationRequest) error {
return m.ValidateRequestFunc(ctx, req)
}
func (m *MockAuthorizationCodeRequestValidator) ValidateAuthorizationCode(ctx context.Context, code, clientID, redirectURI string) error {
return m.ValidateAuthorizationCodeFunc(ctx, code, clientID, redirectURI)
}
func (m *MockAuthorizationCodeRequestValidator) ValidatePKCE(ctx context.Context, authzCodeData *domain.AuthorizationCodeData, codeVerifier string) error {
return m.ValidatePKCEFunc(ctx, authzCodeData, codeVerifier)
}
package mocks
import (
"context"
"net/http"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/client"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
var _ domain.ClientAuthenticator = (*MockClientAuthenticator)(nil)
type MockClientAuthenticator struct {
AuthenticateRequestFunc func(ctx context.Context, r *http.Request, requiredScope types.Scope) error
AuthenticateClientFunc func(ctx context.Context, req *domain.ClientAuthenticationRequest) error
}
func (m *MockClientAuthenticator) AuthenticateRequest(ctx context.Context, r *http.Request, requiredScope types.Scope) error {
return m.AuthenticateRequestFunc(ctx, r, requiredScope)
}
func (m *MockClientAuthenticator) AuthenticateClient(ctx context.Context, req *domain.ClientAuthenticationRequest) error {
return m.AuthenticateClientFunc(ctx, req)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.ClientAuthorization = (*MockClientAuthorization)(nil)
type MockClientAuthorization struct {
AuthorizeFunc func(ctx context.Context, request *domain.ClientAuthorizationRequest) (string, error)
}
func (m *MockClientAuthorization) Authorize(ctx context.Context, request *domain.ClientAuthorizationRequest) (string, error) {
return m.AuthorizeFunc(ctx, request)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.ClientCreator = (*MockClientCreator)(nil)
type MockClientCreator struct {
RegisterFunc func(ctx context.Context, client *domain.ClientRegistrationRequest) (*domain.ClientRegistrationResponse, error)
}
func (m *MockClientCreator) Register(ctx context.Context, client *domain.ClientRegistrationRequest) (*domain.ClientRegistrationResponse, error) {
return m.RegisterFunc(ctx, client)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.ClientManager = (*MockClientManager)(nil)
type MockClientManager struct {
RegenerateClientSecretFunc func(ctx context.Context, clientID string) (*domain.ClientSecretRegenerationResponse, error)
GetClientByIDFunc func(ctx context.Context, clientID string) (*domain.Client, error)
GetClientInformationFunc func(ctx context.Context, clientID string, registrationAccessToken string) (*domain.ClientInformationResponse, error)
UpdateClientInformationFunc func(ctx context.Context, clientID string, registrationAccessToken string, req *domain.ClientUpdateRequest) (*domain.ClientInformationResponse, error)
DeleteClientInformationFunc func(ctx context.Context, clientID string, registrationAccessToken string) error
}
func (m *MockClientManager) RegenerateClientSecret(ctx context.Context, clientID string) (*domain.ClientSecretRegenerationResponse, error) {
return m.RegenerateClientSecretFunc(ctx, clientID)
}
func (m *MockClientManager) GetClientByID(ctx context.Context, clientID string) (*domain.Client, error) {
return m.GetClientByIDFunc(ctx, clientID)
}
func (m *MockClientManager) GetClientInformation(ctx context.Context, clientID string, registrationAccessToken string) (*domain.ClientInformationResponse, error) {
return m.GetClientInformationFunc(ctx, clientID, registrationAccessToken)
}
func (m *MockClientManager) UpdateClientInformation(ctx context.Context, clientID string, registrationAccessToken string, req *domain.ClientUpdateRequest) (*domain.ClientInformationResponse, error) {
return m.UpdateClientInformationFunc(ctx, clientID, registrationAccessToken, req)
}
func (m *MockClientManager) DeleteClientInformation(ctx context.Context, clientID string, registrationAccessToken string) error {
return m.DeleteClientInformationFunc(ctx, clientID, registrationAccessToken)
}
package mocks
import (
"context"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ client.ClientRepository = (*MockClientRepository)(nil)
type MockClientRepository struct {
SaveClientFunc func(ctx context.Context, client *client.Client) error
GetClientByIDFunc func(ctx context.Context, clientID string) (*client.Client, error)
DeleteClientByIDFunc func(ctx context.Context, clientID string) error
UpdateClientFunc func(ctx context.Context, client *client.Client) error
IsExistingIDFunc func(ctx context.Context, clientID string) bool
}
func (m *MockClientRepository) SaveClient(ctx context.Context, client *client.Client) error {
return m.SaveClientFunc(ctx, client)
}
func (m *MockClientRepository) GetClientByID(ctx context.Context, clientID string) (*client.Client, error) {
return m.GetClientByIDFunc(ctx, clientID)
}
func (m *MockClientRepository) DeleteClientByID(ctx context.Context, clientID string) error {
return m.DeleteClientByIDFunc(ctx, clientID)
}
func (m *MockClientRepository) UpdateClient(ctx context.Context, client *client.Client) error {
return m.UpdateClientFunc(ctx, client)
}
func (m *MockClientRepository) IsExistingID(ctx context.Context, clientID string) bool {
return m.IsExistingIDFunc(ctx, clientID)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/client"
)
var _ domain.ClientValidator = (*MockClientValidator)(nil)
type MockClientValidator struct {
ValidateRegistrationRequestFunc func(ctx context.Context, req *domain.ClientRegistrationRequest) error
ValidateUpdateRequestFunc func(ctx context.Context, req *domain.ClientUpdateRequest) error
ValidateAuthorizationRequestFunc func(ctx context.Context, req *domain.ClientAuthorizationRequest) error
ValidateRedirectURIFunc func(ctx context.Context, redirectURI string, existingClient *domain.Client) error
ValidateClientAndRegistrationAccessTokenFunc func(ctx context.Context, clientID string, registrationAccessToken string) error
}
func (m *MockClientValidator) ValidateRegistrationRequest(ctx context.Context, req *domain.ClientRegistrationRequest) error {
return m.ValidateRegistrationRequestFunc(ctx, req)
}
func (m *MockClientValidator) ValidateUpdateRequest(ctx context.Context, req *domain.ClientUpdateRequest) error {
return m.ValidateUpdateRequestFunc(ctx, req)
}
func (m *MockClientValidator) ValidateAuthorizationRequest(ctx context.Context, req *domain.ClientAuthorizationRequest) error {
return m.ValidateAuthorizationRequestFunc(ctx, req)
}
func (m *MockClientValidator) ValidateRedirectURI(ctx context.Context, redirectURI string, existingClient *domain.Client) error {
return m.ValidateRedirectURIFunc(ctx, redirectURI, existingClient)
}
func (m *MockClientValidator) ValidateClientAndRegistrationAccessToken(ctx context.Context, clientID string, registrationAccessToken string) error {
return m.ValidateClientAndRegistrationAccessTokenFunc(ctx, clientID, registrationAccessToken)
}
package mocks
import (
"context"
"net/http"
"time"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/cookies"
)
var _ domain.HTTPCookieService = (*MockHTTPCookieService)(nil)
type MockHTTPCookieService struct {
SetSessionCookieFunc func(ctx context.Context, w http.ResponseWriter, token string, expirationTime time.Duration)
ClearSessionCookieFunc func(ctx context.Context, w http.ResponseWriter)
GetSessionCookieFunc func(r *http.Request) (*http.Cookie, error)
}
func (m *MockHTTPCookieService) SetSessionCookie(ctx context.Context, w http.ResponseWriter, token string, expirationTime time.Duration) {
m.SetSessionCookieFunc(ctx, w, token, expirationTime)
}
func (m *MockHTTPCookieService) ClearSessionCookie(ctx context.Context, w http.ResponseWriter) {
m.ClearSessionCookieFunc(ctx, w)
}
func (m *MockHTTPCookieService) GetSessionCookie(r *http.Request) (*http.Cookie, error) {
return m.GetSessionCookieFunc(r)
}
package mocks
import domain "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
var _ domain.Cryptographer = (*MockCryptographer)(nil)
type MockCryptographer struct {
EncryptStringFunc func(plainStr, secretKey string) (string, error)
DecryptStringFunc func(encryptedStr, secretKey string) (string, error)
EncryptBytesFunc func(plainBytes []byte, secretKey string) (string, error)
DecryptBytesFunc func(encryptedBytes, secretKey string) ([]byte, error)
HashStringFunc func(plainStr string) (string, error)
GenerateRandomStringFunc func(length int) (string, error)
}
func (m *MockCryptographer) EncryptString(plainStr string, secretKey string) (string, error) {
return m.EncryptStringFunc(plainStr, secretKey)
}
func (m *MockCryptographer) DecryptString(encryptedStr string, secretKey string) (string, error) {
return m.DecryptStringFunc(encryptedStr, secretKey)
}
func (m *MockCryptographer) EncryptBytes(plainBytes []byte, secretKey string) (string, error) {
return m.EncryptBytesFunc(plainBytes, secretKey)
}
func (m *MockCryptographer) DecryptBytes(encryptedBytes, secretKey string) ([]byte, error) {
return m.DecryptBytesFunc(encryptedBytes, secretKey)
}
func (m *MockCryptographer) HashString(plainStr string) (string, error) {
return m.HashStringFunc(plainStr)
}
func (m *MockCryptographer) GenerateRandomString(length int) (string, error) {
return m.GenerateRandomStringFunc(length)
}
package mocks
import (
domain "github.com/vigiloauth/vigilo/v2/internal/domain/email"
"gopkg.in/gomail.v2"
)
var _ domain.Mailer = (*MockGoMailer)(nil)
type MockGoMailer struct {
DialFunc func(host string, port int, username string, password string) (gomail.SendCloser, error)
DialAndSendFunc func(host string, port int, username string, password string, message ...*gomail.Message) error
NewMessageFunc func(request *domain.EmailRequest, body string, subject string, fromAddress string) *gomail.Message
}
func (m *MockGoMailer) Dial(host string, port int, username string, password string) (gomail.SendCloser, error) {
return m.DialFunc(host, port, username, password)
}
func (m *MockGoMailer) DialAndSend(host string, port int, username string, password string, message ...*gomail.Message) error {
return m.DialAndSendFunc(host, port, username, password, message...)
}
func (m *MockGoMailer) NewMessage(request *domain.EmailRequest, body string, subject string, fromAddress string) *gomail.Message {
return m.NewMessageFunc(request, body, subject, fromAddress)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/email"
)
var _ domain.EmailService = (*MockEmailService)(nil)
type MockEmailService struct {
SendEmailFunc func(ctx context.Context, request *domain.EmailRequest) error
TestConnectionFunc func() error
GetEmailRetryQueueFunc func() *domain.EmailRetryQueue
}
func (m *MockEmailService) SendEmail(ctx context.Context, request *domain.EmailRequest) error {
return m.SendEmailFunc(ctx, request)
}
func (m *MockEmailService) TestConnection() error {
return m.TestConnectionFunc()
}
func (m *MockEmailService) GetEmailRetryQueue() *domain.EmailRetryQueue {
return m.GetEmailRetryQueueFunc()
}
package mocks
import (
"context"
jwt "github.com/vigiloauth/vigilo/v2/internal/domain/jwt"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
)
var _ jwt.JWTService = (*MockJWTService)(nil)
type MockJWTService struct {
ParseWithClaimsFunc func(ctx context.Context, tokenString string) (*tokens.TokenClaims, error)
SignTokenFunc func(ctx context.Context, claims *tokens.TokenClaims) (string, error)
}
func (m *MockJWTService) ParseWithClaims(ctx context.Context, tokenString string) (*tokens.TokenClaims, error) {
return m.ParseWithClaimsFunc(ctx, tokenString)
}
func (m *MockJWTService) SignToken(ctx context.Context, claims *tokens.TokenClaims) (string, error) {
return m.SignTokenFunc(ctx, claims)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/login"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ domain.LoginAttemptRepository = (*MockLoginAttemptRepository)(nil)
type MockLoginAttemptRepository struct {
SaveLoginAttemptFunc func(ctx context.Context, attempt *user.UserLoginAttempt) error
GetLoginAttemptsByUserIDFunc func(ctx context.Context, userID string) ([]*user.UserLoginAttempt, error)
}
func (m *MockLoginAttemptRepository) SaveLoginAttempt(ctx context.Context, attempt *user.UserLoginAttempt) error {
return m.SaveLoginAttemptFunc(ctx, attempt)
}
func (m *MockLoginAttemptRepository) GetLoginAttemptsByUserID(ctx context.Context, userID string) ([]*user.UserLoginAttempt, error) {
return m.GetLoginAttemptsByUserIDFunc(ctx, userID)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/login"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ domain.LoginAttemptService = (*MockLoginAttemptService)(nil)
type MockLoginAttemptService struct {
SaveLoginAttemptFunc func(ctx context.Context, attempt *user.UserLoginAttempt) error
GetLoginAttemptsByUserIDFunc func(ctx context.Context, userID string) ([]*user.UserLoginAttempt, error)
HandleFailedLoginAttemptFunc func(ctx context.Context, user *user.User, attempt *user.UserLoginAttempt) error
}
func (m *MockLoginAttemptService) SaveLoginAttempt(ctx context.Context, attempt *user.UserLoginAttempt) error {
return m.SaveLoginAttemptFunc(ctx, attempt)
}
func (m *MockLoginAttemptService) GetLoginAttemptsByUserID(ctx context.Context, userID string) ([]*user.UserLoginAttempt, error) {
return m.GetLoginAttemptsByUserIDFunc(ctx, userID)
}
func (m *MockLoginAttemptService) HandleFailedLoginAttempt(ctx context.Context, user *user.User, attempt *user.UserLoginAttempt) error {
return m.HandleFailedLoginAttemptFunc(ctx, user, attempt)
}
package mocks
import (
"context"
"net/http"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/session"
)
var _ domain.SessionManager = (*MockSessionManager)(nil)
type MockSessionManager struct {
GetUserIDFromSessionFunc func(ctx context.Context, r *http.Request) (string, error)
GetUserAuthenticationTimeFunc func(ctx context.Context, r *http.Request) (int64, error)
}
func (m *MockSessionManager) GetUserIDFromSession(ctx context.Context, r *http.Request) (string, error) {
return m.GetUserIDFromSessionFunc(ctx, r)
}
func (m *MockSessionManager) GetUserAuthenticationTime(ctx context.Context, r *http.Request) (int64, error) {
return m.GetUserAuthenticationTimeFunc(ctx, r)
}
package mocks
import (
"context"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
)
var _ session.SessionRepository = (*MockSessionRepository)(nil)
type MockSessionRepository struct {
SaveSessionFunc func(ctx context.Context, sessionData *session.SessionData) error
GetSessionByIDFunc func(ctx context.Context, sessionID string) (*session.SessionData, error)
UpdateSessionByIDFunc func(ctx context.Context, sessionID string, sessionData *session.SessionData) error
DeleteSessionByIDFunc func(ctx context.Context, sessionID string) error
}
func (m *MockSessionRepository) SaveSession(ctx context.Context, sessionData *session.SessionData) error {
return m.SaveSessionFunc(ctx, sessionData)
}
func (m *MockSessionRepository) GetSessionByID(ctx context.Context, sessionID string) (*session.SessionData, error) {
return m.GetSessionByIDFunc(ctx, sessionID)
}
func (m *MockSessionRepository) UpdateSessionByID(ctx context.Context, sessionID string, sessionData *session.SessionData) error {
return m.UpdateSessionByIDFunc(ctx, sessionID, sessionData)
}
func (m *MockSessionRepository) DeleteSessionByID(ctx context.Context, sessionID string) error {
return m.DeleteSessionByIDFunc(ctx, sessionID)
}
package mocks
import (
"net/http"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
)
var _ session.SessionService = (*MockSessionService)(nil)
type MockSessionService struct {
CreateSessionFunc func(w http.ResponseWriter, r *http.Request, sessionData *session.SessionData) error
InvalidateSessionFunc func(w http.ResponseWriter, r *http.Request) error
GetUserIDFromSessionFunc func(r *http.Request) (string, error)
UpdateSessionFunc func(r *http.Request, sessionData *session.SessionData) error
GetSessionDataFunc func(r *http.Request) (*session.SessionData, error)
}
func (m *MockSessionService) CreateSession(w http.ResponseWriter, r *http.Request, sessionData *session.SessionData) error {
return m.CreateSessionFunc(w, r, sessionData)
}
func (m *MockSessionService) InvalidateSession(w http.ResponseWriter, r *http.Request) error {
return m.InvalidateSessionFunc(w, r)
}
func (m *MockSessionService) GetUserIDFromSession(r *http.Request) (string, error) {
return m.GetUserIDFromSessionFunc(r)
}
func (m *MockSessionService) UpdateSession(r *http.Request, sessionData *session.SessionData) error {
return m.UpdateSessionFunc(r, sessionData)
}
func (m *MockSessionService) GetSessionData(r *http.Request) (*session.SessionData, error) {
return m.GetSessionDataFunc(r)
}
package mocks
import (
"context"
"time"
claims "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
var _ token.TokenCreator = (*MockTokenCreator)(nil)
type MockTokenCreator struct {
CreateAccessTokenFunc func(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string) (string, error)
CreateRefreshTokenFunc func(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string) (string, error)
CreateAccessTokenWithClaimsFunc func(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string, claims *claims.ClaimsRequest) (string, error)
CreateIDTokenFunc func(ctx context.Context, userID string, clientID string, scopes types.Scope, nonce string, acrValues string, authTime time.Time) (string, error)
}
func (m *MockTokenCreator) CreateAccessToken(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string) (string, error) {
return m.CreateAccessTokenFunc(ctx, subject, audience, scopes, roles, nonce)
}
func (m *MockTokenCreator) CreateRefreshToken(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string) (string, error) {
return m.CreateRefreshTokenFunc(ctx, subject, audience, scopes, roles, nonce)
}
func (m *MockTokenCreator) CreateAccessTokenWithClaims(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string, claims *claims.ClaimsRequest) (string, error) {
return m.CreateAccessTokenWithClaimsFunc(ctx, subject, audience, scopes, roles, nonce, claims)
}
func (m *MockTokenCreator) CreateIDToken(ctx context.Context, userID string, clientID string, scopes types.Scope, nonce string, acrValues string, authTime time.Time) (string, error) {
return m.CreateIDTokenFunc(ctx, userID, clientID, scopes, nonce, acrValues, authTime)
}
package mocks
import (
"context"
"net/http"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
var _ token.TokenGrantProcessor = (*MockTokenGrantProcessor)(nil)
type MockTokenGrantProcessor struct {
IssueClientCredentialsTokenFunc func(ctx context.Context, clientID, clientSecret, grantType string, scopes types.Scope) (*token.TokenResponse, error)
IssueResourceOwnerTokenFunc func(ctx context.Context, clientID, clientSecret, grantType string, scopes types.Scope, user *users.UserLoginRequest) (*token.TokenResponse, error)
RefreshTokenFunc func(ctx context.Context, clientID, clientSecret, grantType, refreshToken string, scopes types.Scope) (*token.TokenResponse, error)
ExchangeAuthorizationCodeFunc func(ctx context.Context, req *token.TokenRequest) (*token.TokenResponse, error)
IntrospectTokenFunc func(ctx context.Context, r *http.Request, tokenStr string) (*token.TokenIntrospectionResponse, error)
RevokeTokenFunc func(ctx context.Context, r *http.Request, tokenStr string) error
}
func (m *MockTokenGrantProcessor) IssueClientCredentialsToken(ctx context.Context, clientID, clientSecret, grantType string, scopes types.Scope) (*token.TokenResponse, error) {
return m.IssueClientCredentialsTokenFunc(ctx, clientID, clientSecret, grantType, scopes)
}
func (m *MockTokenGrantProcessor) IssueResourceOwnerToken(ctx context.Context, clientID, clientSecret, grantType string, scopes types.Scope, user *users.UserLoginRequest) (*token.TokenResponse, error) {
return m.IssueResourceOwnerTokenFunc(ctx, clientID, clientSecret, grantType, scopes, user)
}
func (m *MockTokenGrantProcessor) RefreshToken(ctx context.Context, clientID, clientSecret, grantType, refreshToken string, scopes types.Scope) (*token.TokenResponse, error) {
return m.RefreshTokenFunc(ctx, clientID, clientSecret, grantType, refreshToken, scopes)
}
func (m *MockTokenGrantProcessor) ExchangeAuthorizationCode(ctx context.Context, req *token.TokenRequest) (*token.TokenResponse, error) {
return m.ExchangeAuthorizationCodeFunc(ctx, req)
}
func (m *MockTokenGrantProcessor) IntrospectToken(ctx context.Context, r *http.Request, tokenStr string) (*token.TokenIntrospectionResponse, error) {
return m.IntrospectTokenFunc(ctx, r, tokenStr)
}
func (m *MockTokenGrantProcessor) RevokeToken(ctx context.Context, r *http.Request, tokenStr string) error {
return m.RevokeTokenFunc(ctx, r, tokenStr)
}
package mocks
import (
"context"
"time"
claims "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
var _ token.TokenIssuer = (*MockTokenIssuer)(nil)
type MockTokenIssuer struct {
IssueTokenPairFunc func(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string, claims *claims.ClaimsRequest) (string, string, error)
IssueIDTokenFunc func(ctx context.Context, subject string, audience string, scopes types.Scope, nonce string, acrValues string, authTime time.Time) (string, error)
IssueAccessTokenFunc func(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string) (string, error)
}
func (m *MockTokenIssuer) IssueTokenPair(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string, claims *claims.ClaimsRequest) (string, string, error) {
return m.IssueTokenPairFunc(ctx, subject, audience, scopes, roles, nonce, claims)
}
func (m *MockTokenIssuer) IssueIDToken(ctx context.Context, subject string, audience string, scopes types.Scope, nonce string, acrValues string, authTime time.Time) (string, error) {
return m.IssueIDTokenFunc(ctx, subject, audience, scopes, nonce, acrValues, authTime)
}
func (m *MockTokenIssuer) IssueAccessToken(ctx context.Context, subject string, audience string, scopes types.Scope, roles string, nonce string) (string, error) {
return m.IssueAccessTokenFunc(ctx, subject, audience, scopes, roles, nonce)
}
package mocks
import (
"context"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
)
var _ tokens.TokenManager = (*MockTokenManager)(nil)
type MockTokenManager struct {
IntrospectFunc func(ctx context.Context, tokenStr string) *tokens.TokenIntrospectionResponse
RevokeFunc func(ctx context.Context, tokenStr string) error
GetTokenDataFunc func(ctx context.Context, tokenStr string) (*tokens.TokenData, error)
DeleteTokenFunc func(ctx context.Context, token string) error
BlacklistTokenFunc func(ctx context.Context, token string) error
DeleteExpiredTokensFunc func(ctx context.Context) error
}
func (m *MockTokenManager) Introspect(ctx context.Context, tokenStr string) *tokens.TokenIntrospectionResponse {
return m.IntrospectFunc(ctx, tokenStr)
}
func (m *MockTokenManager) Revoke(ctx context.Context, tokenStr string) error {
return m.RevokeFunc(ctx, tokenStr)
}
func (m *MockTokenManager) GetTokenData(ctx context.Context, tokenStr string) (*tokens.TokenData, error) {
return m.GetTokenDataFunc(ctx, tokenStr)
}
func (m *MockTokenManager) DeleteToken(ctx context.Context, token string) error {
return m.DeleteTokenFunc(ctx, token)
}
func (m *MockTokenManager) BlacklistToken(ctx context.Context, token string) error {
return m.BlacklistTokenFunc(ctx, token)
}
func (m *MockTokenManager) DeleteExpiredTokens(ctx context.Context) error {
return m.DeleteExpiredTokensFunc(ctx)
}
package mocks
import (
"context"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
)
var _ token.TokenParser = (*MockTokenParser)(nil)
type MockTokenParser struct {
ParseTokenFunc func(ctx context.Context, tokenString string) (*token.TokenClaims, error)
}
func (m *MockTokenParser) ParseToken(ctx context.Context, tokenStr string) (*token.TokenClaims, error) {
return m.ParseTokenFunc(ctx, tokenStr)
}
package mocks
import (
"context"
"time"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
)
var _ token.TokenRepository = (*MockTokenRepository)(nil)
type MockTokenRepository struct {
SaveTokenFunc func(ctx context.Context, token string, id string, tokenData *token.TokenData, expiration time.Time) error
IsTokenBlacklistedFunc func(ctx context.Context, token string) (bool, error)
GetTokenFunc func(ctx context.Context, token string) (*token.TokenData, error)
DeleteTokenFunc func(ctx context.Context, token string) error
BlacklistTokenFunc func(ctx context.Context, token string) error
ExistsByTokenIDFunc func(ctx context.Context, tokenID string) (bool, error)
GetExpiredTokensFunc func(ctx context.Context) ([]*token.TokenData, error)
}
func (m *MockTokenRepository) SaveToken(ctx context.Context, token string, id string, tokenData *token.TokenData, expiration time.Time) error {
return m.SaveTokenFunc(ctx, token, id, tokenData, expiration)
}
func (m *MockTokenRepository) IsTokenBlacklisted(ctx context.Context, token string) (bool, error) {
return m.IsTokenBlacklistedFunc(ctx, token)
}
func (m *MockTokenRepository) GetToken(ctx context.Context, token string) (*token.TokenData, error) {
return m.GetTokenFunc(ctx, token)
}
func (m *MockTokenRepository) DeleteToken(ctx context.Context, token string) error {
return m.DeleteTokenFunc(ctx, token)
}
func (m *MockTokenRepository) BlacklistToken(ctx context.Context, token string) error {
return m.BlacklistTokenFunc(ctx, token)
}
func (m *MockTokenRepository) ExistsByTokenID(ctx context.Context, tokenID string) (bool, error) {
return m.ExistsByTokenIDFunc(ctx, tokenID)
}
func (m *MockTokenRepository) GetExpiredTokens(ctx context.Context) ([]*token.TokenData, error) {
return m.GetExpiredTokensFunc(ctx)
}
package mocks
import (
"context"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
)
var _ token.TokenValidator = (*MockTokenValidator)(nil)
type MockTokenValidator struct {
ValidateTokenFunc func(ctx context.Context, tokenStr string) error
}
func (m *MockTokenValidator) ValidateToken(ctx context.Context, tokenStr string) error {
return m.ValidateTokenFunc(ctx, tokenStr)
}
package mocks
import (
"context"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ users.UserAuthenticator = (*MockUserAuthenticator)(nil)
type MockUserAuthenticator struct {
AuthenticateUserFunc func(ctx context.Context, request *users.UserLoginRequest) (*users.UserLoginResponse, error)
}
func (m *MockUserAuthenticator) AuthenticateUser(ctx context.Context, request *users.UserLoginRequest) (*users.UserLoginResponse, error) {
return m.AuthenticateUserFunc(ctx, request)
}
package mocks
import (
"context"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ users.UserCreator = (*MockUserCreator)(nil)
type MockUserCreator struct {
CreateUserFunc func(ctx context.Context, user *users.User) (*users.UserRegistrationResponse, error)
}
func (m *MockUserCreator) CreateUser(ctx context.Context, user *users.User) (*users.UserRegistrationResponse, error) {
return m.CreateUserFunc(ctx, user)
}
package mocks
import (
"context"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ users.UserManager = (*MockUserManager)(nil)
type MockUserManager struct {
GetUserByUsernameFunc func(ctx context.Context, username string) (*users.User, error)
GetUserByIDFunc func(ctx context.Context, userID string) (*users.User, error)
DeleteUnverifiedUsersFunc func(ctx context.Context) error
ResetPasswordFunc func(ctx context.Context, userEmail, newPassword, resetToken string) (*users.UserPasswordResetResponse, error)
}
func (m *MockUserManager) GetUserByUsername(ctx context.Context, username string) (*users.User, error) {
return m.GetUserByUsernameFunc(ctx, username)
}
func (m *MockUserManager) GetUserByID(ctx context.Context, userID string) (*users.User, error) {
return m.GetUserByIDFunc(ctx, userID)
}
func (m *MockUserManager) DeleteUnverifiedUsers(ctx context.Context) error {
return m.DeleteUnverifiedUsersFunc(ctx)
}
func (m *MockUserManager) ResetPassword(
ctx context.Context,
userEmail string,
newPassword string,
resetToken string,
) (*users.UserPasswordResetResponse, error) {
return m.ResetPasswordFunc(ctx, userEmail, newPassword, resetToken)
}
package mocks
import (
"context"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ user.UserRepository = (*MockUserRepository)(nil)
type MockUserRepository struct {
AddUserFunc func(ctx context.Context, user *user.User) error
GetUserByIDFunc func(ctx context.Context, userID string) (*user.User, error)
DeleteUserByIDFunc func(ctx context.Context, userID string) error
UpdateUserFunc func(ctx context.Context, user *user.User) error
GetUserByEmailFunc func(ctx context.Context, email string) (*user.User, error)
GetUserByUsernameFunc func(ctx context.Context, username string) (*user.User, error)
FindUnverifiedUsersOlderThanWeekFunc func(ctx context.Context) ([]*user.User, error)
}
func (m *MockUserRepository) AddUser(ctx context.Context, user *user.User) error {
return m.AddUserFunc(ctx, user)
}
func (m *MockUserRepository) GetUserByID(ctx context.Context, userID string) (*user.User, error) {
return m.GetUserByIDFunc(ctx, userID)
}
func (m *MockUserRepository) DeleteUserByID(ctx context.Context, userID string) error {
return m.DeleteUserByIDFunc(ctx, userID)
}
func (m *MockUserRepository) UpdateUser(ctx context.Context, user *user.User) error {
return m.UpdateUserFunc(ctx, user)
}
func (m *MockUserRepository) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
return m.GetUserByEmailFunc(ctx, email)
}
func (m *MockUserRepository) GetUserByUsername(ctx context.Context, username string) (*user.User, error) {
return m.GetUserByUsernameFunc(ctx, username)
}
func (m *MockUserRepository) FindUnverifiedUsersOlderThanWeek(ctx context.Context) ([]*user.User, error) {
return m.FindUnverifiedUsersOlderThanWeekFunc(ctx)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/user"
)
var _ domain.UserVerifier = (*MockUserVerifier)(nil)
type MockUserVerifier struct {
VerifyEmailAddressFunc func(ctx context.Context, verificationCode string) error
}
func (m *MockUserVerifier) VerifyEmailAddress(ctx context.Context, verificationCode string) error {
return m.VerifyEmailAddressFunc(ctx, verificationCode)
}
package mocks
import (
"context"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
var _ domain.UserConsentRepository = (*MockUserConsentRepository)(nil)
type MockUserConsentRepository struct {
HasConsentFunc func(ctx context.Context, userID, clientID string, scope types.Scope) (bool, error)
SaveConsentFunc func(ctx context.Context, userID, clientID string, scope types.Scope) error
RevokeConsentFunc func(ctx context.Context, userID, clientID string) error
}
func (m *MockUserConsentRepository) HasConsent(ctx context.Context, userID, clientID string, scope types.Scope) (bool, error) {
return m.HasConsentFunc(ctx, userID, clientID, scope)
}
func (m *MockUserConsentRepository) SaveConsent(ctx context.Context, userID, clientID string, scope types.Scope) error {
return m.SaveConsentFunc(ctx, userID, clientID, scope)
}
func (m *MockUserConsentRepository) RevokeConsent(ctx context.Context, userID, clientID string) error {
return m.RevokeConsentFunc(ctx, userID, clientID)
}
package mocks
import (
"context"
"net/http"
user "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/types"
)
var _ user.UserConsentService = (*MockUserConsentService)(nil)
type MockUserConsentService struct {
CheckUserConsentFunc func(ctx context.Context, userID, clientID string, scope types.Scope) (bool, error)
SaveUserConsentFunc func(ctx context.Context, userID, clientID string, scope types.Scope) error
RevokeConsentFunc func(ctx context.Context, userID, clientID string) error
GetConsentDetailsFunc func(userID, clientID, redirectURI, state string, scope types.Scope, responseType, nonce, display string, r *http.Request) (*user.UserConsentResponse, error)
ProcessUserConsentFunc func(userID, clientID, redirectURI string, scope types.Scope, consentRequest *user.UserConsentRequest, r *http.Request) (*user.UserConsentResponse, error)
}
func (m *MockUserConsentService) GetConsentDetails(userID, clientID, redirectURI, state string, scope types.Scope, responseType, nonce, display string, r *http.Request) (*user.UserConsentResponse, error) {
return m.GetConsentDetailsFunc(userID, clientID, redirectURI, state, scope, responseType, nonce, display, r)
}
func (m *MockUserConsentService) ProcessUserConsent(userID, clientID, redirectURI string, scope types.Scope, consentRequest *user.UserConsentRequest, r *http.Request) (*user.UserConsentResponse, error) {
return m.ProcessUserConsentFunc(userID, clientID, redirectURI, scope, consentRequest, r)
}
func (m *MockUserConsentService) CheckUserConsent(ctx context.Context, userID, clientID string, scope types.Scope) (bool, error) {
return m.CheckUserConsentFunc(ctx, userID, clientID, scope)
}
func (m *MockUserConsentService) SaveUserConsent(ctx context.Context, userID, clientID string, scope types.Scope) error {
return m.SaveUserConsentFunc(ctx, userID, clientID, scope)
}
func (m *MockUserConsentService) RevokeConsent(ctx context.Context, userID, clientID string) error {
return m.RevokeConsentFunc(ctx, userID, clientID)
}
package repository
import (
"context"
"sync"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var (
logger = config.GetServerConfig().Logger()
instance *InMemoryAuditEventRepository
once sync.Once
_ domain.AuditRepository = (*InMemoryAuditEventRepository)(nil)
)
const module string = "InMemoryAuditEventRepository"
type InMemoryAuditEventRepository struct {
events map[string]*domain.AuditEvent
mu sync.RWMutex
}
// GetInMemoryAuditEventRepository returns the singleton instance of InMemoryAuditEventRepository.
//
// Returns:
// - *InMemoryAuditEventRepository: The singleton instance of InMemoryAuditEventRepository.
func GetInMemoryAuditEventRepository() *InMemoryAuditEventRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryAuditEventRepository")
instance = &InMemoryAuditEventRepository{
events: make(map[string]*domain.AuditEvent),
}
})
return instance
}
// ResetInMemoryAuditEventRepository resets the in-memory audit event store for testing purposes.
func ResetInMemoryAuditEventRepository() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.events = make(map[string]*domain.AuditEvent)
instance.mu.Unlock()
}
}
// StoreAuditEvent stores an audit event in the repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - event *AuditEvent: The audit event to be stored.
//
// Returns:
// - error: An error if storing the event fails, otherwise nil.
func (r *InMemoryAuditEventRepository) StoreAuditEvent(ctx context.Context, event *domain.AuditEvent) error {
r.mu.Lock()
defer r.mu.Unlock()
r.events[event.EventID] = event
return nil
}
// GetAuditEvents retrieves audit events that match the provided filters and time range.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - filters map[string]any: A map of filter keys and values to apply.
// - from time.Time: The start time of the time range to filter events.
// - to time.Time: The end time of the time range to filter events.
// - limit int: The maximum number of events to return.
// - offset int: The number of events to skip (for pagination).
//
// Returns:
// - []*AuditEvent: A slice of matching audit events.
// - error: An error if the retrieval fails, otherwise nil.
func (r *InMemoryAuditEventRepository) GetAuditEvents(
ctx context.Context,
filters map[string]any,
from time.Time,
to time.Time,
limit int,
offset int,
) ([]*domain.AuditEvent, error) {
r.mu.RLock()
defer r.mu.RUnlock()
auditEvents, err := r.getFilteredEvents(ctx, filters, from, to)
if err != nil {
logger.Error(module, utils.GetRequestID(ctx), "[GetRequestID]: An error occurred retrieving filtered events: %v", err)
return nil, err
}
start := min(offset, len(auditEvents))
end := min(start+limit, len(auditEvents))
return auditEvents[start:end], nil
}
// DeleteEvent deletes an event using the given event ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - eventID string: The ID of the event to delete.
//
// Returns:
// - error: An error if deletion fails, otherwise nil.
func (r *InMemoryAuditEventRepository) DeleteEvent(ctx context.Context, eventID string) error {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.events, eventID)
return nil
}
// nolint
func (r *InMemoryAuditEventRepository) getFilteredEvents(
ctx context.Context,
filters map[string]any,
from time.Time,
to time.Time,
) ([]*domain.AuditEvent, error) {
var auditEvents []*domain.AuditEvent
loop:
for _, event := range r.events {
select {
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), errors.ErrCodeRequestTimeout, "the request timed out")
default:
event.Timestamp = event.Timestamp.UTC().Truncate(time.Second)
from = from.UTC().Truncate(time.Second)
to = to.UTC().Truncate(time.Second)
if event.Timestamp.Before(from) || event.Timestamp.After(to) {
continue
}
for key, value := range filters {
switch key {
case "UserID":
if v, ok := value.(string); !ok || event.UserID != v {
continue loop //nolint:nlreturn
}
case "EventType":
if v, ok := value.(string); !ok || event.EventType.String() != v {
continue loop //nolint:nlreturn
}
case "Success":
if v, ok := value.(bool); !ok || event.Success != v {
continue loop //nolint:nlreturn
}
case "IP":
if v, ok := value.(string); !ok || event.IP != v {
continue loop //nolint:nlreturn
}
case "RequestID":
if v, ok := value.(string); !ok || event.RequestID != v {
continue loop //nolint:nlreturn
}
case "SessionID":
if v, ok := value.(string); !ok || event.SessionID != v {
continue loop //nolint:nlreturn
}
default:
logger.Warn(module, utils.GetRequestID(ctx), "[GetAuditEvents]: Unknown filter: %s", key)
continue loop //nolint:nlreturn
}
}
auditEvents = append(auditEvents, event)
}
}
return auditEvents, nil
}
package repository
import (
"context"
"sync"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
const module = "InMemoryAuthorizationCodeRepository"
var (
logger = config.GetServerConfig().Logger()
_ authz.AuthorizationCodeRepository = (*InMemoryAuthorizationCodeRepository)(nil)
instance *InMemoryAuthorizationCodeRepository
once sync.Once
)
type InMemoryAuthorizationCodeRepository struct {
codes map[string]codeEntry
mu sync.RWMutex
}
// codeEntry represents a stored authorization code with expiration.
type codeEntry struct {
Data *authz.AuthorizationCodeData
ExpiresAt time.Time
}
// GetInMemoryAuthorizationCodeRepository returns the singleton instance of InMemoryAuthorizationCodeRepository.
//
// Returns:
// - *InMemoryAuthorizationCodeStore: The singleton instance of InMemoryAuthorizationCodeRepository.
func GetInMemoryAuthorizationCodeRepository() *InMemoryAuthorizationCodeRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryAuthorizationCodeRepository")
instance = &InMemoryAuthorizationCodeRepository{
codes: make(map[string]codeEntry),
}
})
return instance
}
// StoreAuthorizationCode persists an authorization code with its associated data.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code.
// - data *AuthorizationCodeData: The data associated with the code.
// - expiresAt time.Time: When the code expires.
//
// Returns:
// - error: An error if storing fails, nil otherwise.
func (s *InMemoryAuthorizationCodeRepository) StoreAuthorizationCode(ctx context.Context, code string, data *authz.AuthorizationCodeData, expiresAt time.Time) error {
s.mu.Lock()
defer s.mu.Unlock()
s.codes[code] = codeEntry{
Data: data,
ExpiresAt: expiresAt,
}
return nil
}
// GetAuthorizationCode retrieves the data associated with an authorization code.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code to look up.
//
// Returns:
// - *AuthorizationCodeData: The associated data if found.
// - error: An error if retrieval fails.
func (s *InMemoryAuthorizationCodeRepository) GetAuthorizationCode(ctx context.Context, code string) (*authz.AuthorizationCodeData, error) {
s.mu.RLock()
defer s.mu.RUnlock()
entry, exists := s.codes[code]
if !exists {
logger.Debug(module, "", "[GetAuthorizationCode]: Code=%s does not exist", code)
return nil, errors.New(errors.ErrCodeInvalidAuthorizationCode, "authorization code does not exist")
}
return entry.Data, nil
}
// DeleteAuthorizationCode deletes an authorization code after use.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code to remove.
//
// Returns:
// - error: An error if removal fails, nil otherwise.
func (s *InMemoryAuthorizationCodeRepository) DeleteAuthorizationCode(ctx context.Context, code string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.codes, code)
return nil
}
// UpdateAuthorizationCode updates existing authorization code data.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code to update.
// - authData *AuthorizationCodeData: The update authorization code data.
//
// Returns:
// - error: An error if update fails, nil otherwise.
func (s *InMemoryAuthorizationCodeRepository) UpdateAuthorizationCode(ctx context.Context, code string, authData *authz.AuthorizationCodeData) error {
s.mu.Lock()
defer s.mu.Unlock()
s.codes[code] = codeEntry{
Data: authData,
}
return nil
}
package repository
import (
"context"
"sync"
"github.com/vigiloauth/vigilo/v2/idp/config"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var (
logger = config.GetServerConfig().Logger()
_ client.ClientRepository = (*InMemoryClientRepository)(nil)
instance *InMemoryClientRepository
once sync.Once
)
const module = "InMemoryClientRepository"
// InMemoryClientRepository provides an in-memory implementation of ClientStore.
// It uses a map to store clients and a read-write mutex for concurrency control.
type InMemoryClientRepository struct {
data map[string]*client.Client
mu sync.RWMutex
}
// NewInMemoryClientRepository initializes a new InMemoryClientStore instance.
//
// Returns:
//
// *InMemoryClientStore: A new in-memory client store.
func NewInMemoryClientRepository() *InMemoryClientRepository {
return &InMemoryClientRepository{data: make(map[string]*client.Client)}
}
// GetInMemoryClientRepository returns a singleton instance of InMemoryClientStore.
// It ensures that only one instance is created using sync.Once.
//
// Returns:
//
// *InMemoryClientStore: The singleton instance of InMemoryClientStore.
func GetInMemoryClientRepository() *InMemoryClientRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryClientRepository")
instance = &InMemoryClientRepository{
data: make(map[string]*client.Client),
}
})
return instance
}
// ResetInMemoryClientRepository resets the in-memory store for testing purposes.
func ResetInMemoryClientRepository() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.data = make(map[string]*client.Client)
instance.mu.Unlock()
}
}
// SaveClient adds a new client to the store if it does not already exist.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - client *Client: The client object to store.
//
// Returns:
// - error: An error if the client already exists, nil otherwise.
func (cs *InMemoryClientRepository) SaveClient(ctx context.Context, client *client.Client) error {
cs.mu.Lock()
defer cs.mu.Unlock()
if _, clientExists := cs.data[client.ID]; clientExists {
logger.Error(module, "", "[SaveClient]: Failed to save client. Duplicate ID")
return errors.New(errors.ErrCodeDuplicateClient, "client already exists with given ID")
}
cs.data[client.ID] = client
return nil
}
// GetClientByID retrieves a client by its ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - clientID string: The ID of the client to retrieve.
//
// Returns:
// - *Client: The client object if found, nil otherwise.
func (cs *InMemoryClientRepository) GetClientByID(ctx context.Context, clientID string) (*client.Client, error) {
cs.mu.RLock()
defer cs.mu.RUnlock()
requestID := utils.GetRequestID(ctx)
client, found := cs.data[clientID]
if !found {
logger.Debug(module, requestID, "[GetClientByID]: No client found using the given ID=%s", clientID)
return nil, errors.New(errors.ErrCodeClientNotFound, "client not found")
}
return client, nil
}
// DeleteClientByID removes a client from the store by its ID.
// DeleteClient removes a client from the store by its ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - clientID string: The ID of the client to delete.
//
// Returns:
// - error: Returns an error if deletion fails, otherwise false.
func (cs *InMemoryClientRepository) DeleteClientByID(ctx context.Context, clientID string) error {
cs.mu.Lock()
defer cs.mu.Unlock()
delete(cs.data, clientID)
return nil
}
// UpdateClient updates an existing client in the store.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - client *Client: The updated client object.
//
// Returns:
// - error: An error if the client does not exist, nil otherwise.
func (cs *InMemoryClientRepository) UpdateClient(ctx context.Context, client *client.Client) error {
cs.mu.Lock()
defer cs.mu.Unlock()
requestID := utils.GetRequestID(ctx)
if _, clientExists := cs.data[client.ID]; !clientExists {
logger.Debug(module, requestID, "[UpdateClient]: No client found using the given ID=%s", client.ID)
return errors.New(errors.ErrCodeClientNotFound, "client not found using provided ID")
}
cs.data[client.ID] = client
return nil
}
// IsExistingID checks to see if an ID already exists in the database.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - clientID string: The client ID to verify.
//
// Returns:
// - bool: True if it exists, otherwise false.
func (cs *InMemoryClientRepository) IsExistingID(ctx context.Context, clientID string) bool {
_, clientExists := cs.data[clientID]
return clientExists
}
package repository
import (
"context"
"sync"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/login"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
// maxStoredLoginAttempts defines the maximum number of login attempts stored per user.
const (
module = "InMemoryLoginAttemptRepository"
maxStoredLoginAttempts = 100
)
var (
logger = config.GetServerConfig().Logger()
_ domain.LoginAttemptRepository = (*InMemoryLoginAttemptRepository)(nil)
instance *InMemoryLoginAttemptRepository
once sync.Once
)
// InMemoryLoginAttemptRepository is a store for login attempts.
// It uses an in-memory map to store login attempts, keyed by user ID.
type InMemoryLoginAttemptRepository struct {
attempts map[string][]*user.UserLoginAttempt
mu sync.RWMutex
}
// GetInMemoryLoginRepository returns the singleton instance of InMemoryLoginAttemptStore.
//
// Returns:
// - *InMemoryLoginAttemptStore: The singleton instance of InMemoryLoginAttemptStore.
func GetInMemoryLoginRepository() *InMemoryLoginAttemptRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryLoginAttemptRepository")
instance = &InMemoryLoginAttemptRepository{
attempts: make(map[string][]*user.UserLoginAttempt),
}
})
return instance
}
// ResetInMemoryLoginAttemptStore resets the in-memory store for testing purposes.
func ResetInMemoryLoginAttemptStore() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.attempts = make(map[string][]*user.UserLoginAttempt)
instance.mu.Unlock()
}
}
// SaveLoginAttempt saves a login attempt.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - attempt *UserLoginAttempt: The login attempt to save.
//
// Returns:
// - error: If an error occurs saving the login attempts.
func (s *InMemoryLoginAttemptRepository) SaveLoginAttempt(ctx context.Context, attempt *user.UserLoginAttempt) error {
s.mu.Lock()
defer s.mu.Unlock()
s.attempts[attempt.UserID] = append(s.attempts[attempt.UserID], attempt)
s.trimLoginAttempts(attempt.UserID)
return nil
}
// GetLoginAttemptsByUserID retrieves all login attempts for a given user.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The user ID.
//
// Returns:
// - []*UserLoginAttempt: A slice of login attempts for the user.
// - error: If an error occurs retrieving user login attempts.
func (s *InMemoryLoginAttemptRepository) GetLoginAttemptsByUserID(ctx context.Context, userID string) ([]*user.UserLoginAttempt, error) {
s.mu.RLock()
defer s.mu.RUnlock()
attempts, found := s.attempts[userID]
if !found {
return nil, errors.New(errors.ErrCodeNotFound, "failed to retrieve user login attempts")
}
return attempts, nil
}
// trimLoginAttempts trims the list of login attempts for a user if it exceeds the maximum stored attempts.
//
// Parameters:
// - userID string: The user ID.
func (s *InMemoryLoginAttemptRepository) trimLoginAttempts(userID string) {
if len(s.attempts[userID]) > maxStoredLoginAttempts {
s.attempts[userID] = s.attempts[userID][1:]
}
}
package repository
import (
"context"
"sync"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
)
var (
logger = config.GetServerConfig().Logger()
_ session.SessionRepository = (*InMemorySessionRepository)(nil)
instance *InMemorySessionRepository
once sync.Once
)
const module = "InMemorySessionRepository"
type InMemorySessionRepository struct {
data map[string]*session.SessionData
mu sync.RWMutex
}
func NewInMemorySessionRepository() *InMemorySessionRepository {
return &InMemorySessionRepository{
data: make(map[string]*session.SessionData),
}
}
func GetInMemorySessionRepository() *InMemorySessionRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemorySessionRepository")
instance = &InMemorySessionRepository{
data: make(map[string]*session.SessionData),
}
})
return instance
}
func ResetInMemorySessionRepository() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.data = make(map[string]*session.SessionData)
instance.mu.Unlock()
}
}
// SaveSession creates a new session and returns the session ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - sessionData SessionData: The data to store in the new session.
//
// Returns:
// - error: An error if the session creation fails.
func (s *InMemorySessionRepository) SaveSession(ctx context.Context, sessionData *session.SessionData) error {
s.mu.Lock()
defer s.mu.Unlock()
requestID := utils.GetRequestID(ctx)
if _, ok := s.data[sessionData.ID]; ok {
logger.Debug(module, requestID, "[SaveSession]: Failed to save session as it already exists")
return errors.New(errors.ErrCodeDuplicateSession, "session already exists with the given ID")
}
s.data[sessionData.ID] = sessionData
return nil
}
// GetSessionByID retrieves session data for a given session ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - sessionID string: The unique identifier of the session to retrieve.
//
// Returns:
// - *SessionData: The session data associated with the session ID.
// - error: An error if the session is not found or retrieval fails.
func (s *InMemorySessionRepository) GetSessionByID(ctx context.Context, sessionID string) (*session.SessionData, error) {
s.mu.RLock()
defer s.mu.RUnlock()
requestID := utils.GetRequestID(ctx)
session, found := s.data[sessionID]
if !found {
logger.Debug(module, requestID, "[GetSessionByID]: No session exists with the given ID=%s", sessionID)
return nil, errors.New(errors.ErrCodeSessionNotFound, "no session found with the given ID")
}
return session, nil
}
// UpdateSessionByID updates the session data for a given session ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - sessionID string: The unique identifier of the session to update.
// - sessionData SessionData: The updated session data.
//
// Returns:
// - error: An error if the update fails.
func (s *InMemorySessionRepository) UpdateSessionByID(ctx context.Context, sessionID string, sessionData *session.SessionData) error {
s.mu.Lock()
defer s.mu.Unlock()
requestID := utils.GetRequestID(ctx)
existingSession, ok := s.data[sessionID]
if !ok {
logger.Error(module, requestID, "[UpdateSessionByID]: No session exists with the given ID=[%s]", utils.TruncateSensitive(sessionID))
return errors.New(errors.ErrCodeSessionNotFound, "session does not exist with the given ID")
}
if sessionData.UserID != "" {
existingSession.UserID = sessionData.UserID
}
if sessionData.State != "" {
existingSession.State = sessionData.State
}
if sessionData.ClientID != "" {
existingSession.ClientID = sessionData.ClientID
}
s.data[sessionID] = sessionData
return nil
}
// DeleteSessionByID removes a session with the given session ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - sessionID string: The unique identifier of the session to delete.
//
// Returns:
// - error: An error if the deletion fails.
func (s *InMemorySessionRepository) DeleteSessionByID(ctx context.Context, sessionID string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.data, sessionID)
return nil
}
package repository
import (
"context"
"sync"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var (
logger = config.GetServerConfig().Logger()
_ domain.TokenRepository = (*InMemoryTokenRepository)(nil)
instance *InMemoryTokenRepository
once sync.Once
)
const module = "InMemoryTokenRepository"
// InMemoryTokenRepository implements a token store using an in-memory map.
type InMemoryTokenRepository struct {
tokens map[string]*domain.TokenData
blacklist map[string]*domain.TokenData
mu sync.RWMutex
}
// GetInMemoryTokenRepository returns the singleton instance of InMemoryTokenStore.
// It initializes the store and starts a goroutine to clean up expired tokens.
//
// Returns:
//
// *InMemoryTokenStore: The singleton instance of InMemoryTokenStore.
func GetInMemoryTokenRepository() *InMemoryTokenRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryTokenRepository")
instance = &InMemoryTokenRepository{
tokens: make(map[string]*domain.TokenData),
blacklist: make(map[string]*domain.TokenData),
}
})
return instance
}
// ResetInMemoryTokenRepository resets the in-memory token store for testing purposes.
func ResetInMemoryTokenRepository() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.tokens = make(map[string]*domain.TokenData)
instance.blacklist = make(map[string]*domain.TokenData)
instance.mu.Unlock()
}
}
// SaveToken adds a token to the store.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - token string: The token string to add.
// - id string: The id associated with the token.
// - tokenData *TokenData: The data associated with the token.
// - expiration time.Time: The token's expiration time.
//
// Returns:
// - error: If an error occurs saving the token.
func (b *InMemoryTokenRepository) SaveToken(ctx context.Context, tokenStr string, id string, tokenData *domain.TokenData, expiration time.Time) error {
b.mu.Lock()
defer b.mu.Unlock()
requestID := utils.GetRequestID(ctx)
if _, blacklisted := b.blacklist[tokenStr]; blacklisted {
logger.Debug(module, requestID, "[SaveToken]: Token=%s is blacklisted and will not be saved", truncateToken(tokenStr))
return nil
}
b.tokens[tokenStr] = tokenData
return nil
}
// GetToken retrieves a token from the store and validates it.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenStr string: The token string to retrieve.
//
// Returns:
// - *TokenData: The TokenData if the token is valid, or nil if not found.
// - error: If an error occurs retrieving the token.
func (b *InMemoryTokenRepository) GetToken(ctx context.Context, token string) (*domain.TokenData, error) {
b.mu.RLock()
defer b.mu.RUnlock()
requestID := utils.GetRequestID(ctx)
data, exists := b.tokens[token]
if !exists {
logger.Debug(module, requestID, "[GetToken]: Token not found")
return nil, errors.New(errors.ErrCodeTokenNotFound, "token not found or expired")
}
return data, nil
}
// BlacklistToken adds a token to the blacklist.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - token string: The token string to delete.
//
// Returns:
// - error: An error if the token blacklisting fails.
func (b *InMemoryTokenRepository) BlacklistToken(ctx context.Context, token string) error {
b.mu.Lock()
defer b.mu.Unlock()
tokenData := b.tokens[token]
delete(b.tokens, token)
b.blacklist[token] = tokenData
return nil
}
// IsTokenBlacklisted checks if a token is blacklisted.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - token string: The token string to check.
//
// Returns:
// - bool: True if the token is blacklisted, false otherwise.
// - error: If an error occurs checking the token.
func (b *InMemoryTokenRepository) IsTokenBlacklisted(ctx context.Context, token string) (bool, error) {
b.mu.RLock()
defer b.mu.RUnlock()
requestID := utils.GetRequestID(ctx)
_, blacklisted := b.blacklist[token]
if blacklisted {
logger.Debug(module, requestID, "[IsTokenBlacklisted]: Token is blacklisted")
return true, nil
}
data, exists := b.tokens[token]
if !exists {
logger.Debug(module, requestID, "[IsTokenBlacklisted]: Token is not blacklisted")
return false, nil
}
expirationTime := time.Unix(data.TokenClaims.ExpiresAt, 0)
if time.Now().After(expirationTime) {
logger.Debug(module, requestID, "[IsTokenBlacklisted]: Deleting expired token")
delete(b.tokens, token)
return true, nil
}
return false, nil
}
// DeleteToken removes a token from the store.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - token string: The token string to delete.
//
// Returns:
// - error: An error if the token deletion fails.
func (b *InMemoryTokenRepository) DeleteToken(ctx context.Context, token string) error {
b.mu.Lock()
defer b.mu.Unlock()
requestID := utils.GetRequestID(ctx)
if _, exists := b.tokens[token]; !exists {
logger.Warn(module, requestID, "[DeleteToken]: Attempted to delete non-existent token=[%s]", truncateToken(token))
return errors.New(errors.ErrCodeTokenNotFound, "token not found")
}
delete(b.tokens, token)
delete(b.blacklist, token)
logger.Debug(module, requestID, "[DeleteToken]: Successfully deleted token=[%s]", truncateToken(token))
return nil
}
// ExistsByTokenID checks to see if the given ID matches with any token in the repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenID string: The token ID to search.
//
// Returns:
// - error: An error if the searching for the token fails.
func (b *InMemoryTokenRepository) ExistsByTokenID(ctx context.Context, tokenID string) (bool, error) {
for _, data := range b.tokens {
if data.TokenID == tokenID {
return true, nil
}
}
return false, nil
}
// GetExpiredTokens searches for all expired tokens in the repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
//
// Returns:
// - []*TokenData: A slice of token data.
// - error: An error if searching fails.
func (b *InMemoryTokenRepository) GetExpiredTokens(ctx context.Context) ([]*domain.TokenData, error) {
b.mu.RLock()
defer b.mu.RUnlock()
now := time.Now()
tokens := []*domain.TokenData{}
for _, data := range b.tokens {
expirationTime := time.Unix(data.TokenClaims.ExpiresAt, 0)
if now.After(expirationTime) {
tokens = append(tokens, data)
}
}
return tokens, nil
}
// truncateToken truncates a token for safe logging.
func truncateToken(token string) string {
const tokenLength int = 10
if len(token) > tokenLength {
return token[:tokenLength] + "..."
}
return token
}
package repository
import (
"context"
"sync"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var (
logger = config.GetServerConfig().Logger()
_ user.UserRepository = (*InMemoryUserRepository)(nil)
instance *InMemoryUserRepository
once sync.Once
)
const module = "InMemoryUserRepository"
// InMemoryUserRepository implements the UserStore interface using an in-memory map.
type InMemoryUserRepository struct {
users map[string]*user.User
mu sync.RWMutex
}
// GetInMemoryUserRepository returns the singleton instance of InMemoryUserRepository.
//
// Returns:
// - *InMemoryUserRepository: The singleton instance of InMemoryUserRepository.
func GetInMemoryUserRepository() *InMemoryUserRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryUserRepository")
instance = &InMemoryUserRepository{users: make(map[string]*user.User)}
})
return instance
}
// ResetInMemoryUserRepository resets the in-memory user store for testing purposes.
func ResetInMemoryUserRepository() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.users = make(map[string]*user.User)
instance.mu.Unlock()
}
}
// AddUser adds a new user to the store.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - user *User: The User object to add.
//
// Returns:
// - error: An error if the user cannot be added, or nil if successful.
func (u *InMemoryUserRepository) AddUser(ctx context.Context, user *user.User) error {
requestID := utils.GetRequestID(ctx)
if err := ctx.Err(); err != nil {
logger.Debug(module, requestID, "[AddUser]: Context already cancelled")
return errors.NewContextError(err) //nolint:wrapcheck
}
u.mu.RLock()
if existingUser, _ := u.GetUserByEmail(ctx, user.Email); existingUser != nil {
u.mu.RUnlock()
return errors.New(errors.ErrCodeDuplicateUser, "user already exists with the provided email")
}
u.mu.RUnlock()
u.mu.Lock()
defer u.mu.Unlock()
if _, ok := u.users[user.ID]; ok {
logger.Error(module, requestID, "[AddUser]: user already exists with the given ID=[%s]", user.ID)
return errors.New(errors.ErrCodeDuplicateUser, "user already exists with the provided ID")
}
u.users[user.ID] = user
return nil
}
// GetUserByUsername fetches a user by their username.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - username string: The username of the user to retrieve.
//
// Returns:
// - *User: The retrieved user, otherwise nil.
// - error: If an error occurs retrieving the user.
func (u *InMemoryUserRepository) GetUserByUsername(ctx context.Context, username string) (*user.User, error) {
requestID := utils.GetRequestID(ctx)
if err := ctx.Err(); err != nil {
logger.Debug(module, requestID, "[GetUserByUsername]: Context already cancelled")
return nil, errors.NewContextError(err) //nolint:wrapcheck
}
u.mu.RLock()
defer u.mu.RUnlock()
for _, user := range u.users {
if user.PreferredUsername == username {
logger.Debug(module, requestID, "[GetUserByUsername]: User found with the given username=[%s]", username)
return user, nil
}
}
logger.Debug(module, requestID, "[GetUserByUsername]: User not found with the given username=[%s]", username)
return nil, errors.New(errors.ErrCodeUserNotFound, "user not found by username")
}
// GetUserByID retrieves a user from the store using their ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The ID used to retrieve the user.
//
// Returns:
// - *User: The User object if found, or nil if not found.
// - error: If an error occurs retrieving the user.
func (u *InMemoryUserRepository) GetUserByID(ctx context.Context, userID string) (*user.User, error) {
requestID := utils.GetRequestID(ctx)
if err := ctx.Err(); err != nil {
logger.Debug(module, requestID, "[GetUserByID]: Context already cancelled")
return nil, errors.NewContextError(err) //nolint:wrapcheck
}
u.mu.RLock()
defer u.mu.RUnlock()
user, found := u.users[userID]
if !found {
logger.Debug(module, requestID, "[GetUserByID]: User not found with the given ID=[%s]", userID)
return nil, errors.New(errors.ErrCodeUserNotFound, "user not found")
}
return user, nil
}
// GetUserByEmail retrieves a user from the store using their email address.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - email string: The email address used to retrieve the user.
//
// Returns:
// - *User: The User object if found, or nil if not found.
// - error: If an error occurs retrieving the user.
func (u *InMemoryUserRepository) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
requestID := utils.GetRequestID(ctx)
if err := ctx.Err(); err != nil {
logger.Debug(module, requestID, "[GetUserByEmail]: Context already cancelled")
return nil, errors.NewContextError(err) //nolint:wrapcheck
}
u.mu.RLock()
defer u.mu.RUnlock()
for _, user := range u.users {
if user.Email == email {
logger.Debug(module, requestID, "[GetUserByEmail]: User found with the given email=[%s]", email)
return user, nil
}
}
logger.Debug(module, requestID, "[GetUserByEmail]: User not found with the given email=[%s]", email)
return nil, nil
}
// DeleteUserByID removes a user from the repository using their ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The id used to identify the user to delete.
//
// Returns:
// - error: An error if the user cannot be deleted, or nil if successful.
func (u *InMemoryUserRepository) DeleteUserByID(ctx context.Context, userID string) error {
u.mu.Lock()
defer u.mu.Unlock()
delete(u.users, userID)
return nil
}
// UpdateUser updates an existing user's information in the repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - user *User: The User object with updated information.
//
// Returns:
// - error: An error if the user cannot be updated, or nil if successful.
func (u *InMemoryUserRepository) UpdateUser(ctx context.Context, user *user.User) error {
requestID := utils.GetRequestID(ctx)
if err := ctx.Err(); err != nil {
logger.Debug(module, requestID, "[UpdateUser]: Context already cancelled")
return errors.NewContextError(err) //nolint
}
u.mu.Lock()
defer u.mu.Unlock()
if _, ok := u.users[user.ID]; !ok {
logger.Debug(module, requestID, "[UpdateUser]: User not found with the given ID=[%s]", user.ID)
return errors.New(errors.ErrCodeUserNotFound, "user does not exist with the provided ID")
}
u.users[user.ID] = user
return nil
}
// FindUnverifiedUsersOlderThanWeek retrieves users that have not been verified
// and who's account has been created over a week ago.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
//
// Returns:
// - []*User: A slice of users.
// - error: If an error occurs retrieving users.
func (u *InMemoryUserRepository) FindUnverifiedUsersOlderThanWeek(ctx context.Context) ([]*user.User, error) {
var expiredUsers []*user.User
oneWeekAgo := time.Now().AddDate(0, 0, -7) // 7 days ago
for _, user := range u.users {
if !user.EmailVerified && user.CreatedAt.Before(oneWeekAgo) {
expiredUsers = append(expiredUsers, user)
}
}
return expiredUsers, nil
}
package repository
import (
"context"
"strings"
"sync"
"time"
"slices"
"github.com/vigiloauth/vigilo/v2/idp/config"
consent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var (
logger = config.GetServerConfig().Logger()
_ consent.UserConsentRepository = (*InMemoryUserConsentRepository)(nil)
instance *InMemoryUserConsentRepository
once sync.Once
)
const module = "InMemoryUserConsentRepository"
// InMemoryUserConsentRepository implements the ConsentStore interface using an in-memory map.
type InMemoryUserConsentRepository struct {
data map[string]*consent.UserConsentRecord
mu sync.RWMutex
}
// GetInMemoryUserConsentRepository returns the singleton instance of InMemoryConsentRepository.
//
// Returns:
// - *InMemoryConsentStore: The singleton instance of InMemoryConsentRepository.
func GetInMemoryUserConsentRepository() *InMemoryUserConsentRepository {
once.Do(func() {
logger.Debug(module, "", "Creating new instance of InMemoryUserConsentRepository")
instance = &InMemoryUserConsentRepository{data: make(map[string]*consent.UserConsentRecord)}
})
return instance
}
// ResetInMemoryUserConsentRepository resets the in-memory repository for testing purposes.
func ResetInMemoryUserConsentRepository() {
if instance != nil {
logger.Debug(module, "", "Resetting instance")
instance.mu.Lock()
instance.data = make(map[string]*consent.UserConsentRecord)
instance.mu.Unlock()
}
}
// HasConsent checks if a user has granted consent to a client for specific scopes.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The ID of the user.
// - clientID string: The ID of the client application.
// - requestedScope string: The requested scope(s).
//
// Returns:
//
// bool: True if consent exists, false otherwise.
// error: An error if the check fails, or nil if successful.
func (c *InMemoryUserConsentRepository) HasConsent(ctx context.Context, userID string, clientID string, requestedScope types.Scope) (bool, error) {
c.mu.RLock()
defer c.mu.RUnlock()
requestID := utils.GetRequestID(ctx)
key := createConsentKey(userID, clientID)
record, exists := c.data[key]
if !exists {
logger.Debug(module, requestID, "[HasConsent]: Record does not exist with given consent key=[%s]", utils.TruncateSensitive(key))
return false, nil
}
grantedScopes := strings.Fields(record.Scope.String())
requestedScopes := strings.Fields(requestedScope.String())
for _, reqScope := range requestedScopes {
found := slices.Contains(grantedScopes, reqScope)
if !found {
logger.Error(module, requestID, "[HasConsent]: The requested scope=[%s] was not previously granted", reqScope)
return false, errors.New(errors.ErrCodeInsufficientScope, "at least one requested scope wasn't previously granted")
}
}
return true, nil
}
// SaveConsent stores a user's consent for a client and scope.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The ID of the user.
// - clientID string: The ID of the client application.
// - scope string: The granted scope(s).
//
// Returns:
// - error: An error if the consent cannot be saved, or nil if successful.
func (c *InMemoryUserConsentRepository) SaveConsent(ctx context.Context, userID, clientID string, scope types.Scope) error {
c.mu.Lock()
defer c.mu.Unlock()
key := createConsentKey(userID, clientID)
c.data[key] = &consent.UserConsentRecord{
UserID: userID,
ClientID: clientID,
Scope: scope,
CreatedAt: time.Now(),
}
return nil
}
// RevokeConsent removes a user's consent for a client.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The ID of the user.
// - clientID string: The ID of the client application.
//
// Returns:
// - error: An error if the consent cannot be revoked, or nil if successful.
func (c *InMemoryUserConsentRepository) RevokeConsent(ctx context.Context, userID, clientID string) error {
c.mu.Lock()
defer c.mu.Unlock()
key := createConsentKey(userID, clientID)
delete(c.data, key)
return nil
}
// Helper function to generate a composite key.
func createConsentKey(userID, clientID string) string {
return userID + "::" + clientID
}
package routes
import "net/http"
type RouteGroup struct {
Name string
Middleware []func(http.Handler) http.Handler
Routes []Route
}
type Route struct {
Methods []string
Pattern string
Handler http.HandlerFunc
Middleware []func(http.Handler) http.Handler
Description string
}
type RouteBuilder struct {
route Route
}
func NewRoute() *RouteBuilder {
return &RouteBuilder{
route: Route{
Middleware: make([]func(http.Handler) http.Handler, 0),
},
}
}
func (rb *RouteBuilder) SetMethods(methods ...string) *RouteBuilder {
rb.route.Methods = methods
return rb
}
func (rb *RouteBuilder) SetPattern(pattern string) *RouteBuilder {
rb.route.Pattern = pattern
return rb
}
func (rb *RouteBuilder) SetHandler(handler http.HandlerFunc) *RouteBuilder {
rb.route.Handler = handler
return rb
}
func (rb *RouteBuilder) SetMiddleware(middleware ...func(http.Handler) http.Handler) *RouteBuilder {
rb.route.Middleware = append(rb.route.Middleware, middleware...)
return rb
}
func (rb *RouteBuilder) SetDescription(description string) *RouteBuilder {
rb.route.Description = description
return rb
}
func (rb *RouteBuilder) Build() Route {
return rb.route
}
func (r Route) getHTTPMethods() []string {
if len(r.Methods) > 0 {
return r.Methods
}
return []string{http.MethodGet} // Default to GET if no method is provided
}
package routes
import (
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/container"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/middleware"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
const maxAge int = 300
// RouterConfig represents the configuration for setting up a router in the application.
// It includes the router instance, middleware, logger, and other settings related to routing behavior.
//
// Fields:
// - router chi.Router: The chi.Router instance used for handling HTTP routes.
// - middleware *middleware.Middleware: A reference to the middleware configuration for managing request processing.
// - logger *config.Logger: The logger instance used for logging router-related events.
// - module string: A string representing the module name associated with the router.
// - forceHTTPS bool: A boolean flag indicating whether HTTPS should be enforced for incoming requests.
// - enableRequestLogging bool: A boolean flag indicating whether request logging is enabled.
// - handlerRegistry *container.HandlerRegistry: A reference to the HandlerRegistry, which manages the registration of route handlers.
type RouterConfig struct {
router chi.Router
middleware *middleware.Middleware
logger *config.Logger
module string
forceHTTPS bool
enableRequestLogging bool
handlerRegistry *container.HandlerRegistry
}
// NewRouterConfig initializes and returns a new RouterConfig instance.
// It sets up the router, logger, middleware, and handler registry, along with
// configuration options for HTTPS enforcement and request logging.
//
// Parameters:
// - router chi.Router: The chi.Router instance to be used for routing.
// - logger *config.Logger: A pointer to the Logger configuration for logging purposes.
// - forceHTTPS bool: A boolean indicating whether HTTPS should be enforced.
// - enableRequestLogging bool: A boolean indicating whether request logging should be enabled.
// - middleware *middleware.Middleware: A pointer to the Middleware instance for managing middleware.
// - handlerRegistry *container.HandlerRegistry: A pointer to the HandlerRegistry for managing route handlers.
//
// Returns:
// - *RouterConfig: A pointer to the newly created RouterConfig instance.
func NewRouterConfig(
router chi.Router,
logger *config.Logger,
forceHTTPS bool,
enableRequestLogging bool,
middleware *middleware.Middleware,
handlerRegistry *container.HandlerRegistry,
) *RouterConfig {
r := &RouterConfig{
router: router,
logger: logger,
module: "Router Config",
forceHTTPS: forceHTTPS,
enableRequestLogging: enableRequestLogging,
middleware: middleware,
handlerRegistry: handlerRegistry,
}
return r
}
// Router returns the configured chi.Router instance.
// This method provides access to the router managed by the RouterConfig.
// It can be used to define routes or retrieve the underlying router for further customization.
func (rc *RouterConfig) Router() chi.Router {
return rc.router
}
// Init initializes the RouterConfig by setting up global middleware,
// error handlers, and route groups. It ensures that the necessary
// components are registered and configured for the router to function
// correctly.
func (rc *RouterConfig) Init() {
rc.logger.Debug(rc.module, "", "Registering global middleware...")
rc.applyGlobalMiddleware()
rc.logger.Debug(rc.module, "", "Registering error handlers...")
rc.setupErrorHandlers()
rc.logger.Debug(rc.module, "", "Registering route groups...")
rc.setupRouteGroups()
}
func (rc *RouterConfig) applyGlobalMiddleware() {
rc.router.Use(rc.middleware.WithContextValues)
rc.router.Use(rc.middleware.RateLimit)
if rc.forceHTTPS {
rc.logger.Info(rc.module, "", "The Vigilo Identity Provider is running on HTTPS")
rc.router.Use(rc.middleware.RedirectToHTTPS)
} else {
rc.logger.Warn(rc.module, "", "The Vigilo Identity Provider is running on HTTP. It is recommended to enable HTTPS in production environments")
}
if rc.enableRequestLogging {
rc.logger.Warn(rc.module, "", "Request logging is enabled. It is recommended to disable this in production environments.")
rc.router.Use(rc.middleware.RequestLogger)
}
rc.router.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", constants.RequestIDHeader},
ExposedHeaders: []string{"Link", constants.RequestIDHeader},
AllowCredentials: true,
MaxAge: maxAge,
}))
}
func (rc *RouterConfig) setupErrorHandlers() {
rc.router.NotFound(func(w http.ResponseWriter, r *http.Request) {
requestID := ""
if r.Context().Value(constants.ContextKeyRequestID) != nil {
requestID = r.Context().Value(constants.ContextKeyRequestID).(string)
}
rc.logger.Warn(rc.module, requestID, "Resource not found: %s", r.URL)
web.WriteError(w, errors.New(errors.ErrCodeResourceNotFound, "resource not found"))
})
rc.router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
requestID := ""
if r.Context().Value(constants.ContextKeyRequestID) != nil {
requestID = r.Context().Value(constants.ContextKeyRequestID).(string)
}
rc.logger.Warn(rc.module, requestID, "Method not allowed: %s", r.Method)
web.WriteError(w, errors.New(errors.ErrCodeMethodNotAllowed, "method not allowed"))
})
}
func (rc *RouterConfig) setupRouteGroups() {
routeGroups := []RouteGroup{
rc.getAdminRoutes(),
rc.getOIDCRoutes(),
rc.getClientRoutes(),
rc.getUserRoutes(),
rc.getConsentRoutes(),
rc.getAuthorizationRoutes(),
rc.getTokenRoutes(),
}
for _, group := range routeGroups {
rc.registerRouteGroup(group)
}
}
func (rc *RouterConfig) registerRouteGroup(group RouteGroup) {
rc.logger.Info(rc.module, "", "Registering route group: %s", group.Name)
if len(group.Routes) == 0 {
rc.logger.Warn(rc.module, "", "No routes found for group: %s", group.Name)
}
rc.router.Group(func(r chi.Router) {
for _, middleware := range group.Middleware {
r.Use(middleware)
}
for _, route := range group.Routes {
handler := route.Handler
if len(route.Middleware) > 0 {
handler = rc.chainMiddleware(route.Handler, route.Middleware...)
}
methods := route.getHTTPMethods()
if len(methods) > 1 {
r.HandleFunc(route.Pattern, handler)
rc.logger.Debug(rc.module, "", "Registered routes: [%s] %s", strings.Join(methods, ", "), route.Pattern)
} else {
method := methods[0]
r.Method(method, route.Pattern, handler)
rc.logger.Debug(rc.module, "", "Registered route: %s %s", method, route.Pattern)
}
}
})
}
func (rc *RouterConfig) chainMiddleware(handler http.HandlerFunc, middleware ...func(http.Handler) http.Handler) http.HandlerFunc {
for i := len(middleware) - 1; i >= 0; i-- {
handler = middleware[i](handler).ServeHTTP
}
return handler
}
package routes
import (
"fmt"
"net/http"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
func (rc *RouterConfig) getAdminRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining Admin Routes")
handler := rc.handlerRegistry.AdminHandler()
return RouteGroup{
Name: "Admin Routes",
Middleware: []func(http.Handler) http.Handler{
rc.middleware.AuthMiddleware(),
rc.middleware.WithRole(constants.AdminRole),
},
Routes: []Route{
NewRoute().
SetMethods(http.MethodGet).
SetPattern(web.AdminEndpoints.GetAuditEvents).
SetHandler(handler.GetAuditEvents).
SetDescription("Get audit events").
Build(),
},
}
}
func (rc *RouterConfig) getOIDCRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining OIDC Routes")
handler := rc.handlerRegistry.OIDCHandler()
return RouteGroup{
Name: "Open ID Connect Routes",
Routes: []Route{
NewRoute().
SetMiddleware(rc.middleware.AuthMiddleware()).
SetMethods(http.MethodGet, http.MethodPost).
SetPattern(web.OIDCEndpoints.UserInfo).
SetHandler(handler.GetUserInfo).
SetDescription("Get user info").
Build(),
// Public Routes (no auth required)
NewRoute().
SetMethods(http.MethodGet).
SetPattern(web.OIDCEndpoints.JWKS).
SetHandler(handler.GetJWKS).
SetDescription("Get JSON web key sets").
Build(),
NewRoute().
SetMethods(http.MethodGet).
SetPattern(web.OIDCEndpoints.Discovery).
SetHandler(handler.GetOpenIDConfiguration).
SetDescription("Get OIDC configuration").
Build(),
},
}
}
func (rc *RouterConfig) getClientRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining Client Routes")
handler := rc.handlerRegistry.ClientHandler()
urlParam := fmt.Sprintf("/{%s}", constants.ClientIDReqField)
return RouteGroup{
Name: "Client Routes",
Routes: []Route{
// Basic client registration
NewRoute().
SetMiddleware(rc.middleware.RequiresContentType(constants.ContentTypeJSON)).
SetMethods(http.MethodPost).
SetPattern(web.ClientEndpoints.Register).
SetHandler(handler.RegisterClient).
SetDescription("Register new client").
Build(),
// Client configuration management
NewRoute().
SetMiddleware(rc.middleware.AuthMiddleware()).
SetMethods(http.MethodGet, http.MethodPut, http.MethodDelete).
SetPattern(web.ClientEndpoints.ClientConfiguration + urlParam).
SetHandler(handler.ManageClientConfiguration).
SetDescription("Manage client configuration").
Build(),
// Sensitive operations with strict rate limiting
NewRoute().
SetMethods(http.MethodPost).
SetMiddleware(
rc.middleware.StrictRateLimit,
rc.middleware.AuthMiddleware(),
rc.middleware.RequiresContentType(constants.ContentTypeJSON),
).
SetPattern(web.ClientEndpoints.RegenerateSecret + urlParam).
SetHandler(handler.RegenerateSecret).
SetDescription("Regenerate client secret").
Build(),
NewRoute().
SetMethods(http.MethodGet).
SetPattern(web.ClientEndpoints.GetClientByID + urlParam).
SetDescription("Get client by ID").
SetHandler(handler.GetClientByID).
Build(),
},
}
}
func (rc *RouterConfig) getUserRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining User Routes")
handler := rc.handlerRegistry.UserHandler()
return RouteGroup{
Name: "User Routes",
Routes: []Route{
NewRoute().
SetMiddleware(rc.middleware.AuthMiddleware()).
SetMethods(http.MethodPost).
SetPattern(web.UserEndpoints.Logout).
SetHandler(handler.Logout).
SetDescription("User logout").
Build(),
// Public Routes (no auth required)
NewRoute().
SetMethods(http.MethodGet).
SetPattern(web.UserEndpoints.Verify).
SetHandler(handler.VerifyAccount).
SetDescription("User account verification").
Build(),
NewRoute().
SetMethods(http.MethodPost).
SetPattern(web.UserEndpoints.Registration).
SetHandler(handler.Register).
SetDescription("User registration").
Build(),
NewRoute().
SetMethods(http.MethodPost).
SetPattern(web.UserEndpoints.Login).
SetHandler(handler.Login).
SetDescription("Basic user authentication").
Build(),
NewRoute().
SetMethods(http.MethodPatch).
SetPattern(web.UserEndpoints.ResetPassword).
SetHandler(handler.ResetPassword).
SetDescription("User password reset").
Build(),
NewRoute().
SetMethods(http.MethodPost).
SetPattern(web.OAuthEndpoints.Authenticate).
SetHandler(handler.OAuthLogin).
SetDescription("OAuth user authentication").
Build(),
},
}
}
func (rc *RouterConfig) getConsentRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining User Consent Routes")
handler := rc.handlerRegistry.OAuthHandler()
return RouteGroup{
Name: "OAuth Routes",
Middleware: []func(http.Handler) http.Handler{
rc.middleware.RequiresContentType(constants.ContentTypeJSON),
},
Routes: []Route{
NewRoute().
SetMethods(http.MethodGet, http.MethodPost).
SetPattern(web.OAuthEndpoints.UserConsent).
SetHandler(handler.HandleUserConsent).
SetDescription("Manage user consent").
Build(),
},
}
}
func (rc *RouterConfig) getAuthorizationRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining Authorization Routes")
handler := rc.handlerRegistry.AuthorizationHandler()
return RouteGroup{
Name: "Authorization Handler",
Routes: []Route{
NewRoute().
SetMiddleware(rc.middleware.RequiresContentType(constants.ContentTypeFormURLEncoded)).
SetMethods(http.MethodGet, http.MethodPost).
SetPattern(web.OAuthEndpoints.Authorize).
SetHandler(handler.AuthorizeClient).
SetDescription("Client authorization").
Build(),
},
}
}
func (rc *RouterConfig) getTokenRoutes() RouteGroup {
rc.logger.Debug(rc.module, "", "Defining Token Routes")
handler := rc.handlerRegistry.TokenHandler()
return RouteGroup{
Name: "Token Handler",
Middleware: []func(http.Handler) http.Handler{
rc.middleware.RequiresContentType(constants.ContentTypeFormURLEncoded),
},
Routes: []Route{
NewRoute().
SetMethods(http.MethodPost).
SetPattern(web.OAuthEndpoints.Token).
SetHandler(handler.IssueTokens).
SetDescription("Token issuance").
Build(),
NewRoute().
SetMethods(http.MethodPost).
SetPattern(web.OAuthEndpoints.IntrospectToken).
SetHandler(handler.IntrospectToken).
SetDescription("Token introspection").
Build(),
NewRoute().
SetMethods(http.MethodPost).
SetPattern(web.OAuthEndpoints.RevokeToken).
SetHandler(handler.RevokeToken).
SetDescription("Token revocation").
Build(),
},
}
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ audit.AuditLogger = (*auditLogger)(nil)
type auditLogger struct {
auditRepo audit.AuditRepository
logger *config.Logger
module string
}
func NewAuditLogger(auditRepo audit.AuditRepository) audit.AuditLogger {
return &auditLogger{
auditRepo: auditRepo,
logger: config.GetServerConfig().Logger(),
module: "Audit Logger",
}
}
// StoreEvent saves an AuditEvent to the repository.
// If an error occurs storing the audit event, no error will be returned so that the flow is not disrupted.
//
// Parameters:
// - ctx Context: The context for managing timeouts, cancellations, and for storing/retrieving event metadata.
// - eventType EventType: The type of event to store.
// - success bool: True if the event was successful, otherwise false.
// - action ActionType: The action that is to be audited.
// - method MethodType: The method used (password, email, etc).
// - err error: The error if applicable, otherwise nil.
func (a *auditLogger) StoreEvent(
ctx context.Context,
eventType audit.EventType,
success bool,
action audit.ActionType,
method audit.MethodType,
err error,
) {
requestID := utils.GetRequestID(ctx)
var errCode string
if e, ok := err.(*errors.VigiloAuthError); ok { //nolint:errorlint
errCode = e.ErrorCode
}
event := audit.NewAuditEvent(ctx, eventType, success, action, method, errCode)
if err := a.auditRepo.StoreAuditEvent(ctx, event); err != nil {
a.logger.Error(a.module, requestID, "[StoreEvent]: Failed to store audit event: %v", err)
return
}
}
// GetAuditEvents retrieves audit events that match the provided filters and time range.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - filters map[string]any: A map of filter keys and values to apply.
// - from time.Time: The start time of the time range to filter events.
// - to time.Time: The end time of the time range to filter events.
// - limit int: The maximum number of events to return.
// - offset int: The number of events to skip (for pagination).
//
// Returns:
// - []*AuditEvent: A slice of matching audit events.
// - error: An error if the retrieval fails, otherwise nil.
func (a *auditLogger) GetAuditEvents(ctx context.Context, filters map[string]any, fromStr string, toStr string, limit, offset int) ([]*audit.AuditEvent, error) {
requestID := utils.GetRequestID(ctx)
from, err := time.Parse(time.RFC3339, fromStr)
if err != nil {
a.logger.Error(a.module, requestID, "[GetAuditEvents]: Invalid 'from' timestamp format=[%s]", fromStr)
return nil, errors.Wrap(err, errors.ErrCodeInvalidInput, "invalid 'from' timestamp - must be in RFC3339 format")
}
to, err := time.Parse(time.RFC3339, toStr)
if err != nil {
a.logger.Error(a.module, requestID, "[GetAuditEvents]: Invalid 'to' timestamp format=[%s]", toStr)
return nil, errors.Wrap(err, errors.ErrCodeInvalidInput, "invalid 'to' timestamp - must be in RFC3339 format")
}
events, err := a.auditRepo.GetAuditEvents(ctx, filters, from, to, limit, offset)
if err != nil {
a.logger.Error(a.module, requestID, "[GetAuditEvents]: An error occurred retrieving audit events: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve audit events")
}
return events, nil
}
// DeleteOldEvents deletes audit events older than the specified timestamp.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - olderThan time.Time: Events older than this timestamp will be deleted.
//
// Returns:
// - error: An error if deletion fails, otherwise nil.
func (a *auditLogger) DeleteOldEvents(ctx context.Context, olderThan time.Time) error {
const limit int = 1000
const offset int = 0
events, err := a.auditRepo.GetAuditEvents(ctx, map[string]any{}, time.Time{}, olderThan, limit, offset)
if err != nil {
a.logger.Error(a.module, "", "[DeleteOldEvents]: An error retrieving old audit events: %v", err)
return errors.Wrap(err, "", "failed to retrieve old audit events")
}
if len(events) == 0 {
a.logger.Info(a.module, "", "[DeleteOldEvents]: No audit events to remove in the given time period")
return nil
}
for _, event := range events {
if err := a.auditRepo.DeleteEvent(ctx, event.EventID); err != nil {
a.logger.Error(a.module, "", "[DeleteOldEvents]: An error occurred deleting event ID=[%s]: %v", event.EventID, err)
return errors.Wrap(err, "", "failed to delete audit event")
}
}
return nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authorization"
authzCode "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
consent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
// Compile-time interface implementation check
var _ authz.AuthorizationService = (*authorizationService)(nil)
// authorizationService implements the AuthorizationService interface
// and coordinates authorization-related operations across multiple services.
type authorizationService struct {
authzCodeService authzCode.AuthorizationCodeManager
userConsentService consent.UserConsentService
clientManager client.ClientManager
clientValidator client.ClientValidator
userManager user.UserManager
tokenManager token.TokenManager
logger *config.Logger
module string
}
func NewAuthorizationService(
authzCodeService authzCode.AuthorizationCodeManager,
userConsentService consent.UserConsentService,
tokenManager token.TokenManager,
clientManager client.ClientManager,
clientValidator client.ClientValidator,
userManager user.UserManager,
) authz.AuthorizationService {
return &authorizationService{
authzCodeService: authzCodeService,
userConsentService: userConsentService,
tokenManager: tokenManager,
clientManager: clientManager,
clientValidator: clientValidator,
userManager: userManager,
logger: config.GetServerConfig().Logger(),
module: "Authorization Service",
}
}
// AuthorizeTokenExchange validates the token exchange request for an OAuth 2.0 authorization code grant.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenRequest token.TokenRequest: The token exchange request containing client and authorization code details.
//
// Returns:
// - *AuthorizationCodeData: The authorization code data if authorization is successful.
// - error: An error if the token exchange request is invalid or fails authorization checks.
func (s *authorizationService) AuthorizeTokenExchange(
ctx context.Context,
tokenRequest *token.TokenRequest,
) (code *authzCode.AuthorizationCodeData, err error) {
requestID := utils.GetRequestID(ctx)
authzCodeData, err := s.authzCodeService.GetAuthorizationCode(ctx, tokenRequest.AuthorizationCode)
if err != nil {
return nil, errors.Wrap(err, "", "failed to retrieve authorization code")
}
defer func() {
if err != nil || authzCodeData != nil {
if err := s.markAuthorizationCodeAsUsed(ctx, authzCodeData); err != nil {
s.logger.Error(s.module, requestID, "[AuthorizeTokenExchange]: Failed to mark authorization code as used: %v", err)
}
}
}()
if authzCodeData.Used {
if authzCodeData.AccessTokenHash != "" {
s.revokeAccessToken(ctx, authzCodeData.AccessTokenHash)
}
return nil, errors.New(errors.ErrCodeInvalidGrant, "authorization code has already been used")
}
if err := s.validateClient(ctx, authzCodeData, tokenRequest); err != nil {
s.logger.Error(s.module, requestID, "[AuthorizeTokenExchange]: Failed to validate client=[%s]: %v", tokenRequest.ClientID, err)
return nil, errors.Wrap(err, "", "failed to validate client")
}
if err := s.handlePKCEValidation(authzCodeData, tokenRequest); err != nil {
return nil, err
}
return authzCodeData, nil
}
// AuthorizeUserInfoRequest validates whether the provided access token claims grant sufficient
// permission to access the /userinfo endpoint.
//
// This method is responsible for performing authorization checks and retrieving the user only. It does not validate the token itself (assumes
// the token has already been validated by the time this method is called).
//
// Parameters:
// - ctx context.Context: The context for managing timeouts and cancellations.
// - claims *TokenClaims: The token claims extracted from the a valid access token. These claims should include the
// 'scope' field, which will be used to verify whether the client is authorized for the request.
//
// Returns:
// - error: An error if authorization fails, otherwise nil.
func (s *authorizationService) AuthorizeUserInfoRequest(ctx context.Context, claims *token.TokenClaims) (*user.User, error) {
requestID := utils.GetRequestID(ctx)
s.logger.Debug(s.module, requestID, "[AuthorizeUserInfoRequest]: Starting user info authorization request")
if claims == nil {
s.logger.Error(s.module, requestID, "[AuthorizeUserInfoRequest]: Token claims provided are nil")
return nil, errors.New(errors.ErrCodeEmptyInput, "token claims provided are empty")
}
requestedScopes := types.ParseScopesString(claims.Scopes.String())
if !types.ContainsScope(requestedScopes, types.OpenIDScope) {
return nil, errors.New(errors.ErrCodeInsufficientScope, "bearer access token has insufficient privileges")
}
userID := claims.Subject
retrievedUser, err := s.userManager.GetUserByID(ctx, userID)
if err != nil {
s.logger.Error(s.module, requestID, "[AuthorizeUserInfoRequest]: An error occurred retrieving the user: %v", err)
return nil, errors.Wrap(err, "", "an error occurred retrieving the specified user")
}
if err := s.validateClientScopes(ctx, claims.Audience, requestedScopes); err != nil {
s.logger.Error(s.module, requestID, "[AuthorizeUserInfoRequest]: An error occurred retrieving and validating the client: %v", err)
return nil, errors.Wrap(err, "", "an error occurred validating the client's scopes")
}
return retrievedUser, nil
}
// UpdateAuthorizationCode updates the authorization code data in the database.
//
// Parameters:
// - ctx context.Context: The context for managing timeouts and cancellations.
// - authData *AuthorizationCodeData: The authorization code data to update.
//
// Returns:
// - error: An error if the update fails, otherwise nil.
func (s *authorizationService) UpdateAuthorizationCode(ctx context.Context, authData *authzCode.AuthorizationCodeData) error {
if err := s.authzCodeService.UpdateAuthorizationCode(ctx, authData); err != nil {
s.logger.Error(s.module, utils.GetRequestID(ctx), "[UpdateAuthorizationCode]: Failed to update code: %v", err)
return errors.Wrap(err, "", "failed to update authorization code")
}
return nil
}
func (s *authorizationService) validateClientScopes(ctx context.Context, clientID string, requestedScopes []types.Scope) error {
retrievedClient, err := s.clientManager.GetClientByID(ctx, clientID)
if err != nil {
return errors.Wrap(err, errors.ErrCodeUnauthorized, "invalid client credentials")
}
if !retrievedClient.CanRequestScopes {
for _, scope := range requestedScopes {
if !retrievedClient.HasScope(scope) {
return errors.New(errors.ErrCodeInsufficientScope, "bearer access token has insufficient privileges")
}
}
}
return nil
}
func (s *authorizationService) validateClient(ctx context.Context, code *authzCode.AuthorizationCodeData, tokenRequest *token.TokenRequest) error {
requestID := utils.GetRequestID(ctx)
s.logger.Debug(s.module, requestID, "Starting client validation process")
client, err := s.clientManager.GetClientByID(ctx, tokenRequest.ClientID)
if err != nil {
s.logger.Error(s.module, requestID, "An error occurred retrieving the client by ID: %v", err)
return errors.New(errors.ErrCodeInvalidClient, "invalid client")
}
if client.IsConfidential() && !client.SecretsMatch(tokenRequest.ClientSecret) {
s.logger.Error(s.module, requestID, "Failed to validate client: client secret from token request does not match with a registered client")
return errors.New(errors.ErrCodeInvalidClient, "invalid client credentials")
}
if code.ClientID != tokenRequest.ClientID {
s.logger.Error(s.module, requestID, "Failed to validate client: client ID from token request does not match with a registered client")
return errors.New(errors.ErrCodeInvalidGrant, "authorization code client ID and request client ID do no match")
}
return nil
}
func (s *authorizationService) handlePKCEValidation(authzCodeData *authzCode.AuthorizationCodeData, tokenRequest *token.TokenRequest) error {
if authzCodeData.CodeChallenge == "" {
s.logger.Debug(s.module, "", "PKCE is not required for this request. Skipping validation")
return nil
}
if tokenRequest.CodeVerifier == "" {
s.logger.Error(s.module, "", "Missing code verifier for PKCE")
return errors.New(errors.ErrCodeInvalidRequest, "missing code verifier for PKCE")
} else if err := tokenRequest.ValidateCodeVerifier(); err != nil {
s.logger.Error(s.module, "", "Failed to validate code verifier: %v", err)
return errors.Wrap(err, "", "an error occurred validating the provided code verifier")
}
return nil
}
func (s *authorizationService) revokeAccessToken(ctx context.Context, token string) {
if err := s.tokenManager.BlacklistToken(ctx, token); err != nil {
s.logger.Error(s.module, utils.GetRequestID(ctx), "[revokeAccessToken]: Failed to blacklist token: %v", err)
}
}
func (s *authorizationService) markAuthorizationCodeAsUsed(ctx context.Context, authzCodeData *authzCode.AuthorizationCodeData) error {
authzCodeData.Used = true
if err := s.authzCodeService.UpdateAuthorizationCode(ctx, authzCodeData); err != nil {
s.logger.Error(s.module, utils.GetRequestID(ctx), "[AuthorizeTokenExchange]: Failed to mark code as used: %v", err)
return errors.Wrap(err, "", "failed to mark the authorization code as used")
}
return nil
}
package service
import (
"context"
"encoding/base64"
"fmt"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
crypto "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ domain.AuthorizationCodeCreator = (*authorizationCodeCreator)(nil)
const authorizationCodeLength int = 32
type authorizationCodeCreator struct {
repo domain.AuthorizationCodeRepository
cryptographer crypto.Cryptographer
codeLifeTime time.Duration
logger *config.Logger
module string
}
func NewAuthorizationCodeCreator(
repo domain.AuthorizationCodeRepository,
cryptographer crypto.Cryptographer,
) domain.AuthorizationCodeCreator {
return &authorizationCodeCreator{
repo: repo,
cryptographer: cryptographer,
codeLifeTime: config.GetServerConfig().AuthorizationCodeDuration(),
logger: config.GetServerConfig().Logger(),
module: "Authorization Code Creator",
}
}
// GenerateAuthorizationCode creates a new authorization code and stores it with associated data.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - req *ClientAuthorizationRequest: The request containing the metadata to generate an authorization code.
//
// Returns:
// - string: The generated authorization code.
// - error: An error if code generation fails.
func (c *authorizationCodeCreator) GenerateAuthorizationCode(ctx context.Context, req *client.ClientAuthorizationRequest) (string, error) {
requestID := utils.GetRequestID(ctx)
codeData := &domain.AuthorizationCodeData{
UserID: req.UserID,
ClientID: req.ClientID,
RedirectURI: req.RedirectURI,
Scope: req.Scope,
CreatedAt: time.Now(),
Used: false,
Nonce: req.Nonce,
UserAuthenticationTime: req.UserAuthenticationTime.UTC(),
}
if req.ClaimsRequest != nil {
codeData.ClaimsRequest = req.ClaimsRequest
}
if req.ACRValues != "" {
codeData.ACRValues = req.ACRValues
}
if err := c.handlePKCECreation(requestID, codeData, req); err != nil {
c.logger.Error(c.module, requestID, "[GenerateAuthorizationCode] Error handling PKCE creation: %v", err)
return "", errors.Wrap(err, "", "Failed to handle PKCE creation")
}
expirationTime := codeData.CreatedAt.Add(c.codeLifeTime)
if err := c.repo.StoreAuthorizationCode(ctx, codeData.Code, codeData, expirationTime); err != nil {
c.logger.Error(c.module, requestID, "[GenerateAuthorizationCode] Error creating authorization code in repository: %v", err)
return "", errors.Wrap(err, errors.ErrCodeInternalServerError, "Failed to create authorization code in repository")
}
return codeData.Code, nil
}
func (c *authorizationCodeCreator) handlePKCECreation(
requestID string,
codeData *domain.AuthorizationCodeData,
req *client.ClientAuthorizationRequest,
) error {
if req.Client.RequiresPKCE {
code, err := c.generateAuthorizationCodeForPKCE(requestID, req)
if err != nil {
c.logger.Error(c.module, requestID, "[handlePKCECreation] Error generating PKCE authorization code: %v", err)
return errors.Wrap(err, "", "Failed to generate PKCE authorization code")
}
codeData.Code = code
codeData.CodeChallenge = req.CodeChallenge
codeData.CodeChallengeMethod = req.CodeChallengeMethod
} else {
code, err := c.cryptographer.GenerateRandomString(authorizationCodeLength)
if err != nil {
c.logger.Error(c.module, requestID, "[handlePKCECreation] Error generating random string for authorization code: %v", err)
return errors.Wrap(err, "", "Failed to generate authorization code")
}
codeData.Code = code
}
return nil
}
func (c *authorizationCodeCreator) generateAuthorizationCodeForPKCE(
requestID string,
req *client.ClientAuthorizationRequest,
) (string, error) {
baseAuthorizationCode, err := c.cryptographer.GenerateRandomString(authorizationCodeLength)
if err != nil {
c.logger.Error(c.module, requestID, "[generateAuthorizationCodeForPKCE] Error generating random string for PKCE authorization code: %v", err)
return "", errors.Wrap(err, "", "Failed to generate PKCE authorization code")
}
combinedAuthorizationCode := fmt.Sprintf("%s|%s|%s", baseAuthorizationCode, req.CodeChallenge, req.CodeChallengeMethod)
encodedAuthorizationCode := base64.RawURLEncoding.EncodeToString([]byte(combinedAuthorizationCode))
return encodedAuthorizationCode, nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ domain.AuthorizationCodeIssuer = (*authorizationCodeIssuer)(nil)
type authorizationCodeIssuer struct {
creator domain.AuthorizationCodeCreator
logger *config.Logger
module string
}
func NewAuthorizationCodeIssuer(
creator domain.AuthorizationCodeCreator,
) domain.AuthorizationCodeIssuer {
return &authorizationCodeIssuer{
creator: creator,
logger: config.GetServerConfig().Logger(),
module: "Authorization Code Issuer",
}
}
// IssueAuthorizationCode generates an authorization code for the given client request.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - req *ClientAuthorizationRequest: The request containing the metadata to generate an authorization code.
//
// Returns:
// - string: The generated authorization code.
// - error: An error if code generation fails.
func (c *authorizationCodeIssuer) IssueAuthorizationCode(
ctx context.Context,
req *client.ClientAuthorizationRequest,
) (string, error) {
requestID := utils.GetRequestID(ctx)
code, err := c.creator.GenerateAuthorizationCode(ctx, req)
if err != nil {
c.logger.Error(c.module, requestID, "[IssueAuthorizationCode] Error generating authorization code: %v", err)
return "", errors.Wrap(err, "", "failed to generate authorization code")
}
return code, nil
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ authz.AuthorizationCodeManager = (*authorizationCodeService)(nil)
type authorizationCodeService struct {
repo authz.AuthorizationCodeRepository
codeLifeTime time.Duration
logger *config.Logger
module string
}
func NewAuthorizationCodeManager(
repo authz.AuthorizationCodeRepository,
) authz.AuthorizationCodeManager {
return &authorizationCodeService{
repo: repo,
codeLifeTime: config.GetServerConfig().AuthorizationCodeDuration(),
logger: config.GetServerConfig().Logger(),
module: "Authorization Code Service",
}
}
// RevokeAuthorizationCode explicitly invalidates a code.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code to revoke.
//
// Returns:
// - error: An error if revocation fails.
func (c *authorizationCodeService) RevokeAuthorizationCode(ctx context.Context, code string) error {
requestID := utils.GetRequestID(ctx)
codeData, err := c.repo.GetAuthorizationCode(ctx, code)
if err != nil {
c.logger.Error(c.module, requestID, "[RevokeAuthorizationCode]: Failed to retrieve authorization code: %v", err)
return errors.Wrap(err, "", "failed to revoke the authorization code")
}
codeData.Used = true
if err := c.repo.UpdateAuthorizationCode(ctx, code, codeData); err != nil {
c.logger.Error(c.module, "", "[ValidateAuthorizationCode]: Failed to update authorization code: %v", err)
return errors.NewInternalServerError(err.Error()) //nolint:wrapcheck
}
return nil
}
// UpdateAuthorizationCode updates the provided authorization code data in the repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - authData (*authz.AuthorizationCodeData): The authorization code data to be updated.
//
// Returns:
// - error: An error if updated the authorization code fails, or nil if the operation succeeds.
func (c *authorizationCodeService) UpdateAuthorizationCode(ctx context.Context, authData *authz.AuthorizationCodeData) error {
requestID := utils.GetRequestID(ctx)
if err := c.repo.UpdateAuthorizationCode(ctx, authData.Code, authData); err != nil {
c.logger.Error(c.module, requestID, "[UpdateAuthorizationCode]: Failed to update authorization code: %v", err)
return errors.Wrap(err, "", "failed to update the authorization code")
}
return nil
}
// GetAuthorizationCode retrieves the authorization code data for a given code.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code to retrieve.
//
// Returns:
// - *AuthorizationCodeData: The authorization code data if found, or nil if no matching code exists.
func (c *authorizationCodeService) GetAuthorizationCode(ctx context.Context, code string) (*authz.AuthorizationCodeData, error) {
retrievedCode, err := c.repo.GetAuthorizationCode(ctx, code)
if err != nil {
c.logger.Error(c.module, "", "[GetAuthorizationCode]: An error occurred retrieving the authorization code: %v", err)
return nil, errors.Wrap(err, "", "error retrieving the authorization code")
}
return retrievedCode, nil
}
package service
import (
"context"
"fmt"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ domain.AuthorizationCodeValidator = (*authorizationCodeValidator)(nil)
type authorizationCodeValidator struct {
repo domain.AuthorizationCodeRepository
clientValidator client.ClientValidator
clientAuthenticator client.ClientAuthenticator
logger *config.Logger
module string
}
func NewAuthorizationCodeValidator(
repo domain.AuthorizationCodeRepository,
clientValidator client.ClientValidator,
clientAuthenticator client.ClientAuthenticator,
) domain.AuthorizationCodeValidator {
return &authorizationCodeValidator{
repo: repo,
clientValidator: clientValidator,
clientAuthenticator: clientAuthenticator,
logger: config.GetServerConfig().Logger(),
module: "Authorization Code Issuer",
}
}
// ValidateRequest checks the validity of the client authorization request.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - req *ClientAuthorizationRequest: The request to validate.
//
// Returns:
// - error: An error if the request is invalid.
func (c *authorizationCodeValidator) ValidateRequest(
ctx context.Context,
req *client.ClientAuthorizationRequest,
) error {
requestID := utils.GetRequestID(ctx)
if err := c.clientValidator.ValidateAuthorizationRequest(ctx, req); err != nil {
c.logger.Error(c.module, requestID, "[IssueAuthorizationCode] Error validating authorization request: %v", err)
return errors.Wrap(err, "", "failed to validate authorization request")
}
clientAuthRequest := &client.ClientAuthenticationRequest{
ClientID: req.ClientID,
ClientSecret: req.Client.Secret,
RequestedScopes: req.Scope,
RedirectURI: req.RedirectURI,
RequestedGrant: constants.AuthorizationCodeGrantType,
}
if err := c.clientAuthenticator.AuthenticateClient(ctx, clientAuthRequest); err != nil {
c.logger.Error(c.module, requestID, "[IssueAuthorizationCode] Error authenticating client: %v", err)
return errors.Wrap(err, "", "failed to authenticate client")
}
return nil
}
// ValidateAuthorizationCode checks if a code is valid and returns the associated data.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - code string: The authorization code to validate.
// - clientID string: The client requesting validation.
// - redirectURI string: The redirect URI to verify.
//
// Returns:
// - error: An error if validation fails.
func (c *authorizationCodeValidator) ValidateAuthorizationCode(
ctx context.Context,
code string,
clientID string,
redirectURI string,
) error {
requestID := utils.GetRequestID(ctx)
codeData, err := c.repo.GetAuthorizationCode(ctx, code)
if err != nil {
c.logger.Error(c.module, requestID, "[ValidateAuthorizationCode] Error retrieving authorization code: %v", err)
return errors.Wrap(err, "", "failed to retrieve authorization code")
}
if codeData.Used {
c.logger.Error(c.module, requestID, "[ValidateAuthorizationCode] Authorization code has already been used")
return errors.New(errors.ErrCodeInvalidGrant, "authorization code has already been used")
} else if codeData.ClientID != clientID {
c.logger.Error(c.module, requestID, "[ValidateAuthorizationCode] Authorization code client ID and request client ID do not match")
return errors.New(errors.ErrCodeInvalidGrant, "authorization code client ID and request client ID do not match")
} else if codeData.RedirectURI != redirectURI {
c.logger.Error(c.module, requestID, "[ValidateAuthorizationCode] Authorization code redirect URI and request redirect URI do not match")
return errors.New(errors.ErrCodeInvalidGrant, "authorization code redirect URI and request redirect URI do not match")
}
return nil
}
// ValidatePKCE validates the PKCE (Proof Key for Code Exchange) parameters during the token exchange process.
//
// This method checks if the provided code verifier matches the code challenge stored in the authorization code data.
// It supports the "S256" (SHA-256) and "plain" code challenge methods.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - authzCodeData (*authz.AuthorizationCodeData): The authorization code data containing the code challenge and method.
// - codeVerifier (string): The code verifier provided by the client during the token exchange.
//
// Returns:
// - error: An error if the validation fails, including cases where the code verifier does not match the code challenge
// or if the code challenge method is unsupported. Returns nil if validation succeeds.
func (c *authorizationCodeValidator) ValidatePKCE(
ctx context.Context,
authzCodeData *domain.AuthorizationCodeData,
codeVerifier string,
) error {
requestID := utils.GetRequestID(ctx)
switch authzCodeData.CodeChallengeMethod {
case types.SHA256CodeChallengeMethod:
hashedVerifier := utils.EncodeSHA256(codeVerifier)
if hashedVerifier != authzCodeData.CodeChallenge {
c.logger.Error(c.module, requestID, "[ValidatePKCE]: The provided code challenge does not match with the code verifier.")
return errors.New(errors.ErrCodeInvalidGrant, "invalid code verifier")
}
case types.PlainCodeChallengeMethod:
if codeVerifier != authzCodeData.CodeChallenge {
c.logger.Error(c.module, requestID, "[ValidatePKCE]: The provided code challenge does not match with the code verifier.")
return errors.New(errors.ErrCodeInvalidGrant, "invalid code verifier")
}
default:
return errors.New(errors.ErrCodeUnauthorized, fmt.Sprintf("unsupported code challenge method: %v", authzCodeData.CodeChallengeMethod))
}
return nil
}
package service
import (
"context"
"net/http"
"strings"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
clients "github.com/vigiloauth/vigilo/v2/internal/domain/client"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
var _ clients.ClientAuthenticator = (*clientAuthenticator)(nil)
type clientAuthenticator struct {
clientRepo clients.ClientRepository
tokenValidator tokens.TokenValidator
tokenParser tokens.TokenParser
logger *config.Logger
module string
}
func NewClientAuthenticator(
clientRepo clients.ClientRepository,
tokenValidator tokens.TokenValidator,
tokenParser tokens.TokenParser,
) clients.ClientAuthenticator {
return &clientAuthenticator{
clientRepo: clientRepo,
tokenValidator: tokenValidator,
tokenParser: tokenParser,
logger: config.GetServerConfig().Logger(),
module: "Client Request Authenticator",
}
}
// AuthenticateRequest validates the incoming HTTP request to ensure the client has the required scope.
//
// Parameters:
// - ctx context.Context: The context for managing timeouts and cancellations.
// - r *http.Request: The HTTP request to authenticate.
// - requiredScope types.Scope: The scope required to access the requested resource.
//
// Returns:
// - error: An error if authentication fails or the required scope is not met.
func (c *clientAuthenticator) AuthenticateRequest(
ctx context.Context,
r *http.Request,
requiredScope types.Scope,
) error {
authHeader := r.Header.Get(constants.AuthorizationHeader)
switch {
case strings.HasPrefix(authHeader, constants.BasicAuthHeader):
return c.authenticateWithBasicAuth(ctx, r, requiredScope)
case strings.HasPrefix(authHeader, constants.BearerAuthHeader):
return c.authenticateWithBearerToken(ctx, r, requiredScope)
default:
return errors.New(errors.ErrCodeInvalidClient, "failed to authorize client: missing authorization header")
}
}
// AuthenticateClient authenticates the client using provided credentials
// and authorizes access by validating required grant types and scopes.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - req *ClientAuthenticationRequest: The request containing client credentials and required scopes.
//
// Returns:
// - error: An error if authentication or authorization fails.
func (c *clientAuthenticator) AuthenticateClient(
ctx context.Context,
req *clients.ClientAuthenticationRequest,
) error {
requestID := utils.GetRequestID(ctx)
existingClient, err := c.clientRepo.GetClientByID(ctx, req.ClientID)
if err != nil {
c.logger.Error(c.module, requestID, "[AuthenticateClient]: Failed to retrieve client by ID: %v", err)
return errors.Wrap(err, errors.ErrCodeInvalidClient, "failed to retrieve client")
}
if req.RedirectURI != "" && !existingClient.HasRedirectURI(req.RedirectURI) {
c.logger.Error(c.module, requestID, "[AuthenticateClient]: Invalid redirect URI: %s", req.RedirectURI)
return errors.New(errors.ErrCodeInvalidRedirectURI, "invalid redirect URI")
}
if req.ClientSecret != "" {
if !existingClient.IsConfidential() {
return errors.New(errors.ErrCodeUnauthorizedClient, "client is not confidential")
} else if !existingClient.SecretsMatch(req.ClientSecret) {
return errors.New(errors.ErrCodeInvalidClient, "invalid credentials")
}
}
scopesArr := strings.Split(req.RequestedScopes.String(), " ")
if !existingClient.CanRequestScopes {
for _, scope := range scopesArr {
if !existingClient.HasScope(types.Scope(scope)) {
return errors.New(errors.ErrCodeInsufficientScope, "client does not have the required scope(s)")
}
}
}
if req.RequestedGrant != "" && !existingClient.HasGrantType(req.RequestedGrant) {
return errors.New(errors.ErrCodeUnauthorizedClient, "client does not have the required grant type")
}
return nil
}
func (c *clientAuthenticator) authenticateWithBearerToken(
ctx context.Context,
r *http.Request,
requiredScope types.Scope,
) error {
requestID := utils.GetRequestID(ctx)
bearerToken, err := web.ExtractBearerToken(r)
if err != nil {
c.logger.Error(c.module, requestID, "[authenticateWithBearerToken]: Failed to extract bearer token from header: %v", err)
return errors.Wrap(err, errors.ErrCodeInvalidGrant, "failed to extract bearer token from header")
}
if err := c.tokenValidator.ValidateToken(ctx, bearerToken); err != nil {
c.logger.Error(c.module, requestID, "[authenticateWithBearerToken]: Failed to validate token: %v", err)
return errors.Wrap(err, "", "failed to validate bearer token")
}
claims, err := c.tokenParser.ParseToken(ctx, bearerToken)
if err != nil {
c.logger.Error(c.module, requestID, "[authenticateWithBearerToken]: Failed to parse bearer token: %v", err)
return errors.Wrap(err, "", "failed to parse bearer token")
}
clientID := claims.Audience
req := &clients.ClientAuthenticationRequest{
ClientID: clientID,
RequestedScopes: requiredScope,
}
if err := c.AuthenticateClient(ctx, req); err != nil {
c.logger.Error(c.module, requestID, "[authenticateWithBearerToken]: Failed to authenticate client: %v", err)
return errors.Wrap(err, "", "failed to authenticate client")
}
return nil
}
func (c *clientAuthenticator) authenticateWithBasicAuth(
ctx context.Context,
r *http.Request,
requiredScope types.Scope,
) error {
requestID := utils.GetRequestID(ctx)
clientID, clientSecret, err := web.ExtractClientBasicAuth(r)
if err != nil {
c.logger.Error(c.module, requestID, "[authenticateWithBasicAuth]: Failed to retrieve client credentials: %v", err)
return errors.Wrap(err, "", "failed to extract client credentials from auth header")
}
req := &clients.ClientAuthenticationRequest{
ClientID: clientID,
ClientSecret: clientSecret,
RequestedScopes: requiredScope,
}
if err := c.AuthenticateClient(ctx, req); err != nil {
c.logger.Error(c.module, requestID, "[authenticateWithBasicAuth]: Failed to authenticate client: %v", err)
return errors.Wrap(err, "", "failed to authenticate client")
}
return nil
}
package service
import (
"context"
"net/http"
"net/url"
"strconv"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
authzCode "github.com/vigiloauth/vigilo/v2/internal/domain/authzcode"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
consent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
var _ client.ClientAuthorization = (*clientAuthorization)(nil)
type clientAuthorization struct {
validator client.ClientValidator
manager client.ClientManager
session session.SessionManager
consent consent.UserConsentService
issuer authzCode.AuthorizationCodeIssuer
logger *config.Logger
module string
}
func NewClientAuthorization(
validator client.ClientValidator,
manager client.ClientManager,
session session.SessionManager,
consent consent.UserConsentService,
issuer authzCode.AuthorizationCodeIssuer,
) client.ClientAuthorization {
return &clientAuthorization{
validator: validator,
manager: manager,
session: session,
consent: consent,
issuer: issuer,
logger: config.GetServerConfig().Logger(),
module: "Client Authorization",
}
}
// Authorize handles the authorization logic for a client request.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - req *ClientAuthorizationRequest: The client authorization request.
//
// Returns:
// - string: The redirect URL, or an empty string if authorization failed.
// - error: An error message, if any.
//
// Errors:
// - Returns an error message if the user is not authenticated, consent is denied, or authorization code generation fails.
func (c *clientAuthorization) Authorize(
ctx context.Context,
req *client.ClientAuthorizationRequest,
) (string, error) {
requestID := utils.GetRequestID(ctx)
if req.RequestURI != "" {
return web.BuildErrorURL(
errors.ErrCodeRequestURINotSupported,
"request URIs are not currently supported",
req.State,
req.RedirectURI,
), nil
}
if req.RequestObject != "" {
return web.BuildErrorURL(
errors.ErrCodeRequestObjectNotSupported,
"request objects are not currently supported",
req.State,
req.RedirectURI,
), nil
}
if req.ResponseType == "" {
return web.BuildErrorURL(
errors.ErrCodeInvalidRequest,
"response_type is required",
req.State,
req.RedirectURI,
), nil
}
client, err := c.manager.GetClientByID(ctx, req.ClientID)
if err != nil {
c.logger.Error(c.module, requestID, "[Authorize]: Failed to get client by ID: %v", err)
return "", errors.New(errors.ErrCodeUnauthorizedClient, "invalid client credentials")
}
req.Client = client
if err := c.validator.ValidateAuthorizationRequest(ctx, req); err != nil {
c.logger.Error(c.module, requestID, "[Authorize]: Authorization request validation failed: %v", err)
return "", errors.Wrap(err, "", "failed to authorize request")
}
if c.shouldForceLogin(ctx, req) {
return c.buildLoginRedirectURL(req), nil
}
userID, isAuthenticated := c.isUserAuthenticated(ctx, requestID, req.HTTPRequest)
if c.shouldRejectUnauthenticatedUser(req, isAuthenticated) {
return c.buildLoginRequiredErrorURL(req), nil
}
if !isAuthenticated {
return c.buildLoginRedirectURL(req), nil
}
req.UserID = userID
req.UserAuthenticationTime = c.getUserAuthenticationTime(ctx, requestID, req.HTTPRequest)
if c.shouldRejectMissingConsent(ctx, req, isAuthenticated) {
return c.buildConsentRequiredErrorURL(req), nil
}
if url := c.handleUserConsent(ctx, req); url != "" {
return url, nil
}
authCode, err := c.issuer.IssueAuthorizationCode(ctx, req)
if err != nil {
c.logger.Error(requestID, c.module, "[Authorize]: Failed to issue authorization code: %v", err)
return "", errors.New(errors.ErrCodeInternalServerError, "failed to issue authorization code")
}
return c.buildRedirectURL(req.RedirectURI, authCode, req.State, req.Nonce), nil
}
func (c *clientAuthorization) shouldForceLogin(ctx context.Context, req *client.ClientAuthorizationRequest) bool {
if req.Prompt == constants.PromptLogin {
return true
}
maxAge := req.MaxAge
if maxAge != "" {
if maxAge == "0" {
return true
}
maxAgeSeconds, err := strconv.ParseInt(maxAge, 10, 64)
if err != nil {
c.logger.Warn(c.module, utils.GetRequestID(ctx), "[shouldForceLogin]: Failed to parse max_age: %v", err)
return true
}
lastAuthTime, err := c.session.GetUserAuthenticationTime(ctx, req.HTTPRequest)
if err != nil {
return true
}
currentTime := time.Now().Unix()
secondsSinceLastLogin := currentTime - lastAuthTime
c.logger.Debug(c.module, utils.GetRequestID(ctx), "[shouldForceLogin]: seconds since last login: %d", secondsSinceLastLogin)
if secondsSinceLastLogin > maxAgeSeconds {
return true
}
}
return false
}
func (c *clientAuthorization) buildLoginRedirectURL(req *client.ClientAuthorizationRequest) string {
return web.BuildRedirectURL(
req.ClientID,
req.RedirectURI,
req.Scope.String(),
req.ResponseType,
req.State,
req.Nonce,
req.Prompt,
req.Display,
req.ACRValues,
domain.SerializeClaims(req.ClaimsRequest),
"authenticate",
)
}
func (c *clientAuthorization) buildLoginRequiredErrorURL(req *client.ClientAuthorizationRequest) string {
return web.BuildErrorURL(
errors.ErrCodeLoginRequired,
"authentication required to continue",
req.State, req.RedirectURI,
)
}
func (c *clientAuthorization) isUserAuthenticated(ctx context.Context, requestID string, r *http.Request) (string, bool) {
userID, err := c.session.GetUserIDFromSession(ctx, r)
if err != nil {
c.logger.Warn(c.module, requestID, "[isUserAuthenticated]: User is not authenticated: %v", err)
return "", false
}
if userID == "" {
return "", false
}
return userID, true
}
func (c *clientAuthorization) shouldRejectUnauthenticatedUser(req *client.ClientAuthorizationRequest, isAuthenticated bool) bool {
return req.Prompt == constants.PromptNone && !isAuthenticated
}
func (c *clientAuthorization) getUserAuthenticationTime(ctx context.Context, requestID string, r *http.Request) time.Time {
authTime, err := c.session.GetUserAuthenticationTime(ctx, r)
if err != nil {
c.logger.Warn(c.module, requestID, "[getUserAuthenticationTime]: Failed to get session data: %v", err)
return time.Time{}
}
if authTime < 0 {
return time.Time{}
}
return time.Unix(authTime, 0)
}
func (c *clientAuthorization) shouldRejectMissingConsent(ctx context.Context, req *client.ClientAuthorizationRequest, isAuthenticated bool) bool {
return req.Prompt == constants.PromptNone && isAuthenticated && !c.hasPreConfiguredConsent(ctx, req)
}
func (c *clientAuthorization) hasPreConfiguredConsent(ctx context.Context, req *client.ClientAuthorizationRequest) bool {
requestID := utils.GetRequestID(ctx)
hasConsent, err := c.consent.CheckUserConsent(ctx, req.UserID, req.ClientID, req.Scope)
if err != nil {
c.logger.Error(c.module, requestID, "Failed to check user consent, user=[%s]: %v", utils.TruncateSensitive(req.UserID), err)
return false
}
return hasConsent
}
func (c *clientAuthorization) buildConsentRequiredErrorURL(req *client.ClientAuthorizationRequest) string {
return web.BuildErrorURL(
errors.ErrCodeConsentRequired,
"consent required to continue",
req.State, req.RedirectURI,
)
}
func (c *clientAuthorization) handleUserConsent(ctx context.Context, req *client.ClientAuthorizationRequest) string {
requestID := utils.GetRequestID(ctx)
if !c.hasPreConfiguredConsent(ctx, req) {
if !req.ConsentApproved {
c.logger.Warn(c.module, requestID, "Consent required, redirecting to consent URL")
consentURL := web.BuildRedirectURL(
req.ClientID,
req.RedirectURI,
req.Scope.String(),
req.ResponseType,
req.State,
req.Nonce,
req.Prompt,
req.Display,
req.ACRValues,
domain.SerializeClaims(req.ClaimsRequest),
"consent",
)
return consentURL
}
}
return ""
}
func (c *clientAuthorization) buildRedirectURL(redirectURI, code, state, nonce string) string {
queryParams := url.Values{}
queryParams.Add(constants.CodeURLValue, code)
if state != "" {
queryParams.Add(constants.StateReqField, state)
}
if nonce != "" {
queryParams.Add(constants.NonceReqField, nonce)
}
return redirectURI + "?" + queryParams.Encode()
}
package service
import (
"context"
"fmt"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
clients "github.com/vigiloauth/vigilo/v2/internal/domain/client"
encryption "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
var _ clients.ClientCreator = (*clientCreator)(nil)
type clientCreator struct {
repo clients.ClientRepository
validator clients.ClientValidator
issuer tokens.TokenIssuer
encryption encryption.Cryptographer
logger *config.Logger
module string
}
func NewClientCreator(
repo clients.ClientRepository,
validator clients.ClientValidator,
issuer tokens.TokenIssuer,
encryption encryption.Cryptographer,
) clients.ClientCreator {
return &clientCreator{
repo: repo,
validator: validator,
issuer: issuer,
encryption: encryption,
logger: config.GetServerConfig().Logger(),
module: "Client Creator",
}
}
func (c *clientCreator) Register(
ctx context.Context,
req *clients.ClientRegistrationRequest,
) (*clients.ClientRegistrationResponse, error) {
requestID := utils.GetRequestID(ctx)
if err := c.validator.ValidateRegistrationRequest(ctx, req); err != nil {
c.logger.Error(c.module, requestID, "[Register]: Failed to validate client")
return nil, errors.Wrap(err, "", "failed to validate client")
}
client := clients.NewClientFromRegistrationRequest(req)
client.ID = constants.ClientIDPrefix + utils.GenerateUUID()
if client.Type == types.ConfidentialClient {
if err := c.generateClientSecret(requestID, client); err != nil {
c.logger.Error(c.module, requestID, "[Register]: Failed to generate client secret: %v", err)
return nil, errors.Wrap(err, "", "failed to generate secret")
}
}
requestedScopes := types.CombineScopes(client.Scopes...)
registrationAccessToken, err := c.issuer.IssueAccessToken(
ctx,
client.ID, "",
requestedScopes, "", "",
)
if err != nil {
c.logger.Error(c.module, requestID, "[RegisterClient]: Failed to generate registration access token: %v", err)
return nil, errors.Wrap(err, "", "failed to generate the registration access token")
}
client.CreatedAt, client.UpdatedAt, client.IDIssuedAt = time.Now(), time.Now(), time.Now()
client.RegistrationClientURI = c.buildClientConfigurationEndpoint(client.ID)
client.RegistrationAccessToken = registrationAccessToken
if err := c.repo.SaveClient(ctx, client); err != nil {
c.logger.Error(c.module, requestID, "[RegisterClient]: Failed to save client: %v", err)
return nil, errors.Wrap(err, "", "failed to register client")
}
return clients.NewClientRegistrationResponseFromClient(client), nil
}
func (c *clientCreator) generateClientSecret(requestID string, client *clients.Client) error {
const clientSecretLength int = 32
plainSecret, err := c.encryption.GenerateRandomString(clientSecretLength)
if err != nil {
c.logger.Error(c.module, requestID, "[Register]: Failed to generate client secret: %v", err)
return errors.New(errors.ErrCodeRandomGenerationFailed, "failed to generate client secret")
}
hashedSecret, err := c.encryption.HashString(plainSecret)
if err != nil {
c.logger.Error(c.module, requestID, "[Register]: Failed to encrypt client secret: %v", err)
return errors.New(errors.ErrCodeHashingFailed, "failed to hash client secret")
}
client.Secret = hashedSecret
client.SecretExpiration = 0
return nil
}
func (c *clientCreator) buildClientConfigurationEndpoint(clientID string) string {
URL := config.GetServerConfig().URL()
return fmt.Sprintf("%s%s/%s", URL, web.ClientEndpoints.ClientConfiguration, clientID)
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
clients "github.com/vigiloauth/vigilo/v2/internal/domain/client"
crypto "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
var _ clients.ClientManager = (*clientManager)(nil)
const clientSecretLength int = 32
type clientManager struct {
repo clients.ClientRepository
validator clients.ClientValidator
authenticator clients.ClientAuthenticator
cryptographer crypto.Cryptographer
logger *config.Logger
module string
}
func NewClientManager(
repo clients.ClientRepository,
validator clients.ClientValidator,
authenticator clients.ClientAuthenticator,
cryptographer crypto.Cryptographer,
) clients.ClientManager {
return &clientManager{
repo: repo,
validator: validator,
authenticator: authenticator,
cryptographer: cryptographer,
logger: config.GetServerConfig().Logger(),
module: "Client Manager",
}
}
// RegenerateClientSecret regenerates the client secret for a given client ID.
// It returns a response containing the new client secret and its expiration time.
//
// Parameters:
// - ctx context.Context: The context for the operation.
// - clientID string: The ID of the client for which to regenerate the secret.
//
// Returns:
// - *ClientSecretRegenerationResponse: A pointer to ClientSecretRegenerationResponse containing the new secret and expiration time.
// - error: An error if the operation fails, or nil if successful.
func (c *clientManager) RegenerateClientSecret(
ctx context.Context,
clientID string,
) (*clients.ClientSecretRegenerationResponse, error) {
requestID := utils.GetRequestID(ctx)
client, err := c.repo.GetClientByID(ctx, clientID)
if err != nil {
c.logger.Error(c.module, requestID, "[RegenerateClientSecret]: Failed to retrieve client: %v", err)
return nil, errors.Wrap(err, errors.ErrCodeUnauthorized, "failed to retrieve client")
}
if !client.IsConfidential() {
return nil, errors.New(errors.ErrCodeInvalidClient, "invalid client credentials")
}
req := &clients.ClientAuthenticationRequest{
ClientID: clientID,
ClientSecret: client.Secret,
RequestedGrant: constants.ClientCredentialsGrantType,
}
if err := c.authenticator.AuthenticateClient(ctx, req); err != nil {
c.logger.Error(c.module, requestID, "[RegenerateClientSecret]: Failed to authenticate request: %v", err)
return nil, errors.Wrap(err, "", "failed to validate client")
}
clientSecret, err := c.cryptographer.GenerateRandomString(clientSecretLength)
if err != nil {
c.logger.Error(c.module, requestID, "[RegenerateClientSecret]: Failed to generate client secret: %v", err)
return nil, errors.NewInternalServerError(err.Error()) //nolint:wrapcheck
}
client.Secret, err = c.cryptographer.HashString(clientSecret)
if err != nil {
c.logger.Error(c.module, requestID, "[RegenerateClientSecret]: Failed to encrypt client secret: %v", err)
return nil, errors.NewInternalServerError(err.Error()) //nolint:wrapcheck
}
client.UpdatedAt = time.Now()
if err := c.repo.UpdateClient(ctx, client); err != nil {
c.logger.Error(c.module, requestID, "[RegenerateClientSecret]: Failed to update client: %v", err)
return nil, errors.Wrap(err, "", "failed to update client")
}
return &clients.ClientSecretRegenerationResponse{
ClientID: clientID,
ClientSecret: clientSecret,
UpdatedAt: client.UpdatedAt,
}, nil
}
// GetClientByID retrieves a client by its ID.
//
// Parameters:
// - ctx context.Context: The context for the operation.
// - clientID string: The ID of the client to retrieve.
//
// Returns:
// - *Client: A pointer to the Client object if found, or nil if not found.
// - error: An error if the operation fails, or nil if successful.
func (c *clientManager) GetClientByID(ctx context.Context, clientID string) (*clients.Client, error) {
requestID := utils.GetRequestID(ctx)
client, err := c.repo.GetClientByID(ctx, clientID)
if err != nil {
c.logger.Error(c.module, requestID, "[GetClientByID]: Failed to retrieve client by ID: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve client")
}
return client, nil
}
// GetClientInformation retrieves client information by client ID and registration access token.
// It returns a response containing the client information.
//
// Parameters:
// - ctx context.Context: The context for the operation.
// - clientID string: The ID of the client to retrieve information for.
// - registrationAccessToken string: The registration access token for authentication.
//
// Returns:
// - *ClientInformationResponse: A pointer to ClientInformationResponse containing the client information.
// - error: An error if the operation fails, or nil if successful.
func (c *clientManager) GetClientInformation(
ctx context.Context,
clientID string,
registrationAccessToken string,
) (*clients.ClientInformationResponse, error) {
requestID := utils.GetRequestID(ctx)
if err := c.validator.ValidateClientAndRegistrationAccessToken(ctx, clientID, registrationAccessToken); err != nil {
c.logger.Error(c.module, requestID, "[GetClientInformation]: Failed to validate request")
return nil, errors.Wrap(err, "", "failed to validate client")
}
// The error can be ignored here since the client was validated in the previous method
client, _ := c.repo.GetClientByID(ctx, clientID)
registrationClientURI := config.GetServerConfig().BaseURL() + web.ClientEndpoints.Register
return clients.NewClientInformationResponse(
client.ID,
client.Secret,
registrationClientURI,
registrationAccessToken,
), nil
}
// UpdateClientInformation updates the client information for a given client ID.
//
// Parameters:
// - ctx context.Context: The context for the operation.
// - clientID string: The ID of the client to update.
// - registrationAccessToken string: The registration access token for authentication.
// - request *ClientUpdateRequest: A pointer to ClientUpdateRequest containing the updated information.
//
// Returns:
// - *ClientInformationResponse: A pointer to ClientInformationResponse containing the updated client information.
// - error: An error if the operation fails, or nil if successful.
func (c *clientManager) UpdateClientInformation(
ctx context.Context,
clientID string,
registrationAccessToken string,
request *clients.ClientUpdateRequest,
) (*clients.ClientInformationResponse, error) {
requestID := utils.GetRequestID(ctx)
if err := c.validator.ValidateUpdateRequest(ctx, request); err != nil {
c.logger.Error(c.module, requestID, "[UpdateClientInformation]: Failed to validate request: %v", err)
return nil, errors.Wrap(err, "", "failed to validate request")
}
if err := c.validator.ValidateClientAndRegistrationAccessToken(ctx, clientID, registrationAccessToken); err != nil {
c.logger.Error(c.module, requestID, "[UpdateClientInformation]: Failed to validate request")
return nil, errors.Wrap(err, "", "failed to validate client")
}
// The error can be ignored here since the client was validated in the previous method
client, _ := c.repo.GetClientByID(ctx, clientID)
if client.IsConfidential() {
request.Type = types.ConfidentialClient
if !client.SecretsMatch(request.Secret) {
c.logger.Error(c.module, requestID, "[UpdateClientInformation]: Client secret's don't match")
return nil, errors.New(errors.ErrCodeUnauthorized, "the provided client secret is invalid or does not match the registered credentials")
}
}
client.UpdateValues(request)
if err := c.repo.UpdateClient(ctx, client); err != nil {
c.logger.Error(c.module, requestID, "[UpdateClientInformation]: Failed to update client: %v", err)
return nil, errors.Wrap(err, "", "failed to update client")
}
registrationClientURI := config.GetServerConfig().BaseURL() + web.ClientEndpoints.Register
return clients.NewClientInformationResponse(
client.ID,
client.Secret,
registrationClientURI,
registrationAccessToken,
), nil
}
// DeleteClientInformation deletes the client information for a given client ID.
//
// Parameters:
// - ctx context.Context: The context for the operation.
// - clientID string: The ID of the client to delete.
// - registrationAccessToken string: The registration access token for authentication.
//
// Returns:
// - error: An error if the operation fails, or nil if successful.
func (c *clientManager) DeleteClientInformation(ctx context.Context, clientID string, registrationAccessToken string) error {
requestID := utils.GetRequestID(ctx)
if err := c.validator.ValidateClientAndRegistrationAccessToken(ctx, clientID, registrationAccessToken); err != nil {
c.logger.Error(c.module, requestID, "[DeleteClientInformation]: Failed to validate request")
return errors.Wrap(err, "", "failed to validate client")
}
if err := c.repo.DeleteClientByID(ctx, clientID); err != nil {
c.logger.Error(c.module, requestID, "[DeleteClientInformation]: Failed to delete client: %v", err)
return errors.Wrap(err, "", "failed to delete client")
}
return nil
}
package service
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
clients "github.com/vigiloauth/vigilo/v2/internal/domain/client"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ clients.ClientValidator = (*clientValidator)(nil)
const HTTPS string = "https"
type clientValidator struct {
repo clients.ClientRepository
manager tokens.TokenManager
validator tokens.TokenValidator
parser tokens.TokenParser
logger *config.Logger
module string
sectorURIFetchTimeout time.Duration
}
func NewClientValidator(
repo clients.ClientRepository,
manager tokens.TokenManager,
validator tokens.TokenValidator,
parser tokens.TokenParser,
) clients.ClientValidator {
return &clientValidator{
repo: repo,
manager: manager,
validator: validator,
parser: parser,
logger: config.GetServerConfig().Logger(),
module: "Client Request Validator",
}
}
func (c *clientValidator) ValidateRegistrationRequest(ctx context.Context, req *clients.ClientRegistrationRequest) error {
requestID := utils.GetRequestID(ctx)
if req.Name == "" {
c.logger.Warn(c.module, requestID, "[ValidateRegistrationRequest]: client_name is empty")
return errors.New(errors.ErrCodeInvalidClientMetadata, "client_name is empty")
}
if err := c.validateApplicationType(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientRegistrationRequest]: An error occurred validating the application type: %v", err)
return err
}
if err := c.validateTokenEndpointAuthMethod(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientRegistrationRequest]: An error occurred validating the token endpoint auth method: %v", err)
return err
}
c.determineClientType(req)
if err := c.validateGrantAndResponseTypes(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientRegistrationRequest]: An error occurred validating grant and response types: %v", err)
return err
}
if err := c.validateURIS(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientRegistrationRequest]: An error occurred validating client URIS: %v", err)
return errors.Wrap(err, "", "the value of one or more redirection URIs is invalid")
}
if err := c.validateScopes(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientRegistrationRequest]: An error occurred validating scopes: %v", err)
return err
}
return nil
}
func (c *clientValidator) ValidateUpdateRequest(ctx context.Context, req *clients.ClientUpdateRequest) error {
requestID := utils.GetRequestID(ctx)
if err := c.validateGrantType(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientUpdateRequest]: An error occurred validating grant types: %v", err)
return err
}
if err := c.validateURIS(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientUpdateRequest]: An error occurred validating client URIS: %v", err)
return err
}
if err := c.validateScopes(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientUpdateRequest]: An error occurred validating scopes: %v", err)
return err
}
if err := c.validateResponseTypes(requestID, req); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientUpdateRequest]: An error occurred validating response types: %v", err)
return err
}
return nil
}
func (c *clientValidator) ValidateAuthorizationRequest(ctx context.Context, req *clients.ClientAuthorizationRequest) error {
requestID := utils.GetRequestID(ctx)
if !req.Client.HasRedirectURI(req.RedirectURI) {
return errors.New(errors.ErrCodeInvalidRedirectURI, "the client provided an unregistered redirect URI")
}
if !req.Client.HasGrantType(constants.AuthorizationCodeGrantType) {
c.logger.Error(c.module, requestID, "Failed to validate client authorization: client does not have the required grant types")
return errors.New(errors.ErrCodeInvalidGrant, "authorization code grant is required for this request")
}
if !req.Client.HasResponseType(constants.CodeResponseType) || !req.Client.HasResponseType(req.ResponseType) {
c.logger.Error(c.module, requestID, "Failed to validate client authorization request: client does not have the code response type")
return errors.New(errors.ErrCodeInvalidClient, "code response type is required to receive an authorization code")
}
if req.Client.Type == types.PublicClient && req.CodeChallenge == "" {
return errors.New(errors.ErrCodeInvalidRequest, "public clients are required to use PKCE")
}
if req.CodeChallenge != "" {
if req.CodeChallengeMethod == "" {
c.logger.Warn(c.module, requestID, "Code challenge method was not provided, defaulting to 'plain'")
req.CodeChallengeMethod = types.PlainCodeChallengeMethod
}
if err := c.validateCodeChallengeMethod(requestID, req.CodeChallengeMethod); err != nil {
c.logger.Error(c.module, requestID, "Failed to validate authorization request: %v", err)
return err
}
if err := c.validateCodeChallenge(requestID, req.CodeChallenge); err != nil {
c.logger.Error(c.module, requestID, "Failed to validate authorization request: %v", err)
return err
}
}
return nil
}
func (c *clientValidator) ValidateRedirectURI(ctx context.Context, redirectURI string, client *clients.Client) error {
requestID := utils.GetRequestID(ctx)
parsedURI, err := utils.ParseURI(requestID)
if err != nil {
return errors.Wrap(err, "", "invalid redirect URI format")
}
if err := utils.ValidateRedirectURIScheme(parsedURI); err != nil {
return errors.Wrap(err, "", "failed to validate URL scheme")
}
switch client.Type {
case types.PublicClient:
if err := utils.ValidatePublicURIScheme(parsedURI); err != nil {
return errors.Wrap(err, "", "failed to validate public client redirect URI")
}
case types.ConfidentialClient:
if err := utils.ValidateConfidentialURIScheme(parsedURI); err != nil {
return errors.Wrap(err, "", "failed to valid confidential client redirect URI")
}
default:
c.logger.Error(c.module, requestID, "[ValidateRedirectURI]: Invalid client type '%s'", client.Type.String())
return errors.New(errors.ErrCodeInvalidClient, "invalid client type: must be confidential or public")
}
if !client.HasRedirectURI(redirectURI) {
c.logger.Error(c.module, requestID, "[ValidateRedirectURI]: Client=[%s] does not have requested redirect URI=[%s]",
utils.TruncateSensitive(client.ID),
utils.SanitizeURL(redirectURI),
)
return errors.New(errors.ErrCodeInvalidRequest, "invalid redirect_uri")
}
return nil
}
func (c *clientValidator) ValidateClientAndRegistrationAccessToken(
ctx context.Context,
clientID string,
registrationAccessToken string,
) (err error) {
requestID := utils.GetRequestID(ctx)
defer func() {
if err != nil {
if err := c.manager.BlacklistToken(ctx, registrationAccessToken); err != nil {
c.logger.Warn(c.module, requestID, "[ValidateClientAndRegistrationAccessToken]: Failed to blacklist registration access token: %v", err)
}
}
}()
client, err := c.repo.GetClientByID(ctx, clientID)
if err != nil {
return errors.New(errors.ErrCodeUnauthorized, "invalid client credentials")
}
if err := c.validator.ValidateToken(ctx, registrationAccessToken); err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientAndRegistrationAccessToken]: Failed to validate registration access token: %v", err)
return errors.Wrap(err, "", "invalid registration access token")
}
tokenClaims, err := c.parser.ParseToken(ctx, registrationAccessToken)
if err != nil {
c.logger.Error(c.module, requestID, "[ValidateClientAndRegistrationAccessToken]: Failed to parse registration access token: %v", err)
return errors.Wrap(err, "", "invalid registration access token")
}
if client.ID != tokenClaims.Subject {
c.logger.Error(c.module, requestID, "[ValidateClientAndRegistrationAccessToken]: the registration access token subject does not match with the client ID in the request")
return errors.New(errors.ErrCodeUnauthorized, "the registration access token subject does not match with the client ID in the request")
}
return nil
}
func (c *clientValidator) validateApplicationType(requestID string, req *clients.ClientRegistrationRequest) error {
if req.ApplicationType == "" {
c.logger.Debug(c.module, requestID, "No application type given, will be determined dynamically")
return nil
}
if !constants.ValidApplicationTypes[req.ApplicationType] {
c.logger.Error(c.module, requestID, "Invalid application type provided: %v", req.ApplicationType)
return errors.New(
errors.ErrCodeInvalidClientMetadata,
fmt.Sprintf("invalid application type: %s", req.ApplicationType),
)
}
return nil
}
func (c *clientValidator) validateTokenEndpointAuthMethod(requestID string, req *clients.ClientRegistrationRequest) error {
if req.TokenEndpointAuthMethod == "" {
c.logger.Warn(c.module, requestID, "No token endpoint auth method provided")
return nil
}
if !types.SupportedTokenEndpointAuthMethods[req.TokenEndpointAuthMethod] {
return errors.New(
errors.ErrCodeInvalidClientMetadata,
fmt.Sprintf("invalid token endpoint auth method: %s", req.TokenEndpointAuthMethod),
)
}
return nil
}
func (c *clientValidator) determineClientType(req *clients.ClientRegistrationRequest) {
if req.TokenEndpointAuthMethod == "" && req.ApplicationType == "" {
req.Type = types.ConfidentialClient
req.TokenEndpointAuthMethod = types.ClientSecretBasicTokenAuth
req.ApplicationType = constants.WebApplicationType
return
}
if req.TokenEndpointAuthMethod == "" {
switch req.ApplicationType {
case constants.WebApplicationType:
req.Type = types.ConfidentialClient
return
case constants.NativeApplicationType:
req.Type = types.PublicClient
return
}
}
switch req.TokenEndpointAuthMethod {
case types.NoTokenAuth:
req.Type = types.PublicClient
return
case types.ClientSecretBasicTokenAuth, types.ClientSecretPostTokenAuth:
req.Type = types.ConfidentialClient
return
}
}
// nolint
func (c *clientValidator) validateGrantAndResponseTypes(requestID string, req *clients.ClientRegistrationRequest) error {
if len(req.GrantTypes) == 0 {
c.logger.Error(c.module, requestID, "No grant types were requested")
return errors.New(errors.ErrCodeInvalidClientMetadata, "at least one grant_type must be requested")
}
if len(req.ResponseTypes) == 0 && (utils.Contains(req.GrantTypes, constants.AuthorizationCodeGrantType) || utils.Contains(req.GrantTypes, constants.ImplicitGrantType)) {
return errors.New(errors.ErrCodeInvalidClientMetadata, "response_types are required for authorization_code or implicit grant types")
}
if req.Type == types.PublicClient {
if utils.Contains(req.GrantTypes, constants.ClientCredentialsGrantType) {
c.logger.Warn(c.module, requestID, "Validation failed: Public client requested client_credentials grant")
return errors.New(errors.ErrCodeInvalidClientMetadata, "public clients cannot request the client_credentials grant")
}
if utils.Contains(req.GrantTypes, constants.PasswordGrantType) {
c.logger.Warn(c.module, requestID, "Validation failed: Public client requested password grant")
return errors.New(errors.ErrCodeInvalidClientMetadata, "public clients cannot request the password grant")
}
if req.Type == types.PublicClient && utils.Contains(req.GrantTypes, constants.AuthorizationCodeGrantType) {
req.RequiresPKCE = true
}
}
requestsAuthCodeGrant := utils.Contains(req.GrantTypes, constants.AuthorizationCodeGrantType)
requestsImplicitGrant := utils.Contains(req.GrantTypes, constants.ImplicitGrantType)
requestsCodeResponseType := utils.ContainsResponseType(req.ResponseTypes, constants.CodeResponseType)
requestsIDTokenResponseType := utils.ContainsResponseType(req.ResponseTypes, constants.IDTokenResponseType)
requestsTokenResponseType := utils.ContainsResponseType(req.ResponseTypes, constants.TokenResponseType)
if requestsAuthCodeGrant || requestsImplicitGrant {
if !requestsCodeResponseType && !requestsIDTokenResponseType && !requestsTokenResponseType {
c.logger.Warn(c.module, requestID, "Validation failed: Auth Code/Implicit grants requested without corresponding response types.")
return errors.New(errors.ErrCodeInvalidClientMetadata, "client requesting authorization_code or implicit grants must include 'code', 'id_token', or 'token' in response_types")
}
}
if requestsAuthCodeGrant && !requestsCodeResponseType {
c.logger.Warn(c.module, requestID, "Validation failed: Auth Code grant requested without 'code' response type.")
return errors.New(errors.ErrCodeInvalidClientMetadata, "client requesting authorization_code grant must include 'code' in response_types")
}
if requestsCodeResponseType && !requestsAuthCodeGrant {
c.logger.Warn(c.module, requestID, "Validation failed: 'code' response type requested without Auth Code grant.")
return errors.New(errors.ErrCodeInvalidClientMetadata, "client requesting response_types including 'code' must include authorization_code grant")
}
usesImplicitResponseTypes := requestsIDTokenResponseType || requestsTokenResponseType
requestsImplicitFlowWithoutCode := usesImplicitResponseTypes && !requestsCodeResponseType
if requestsImplicitGrant && !requestsImplicitFlowWithoutCode && !requestsAuthCodeGrant { // Requests Implicit grant, but no implicit-only response types, and no Auth Code grant for hybrid
c.logger.Warn(c.module, requestID, "Validation failed: Implicit grant requested without corresponding response types or hybrid.")
return errors.New(errors.ErrCodeInvalidClientMetadata, "client requesting implicit grant must include 'id_token', 'token', or both in response_types (or request a hybrid flow)")
}
if requestsImplicitFlowWithoutCode && !requestsImplicitGrant { // Requests implicit-only response types, but no Implicit grant requested
c.logger.Warn(c.module, requestID, "Validation failed: Implicit flow response types requested without Implicit grant.")
return errors.New(errors.ErrCodeInvalidClientMetadata, "client requesting response_types including 'id_token' or 'token' (without 'code') must include implicit grant")
}
if utils.Contains(req.GrantTypes, constants.RefreshTokenGrantType) {
if !requestsAuthCodeGrant && !utils.Contains(req.GrantTypes, constants.PasswordGrantType) {
c.logger.Warn(c.module, requestID, "Validation failed: Refresh Token grant requested without a valid issuing grant.")
return errors.New(errors.ErrCodeInvalidClientMetadata, "refresh_token grant requires a grant type capable of issuing refresh tokens (e.g., authorization_code or password)")
}
}
return nil
}
func (c *clientValidator) validateURIS(requestID string, req clients.ClientRequest) error { //nolint
if req.GetType() == types.ConfidentialClient && req.HasGrantType(constants.AuthorizationCodeGrantType) && len(req.GetRedirectURIS()) == 0 {
return errors.New(errors.ErrCodeInvalidClientMetadata, "redirect URI(s) are required for confidential clients using the authorization code grant type")
}
if req.GetType() == types.PublicClient && len(req.GetRedirectURIS()) == 0 {
c.logger.Warn(c.module, requestID, "Validation failed: redirect_uris is empty for public client")
return errors.New(errors.ErrCodeInvalidClientMetadata, "redirect URI(s) are required for public clients")
}
if req.GetJwksURI() != "" {
if _, err := url.ParseRequestURI(req.GetJwksURI()); err != nil {
c.logger.Warn(c.module, requestID, "Invalid jwks_uri provided: %s", req.GetJwksURI())
return errors.New(errors.ErrCodeInvalidClientMetadata, "invalid jwks_uri format")
}
}
if err := c.validateSectorIdentifierURI(requestID, req.GetRedirectURIS(), req.GetSectorIdentifierURI()); err != nil {
c.logger.Error(c.module, requestID, "Failed to validate sector identifier URI: %v", err)
return err
}
if req.GetLogoURI() != "" {
if _, err := url.ParseRequestURI(req.GetLogoURI()); err != nil {
c.logger.Warn(c.module, requestID, "Invalid logo_uri: %s", req.GetLogoURI())
return errors.New(errors.ErrCodeInvalidClientMetadata, "invalid logo_uri format")
}
}
mobileSchemePattern := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9+.-]*:\/\/`)
for _, uri := range req.GetRedirectURIS() {
if uri == "" {
c.logger.Warn(c.module, requestID, "Empty redirect_uri found")
return errors.New(errors.ErrCodeInvalidRedirectURI, "'redirect_uri' is empty")
}
if strings.HasPrefix(uri, "http://localhost") || strings.HasPrefix(uri, "http://127.0.0.1") {
continue
}
if utils.ContainsWildcard(uri) {
c.logger.Warn(c.module, requestID, "Redirect URI contains wildcard: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, "redirect URIs cannot have wildcards")
}
parsedURI, err := url.Parse(uri)
if err != nil {
c.logger.Warn(c.module, requestID, "Malformed redirect URI: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, fmt.Sprintf("malformed redirect URI: %s", uri))
}
switch req.GetType() {
case types.ConfidentialClient:
if parsedURI.Scheme != HTTPS {
c.logger.Warn(c.module, requestID, "Confidential client redirect URI is not using HTTPS: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, "confidential clients must use HTTPS")
}
if net.ParseIP(parsedURI.Hostname()) != nil && !utils.IsLoopbackIP(parsedURI.Hostname()) {
c.logger.Warn(c.module, requestID, "Confidential client redirect URI is using IP address: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, "IP address not allowed as redirect URI hosts")
}
if parsedURI.Fragment != "" {
c.logger.Warn(c.module, requestID, "Confidential client redirect URI contains fragment: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, "fragment component not allowed")
}
case types.PublicClient:
isMobileScheme := mobileSchemePattern.MatchString(uri) && parsedURI.Scheme != "http" && parsedURI.Scheme != HTTPS
if isMobileScheme {
const uriLength int = 4
if len(parsedURI.Scheme) < uriLength {
c.logger.Warn(c.module, requestID, "Mobile URI scheme is too short: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, "mobile URI scheme is too short")
}
} else if parsedURI.Scheme != HTTPS {
c.logger.Warn(c.module, requestID, "Public client redirect URI is not using HTTPS: %s", uri)
return errors.New(errors.ErrCodeInvalidRedirectURI, "public clients must use HTTPS")
}
}
}
return nil
}
func (c *clientValidator) validateScopes(requestID string, req clients.ClientRequest) error {
if len(req.GetScopes()) == 0 {
return nil
}
for _, scope := range req.GetScopes() {
if _, ok := types.SupportedScopes[scope]; !ok {
c.logger.Warn(c.module, requestID, "Unsupported scope: %s", scope.String())
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("scope '%s' is not supported", scope))
}
}
if !utils.Contains(req.GetScopes(), types.OpenIDScope) {
requestedScopes := req.GetScopes()
newScopes := append(requestedScopes, types.OpenIDScope)
req.SetScopes(newScopes)
c.logger.Info(c.module, requestID, "Adding default 'oidc' scope to client")
}
return nil
}
func (c *clientValidator) validateSectorIdentifierURI(requestID string, redirectURIs []string, sectorIdentifierURI string) error {
if sectorIdentifierURI == "" {
return nil
}
parsedURI, err := url.Parse(sectorIdentifierURI)
if err != nil {
c.logger.Warn(c.module, requestID, "Malformed sector identifier URI: %s", sectorIdentifierURI)
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("malformed sector identifier URI: %s", sectorIdentifierURI))
}
if parsedURI.Scheme != HTTPS {
return errors.New(errors.ErrCodeInvalidClientMetadata, "sector identifier URI must use HTTPS")
}
client := http.Client{
Timeout: c.sectorURIFetchTimeout,
}
resp, err := client.Get(sectorIdentifierURI)
if err != nil {
c.logger.Warn(c.module, requestID, "Failed to fetch sector identifier URI (%s): %v", sectorIdentifierURI, err)
return errors.Wrap(err, errors.ErrCodeInvalidClientMetadata, "failed to fetch sector identifier URI")
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
c.logger.Warn(c.module, requestID, "Sector identifier URI (%s) returned non-200 status: %d", sectorIdentifierURI, resp.StatusCode)
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("sector identifier URI returned non-200 status: %d", resp.StatusCode))
}
contentType := resp.Header.Get("Content-Type")
if !strings.Contains(contentType, "application/json") {
c.logger.Warn(c.module, requestID, "Sector identifier URI (%s) returned unexpected Content-Type: %s", sectorIdentifierURI, contentType)
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("sector identifier URI returned unexpected Content-Type: %s", contentType))
}
var fetchedRedirectURIs []string
if err := json.NewDecoder(resp.Body).Decode(&fetchedRedirectURIs); err != nil {
c.logger.Warn(c.module, requestID, "Failed to decode JSON from sector identifier URI (%s): %v", sectorIdentifierURI, err)
return errors.Wrap(err, errors.ErrCodeInvalidClientMetadata, "failed to decode JSON from sector identifier URI")
}
if len(fetchedRedirectURIs) == 0 {
c.logger.Warn(c.module, requestID, "Sector identifier URI (%s) returned an empty array", sectorIdentifierURI)
return errors.New(errors.ErrCodeInvalidClientMetadata, "sector identifier URI returned an empty array")
}
for _, providedURI := range redirectURIs {
found := false
for _, fetchedURI := range fetchedRedirectURIs {
if providedURI == fetchedURI {
found = true
break
}
}
if !found {
c.logger.Warn(c.module, requestID, "Redirect URI '%s' not found in sector identifier URI (%s)", providedURI, sectorIdentifierURI)
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("redirect URI '%s' not found in sector identifier URI (%s)", providedURI, sectorIdentifierURI))
}
}
return nil
}
func (c *clientValidator) validateGrantType(requestID string, req clients.ClientRequest) error {
if len(req.GetGrantTypes()) == 0 {
c.logger.Warn(c.module, requestID, "Grant type validation failed: grant_types is empty")
return errors.New(errors.ErrCodeInvalidClientMetadata, "grant_types is empty")
}
validGrantTypes := constants.SupportedGrantTypes
for _, grantType := range req.GetGrantTypes() {
if _, ok := validGrantTypes[grantType]; !ok {
c.logger.Warn(c.module, requestID, "Unsupported grant type provided: %s", grantType)
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("grant type %s is not supported", grantType))
}
if req.GetType() == types.PublicClient {
if grantType == constants.ClientCredentialsGrantType || grantType == constants.PasswordGrantType {
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("grant type %s is not supported for public clients", grantType))
}
}
if grantType == constants.RefreshTokenGrantType && len(req.GetGrantTypes()) == 0 {
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("%s requires another grant type", grantType))
}
}
return nil
}
func (c *clientValidator) validateResponseTypes(requestID string, req clients.ClientRequest) error {
if len(req.GetResponseTypes()) == 0 {
c.logger.Warn(c.module, requestID, "Response type validation failed: response_types is empty")
return errors.New(errors.ErrCodeInvalidClientMetadata, "response_types is empty")
}
for _, responseType := range req.GetResponseTypes() {
if _, ok := constants.SupportedResponseTypes[responseType]; !ok {
c.logger.Warn(c.module, requestID, "Unsupported response type: %s", responseType)
return errors.New(errors.ErrCodeInvalidClientMetadata, fmt.Sprintf("response type '%s' is not supported", responseType))
}
}
// Validate compatibility with grant types
authCodeOrDeviceCode := utils.Contains(req.GetGrantTypes(), constants.AuthorizationCodeGrantType) || utils.Contains(req.GetGrantTypes(), constants.DeviceCodeGrantType)
implicitFlow := utils.Contains(req.GetGrantTypes(), constants.ImplicitGrantType)
idToken := utils.Contains(req.GetResponseTypes(), constants.IDTokenResponseType)
code := utils.Contains(req.GetResponseTypes(), constants.CodeResponseType)
token := utils.Contains(req.GetResponseTypes(), constants.TokenResponseType)
if authCodeOrDeviceCode && !code {
c.logger.Warn(c.module, requestID, "Incompatible response type: 'code' is required for grant types 'authorization_code' or 'device_code'")
return errors.New(errors.ErrCodeInvalidClientMetadata, "code response type is required for the authorization code or device code grant type")
}
if implicitFlow && !token {
c.logger.Warn(c.module, requestID, "Incompatible response type: 'token' is required for the 'implicit' grant type")
return errors.New(errors.ErrCodeInvalidClientMetadata, "token response type is required for the implicit flow grant type")
}
if idToken && !authCodeOrDeviceCode && !implicitFlow {
c.logger.Warn(c.module, requestID, "Incompatible response type: 'id_token' requires 'authorization_code' or 'implicit' grant type")
return errors.New(errors.ErrCodeInvalidClientMetadata, "ID token response type is only allowed with the authorization code, device code or implicit flow grant types")
}
return nil
}
func (c *clientValidator) validateCodeChallengeMethod(requestID string, codeChallengeMethod types.CodeChallengeMethod) error {
if _, ok := types.SupportedCodeChallengeMethods[codeChallengeMethod]; !ok {
c.logger.Error(c.module, requestID, "Failed to validate authorization request: invalid code challenge method: %s", codeChallengeMethod)
return errors.New(
errors.ErrCodeInvalidRequest,
fmt.Sprintf("invalid code challenge method: '%s'. Valid methods are 'plain' and 'SHA-256'", codeChallengeMethod),
)
}
return nil
}
func (c *clientValidator) validateCodeChallenge(requestID, codeChallenge string) error {
codeChallengeLength := len(codeChallenge)
if codeChallengeLength < 43 || codeChallengeLength > 128 {
c.logger.Error(c.module, requestID, "Failed to validate code challenge: code challenge does not meet length requirements")
return errors.New(
errors.ErrCodeInvalidRequest,
fmt.Sprintf("invalid code challenge length (%d): must be between 43 and 128 characters", codeChallengeLength),
)
}
validCodeChallengeRegex := regexp.MustCompile(`^[A-Za-z0-9._~-]+$`)
if !validCodeChallengeRegex.MatchString(codeChallenge) {
c.logger.Error(c.module, requestID, "Failed to validate code challenge: contains invalid characters")
return errors.New(errors.ErrCodeInvalidRequest, "invalid characters: only A-Z, a-z, 0-9, '-', and '_' are allowed (Base64 URL encoding)")
}
return nil
}
package service
import (
"context"
"net/http"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
cookies "github.com/vigiloauth/vigilo/v2/internal/domain/cookies"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ cookies.HTTPCookieService = (*httpCookieService)(nil)
type httpCookieService struct {
sessionCookieName string
domain string
enableHTTPS bool
logger *config.Logger
module string
}
func NewHTTPCookieService() cookies.HTTPCookieService {
return &httpCookieService{
sessionCookieName: config.GetServerConfig().SessionCookieName(),
domain: config.GetServerConfig().Domain(),
enableHTTPS: config.GetServerConfig().ForceHTTPS(),
logger: config.GetServerConfig().Logger(),
module: "HTTP Cookie Service",
}
}
// SetSessionCookie sets the session token in an HttpOnly cookie.
// It also sets the cookie's expiration time and other attributes.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - w http.ResponseWriter: The HTTP response writer.
// - sessionID string: The session ID to set in the cookie.
// - expirationTime time.Duration: The expiration time for the cookie.
func (c *httpCookieService) SetSessionCookie(ctx context.Context, w http.ResponseWriter, sessionID string, expirationTime time.Duration) {
requestID := utils.GetRequestID(ctx)
c.logger.Debug(c.module, requestID, "[SetSessionCookie]: Setting session cookie with ID=[%s], expiration=[%s], HTTPS=[%t]",
sessionID,
expirationTime,
c.enableHTTPS,
)
sameSiteMode, secureFlag := c.getCookieSecuritySettings()
http.SetCookie(w, &http.Cookie{
Name: c.sessionCookieName,
Value: sessionID,
Expires: time.Now().Add(expirationTime),
HttpOnly: true,
Secure: secureFlag,
SameSite: sameSiteMode,
Path: "/",
Domain: c.domain,
})
}
// ClearSessionCookie clears the session token cookie.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - w http.ResponseWriter: The HTTP response writer.
func (c *httpCookieService) ClearSessionCookie(ctx context.Context, w http.ResponseWriter) {
requestID := utils.GetRequestID(ctx)
c.logger.Debug(c.module, requestID, "[ClearSessionCookie]: Clearing session cookie for [%s]", c.sessionCookieName)
sameSiteMode, secureFlag := c.getCookieSecuritySettings()
http.SetCookie(w, &http.Cookie{
Name: c.sessionCookieName,
Value: "",
Expires: time.Now().Add(-time.Hour),
HttpOnly: true,
Secure: secureFlag,
SameSite: sameSiteMode,
Path: "/",
Domain: c.domain,
})
}
// GetSessionToken retrieves the session cookie from the request.
//
// Parameters:
// - r *http.Request: The HTTP request containing the session.
//
// Returns:
// - string: The session cookie if found, otherwise nil.
// - error: An error if retrieving the cookie fails.
func (c *httpCookieService) GetSessionCookie(r *http.Request) (*http.Cookie, error) {
requestID := utils.GetRequestID(r.Context())
c.logger.Debug(c.module, requestID, "[GetSessionCookie]: Attempting to retrieve session cookie")
cookie, err := r.Cookie(c.sessionCookieName)
if err != nil {
c.logger.Error(c.module, requestID, "[GetSessionCookie]: Failed to retrieve session cookie: %v", err)
return nil, errors.Wrap(err, errors.ErrCodeMissingHeader, "failed to retrieve cookie from request")
}
return cookie, nil
}
func (c *httpCookieService) getCookieSecuritySettings() (http.SameSite, bool) {
if c.enableHTTPS {
return http.SameSiteNoneMode, true
}
return http.SameSiteStrictMode, false
}
package service
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"io"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"golang.org/x/crypto/bcrypt"
)
var _ domain.Cryptographer = (*cryptographer)(nil)
const (
keyLength int = 32
nonceLength int = 12
)
type cryptographer struct {
logger *config.Logger
module string
}
// NewCryptographer creates a new instance of Cryptographer.
func NewCryptographer() domain.Cryptographer {
return &cryptographer{
logger: config.GetServerConfig().Logger(),
module: "Cryptographer",
}
}
// EncryptString encrypts a plaintext string using AES-GCM and a secret key.
// It returns the base64-encoded encrypted string.
//
// Parameters:
// - plainStr string: The plaintext string to encrypt.
// - secretKey string: A 32-byte key for AES-256 encryption (make sure to store this securely).
//
// Returns:
// - string: The encrypted text, base64-encoded.
// - error: Error if an encryption issue occurs.
func (c *cryptographer) EncryptString(plainStr, secretKey string) (string, error) {
key, err := c.decodeAndValidateSecretKey(secretKey)
if err != nil {
c.logger.Error(c.module, "", "[EncryptString]: Failed to decode or validate secret key: %v", err)
return "", errors.Wrap(err, "", "failed to decode or validate secret key")
}
block, err := c.generateAESCipherBlock(key)
if err != nil {
c.logger.Error(c.module, "", "[EncryptString]: Failed to create AES cipher block: %v", err)
return "", errors.Wrap(err, "", "failed to create AES cipher block")
}
nonce, err := c.generateRandomNonce()
if err != nil {
c.logger.Error(c.module, "", "[EncryptString]: Failed to generate random nonce: %v", err)
return "", errors.Wrap(err, "", "failed to generate random nonce")
}
aesGCM, err := c.generateGCMCipher(block)
if err != nil {
c.logger.Error(c.module, "", "[EncryptString]: Failed to create GCM cipher: %v", err)
return "", errors.Wrap(err, "", "failed to create GCM cipher")
}
cipherText := aesGCM.Seal(nil, nonce, []byte(plainStr), nil)
result := append(nonce, cipherText...)
return base64.StdEncoding.EncodeToString(result), nil
}
// DecryptString decrypts a base64-encoded cipher text string using AES-GCM and a secret key.
// It returns the decrypted plaintext string.
//
// Parameters:
// - encryptedStr string: The base64-encoded encrypted string.
// - secretKey: string The secret key used for encryption.
//
// Returns:
// - string: The decrypted plaintext string.
// - error: Error if decryption fails.
func (c *cryptographer) DecryptString(encryptedStr, secretKey string) (string, error) {
key, err := c.decodeAndValidateSecretKey(secretKey)
if err != nil {
c.logger.Error(c.module, "", "[DecryptString]: Failed to decode or validate secret key: %v", err)
return "", errors.Wrap(err, "", "failed to decode or validate secret key")
}
cipherText, err := c.decodeBase64String(encryptedStr)
if err != nil {
c.logger.Error(c.module, "", "[DecryptString]: Failed to decode base64 string: %v", err)
return "", errors.Wrap(err, "", "failed to decode base64 string")
}
nonce := cipherText[:12]
block, err := c.generateAESCipherBlock(key)
if err != nil {
c.logger.Error(c.module, "", "[DecryptString]: Failed to create AES cipher block: %v", err)
return "", errors.Wrap(err, "", "failed to create AES cipher block")
}
aesGCM, err := c.generateGCMCipher(block)
if err != nil {
c.logger.Error(c.module, "", "[DecryptString]: Failed to create GCM cipher: %v", err)
return "", errors.Wrap(err, "", "failed to create GCM cipher")
}
plaintext, err := c.decryptCipherText(aesGCM, nonce, cipherText[12:])
if err != nil {
c.logger.Error(c.module, "", "[DecryptString]: Failed to decrypt cipher text: %v", err)
return "", errors.Wrap(err, "", "failed to decrypt cipher text")
}
return string(plaintext), nil
}
// EncryptBytes encrypts a given byte slice using AES-GCM mode.
//
// Parameters:
// - plainBytes []byte: The byte slice to encrypt.
// - secretKey []byte: The key used for encryption. It must be 32 bytes long for AES-256.
//
// Returns:
// - string: The base64-encoded encrypted data.
// - error: Any error that occurs during encryption.
func (c *cryptographer) EncryptBytes(plainBytes []byte, secretKey string) (string, error) {
key, err := c.decodeAndValidateSecretKey(secretKey)
if err != nil {
c.logger.Error(c.module, "", "[EncryptBytes]: Failed to decode or validate secret key: %v", err)
return "", errors.Wrap(err, "", "failed to decode or validate secret key")
}
block, err := c.generateAESCipherBlock(key)
if err != nil {
c.logger.Error(c.module, "", "[EncryptBytes]: Failed to create AES cipher block: %v", err)
return "", errors.Wrap(err, "", "failed to create AES cipher block")
}
nonce, err := c.generateRandomNonce()
if err != nil {
c.logger.Error(c.module, "", "[EncryptBytes]: Failed to generate random nonce: %v", err)
return "", errors.Wrap(err, "", "failed to generate random nonce")
}
aesGCM, err := c.generateGCMCipher(block)
if err != nil {
c.logger.Error(c.module, "", "[EncryptBytes]: Failed to create GCM cipher: %v", err)
return "", errors.Wrap(err, "", "failed to create GCM cipher")
}
cipherText := aesGCM.Seal(nil, nonce, plainBytes, nil)
result := append(nonce, cipherText...)
encodedResult := base64.StdEncoding.EncodeToString(result)
return encodedResult, nil
}
// DecryptBytes decrypts an AES-GCM encrypted string (base64 encoded) into a byte slice.
//
// Parameters:
// - encryptedData string: The base64 encoded encrypted data (nonce + cipherText).
// - secretKey string: The key used for decryption. It must be 32 bytes long for AES-256.
//
// Returns:
// - []byte: The decrypted byte slice (plain data).
// - error: Any error that occurs during decryption.
func (c *cryptographer) DecryptBytes(encryptedBytes, secretKey string) ([]byte, error) {
key, err := c.decodeAndValidateSecretKey(secretKey)
if err != nil {
c.logger.Error(c.module, "", "[DecryptBytes]: Failed to decode or validate secret key: %v", err)
return nil, errors.Wrap(err, "", "failed to decode or validate secret key")
}
decodedData, err := c.decodeBase64String(encryptedBytes)
if err != nil {
c.logger.Error(c.module, "", "[DecryptBytes]: Failed to decode base64 string: %v", err)
return nil, errors.Wrap(err, "", "failed to decode base64 string")
}
const minLengthOfDecodedData int = 12
if len(decodedData) < minLengthOfDecodedData {
c.logger.Error(c.module, "", "[DecryptBytes]: Invalid encrypted data length: %d bytes", len(decodedData))
return nil, errors.New(errors.ErrCodeInvalidInput, "encrypted data must be at least 12 bytes for nonce")
}
nonce := decodedData[:12]
cipherText := decodedData[12:]
block, err := c.generateAESCipherBlock(key)
if err != nil {
c.logger.Error(c.module, "", "[DecryptBytes]: Failed to create AES cipher block: %v", err)
return nil, errors.Wrap(err, "", "failed to create AES cipher block")
}
aesGCM, err := c.generateGCMCipher(block)
if err != nil {
c.logger.Error(c.module, "", "[DecryptBytes]: Failed to create GCM cipher: %v", err)
return nil, errors.Wrap(err, "", "failed to create GCM cipher")
}
plainBytes, err := c.decryptCipherText(aesGCM, nonce, cipherText)
if err != nil {
c.logger.Error(c.module, "", "[DecryptBytes]: Failed to decrypt cipher text: %v", err)
return nil, errors.Wrap(err, "", "failed to decrypt cipher text")
}
return plainBytes, nil
}
// HashString takes a plain text string and returns a hashed
// version of it using bcrypt with the default cost.
//
// Parameters:
// - plainStr string: The string to be encrypted.
//
// Returns:
// - string: The encrypted string.
// - error: Error if an error occurs hashing the string.
func (c *cryptographer) HashString(plainStr string) (string, error) {
if plainStr == "" {
return "", errors.New(errors.ErrCodeHashingFailed, "input string cannot be empty")
}
hash, err := bcrypt.GenerateFromPassword([]byte(plainStr), bcrypt.DefaultCost)
if err != nil {
c.logger.Error(c.module, "", "[HashString]: Failed to hash string: %v", err)
return "", errors.New(errors.ErrCodeHashingFailed, "failed to hash string")
}
return string(hash), nil
}
// GenerateRandomString generates a cryptographically secure random string of the specified length.
//
// Parameters:
// - length int: The desired length of the random string.
//
// Returns:
// - string: The random generated string.
// - error: An error if creating the random string fails.
func (c *cryptographer) GenerateRandomString(length int) (string, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
c.logger.Error(c.module, "", "[GenerateRandomString]: Failed to generate random string: %v", err)
return "", errors.New(errors.ErrCodeRandomGenerationFailed, "failed to generate random string")
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
func (c *cryptographer) decodeAndValidateSecretKey(secretKey string) ([]byte, error) {
key, err := base64.StdEncoding.DecodeString(secretKey)
if err != nil {
c.logger.Error(c.module, "", "[decodeAndValidateSecretKey]: Failed to decode secret key: %v", err)
return nil, errors.New(errors.ErrCodeInvalidInput, "failed to decode secret key")
} else if len(key) != keyLength {
c.logger.Error(c.module, "", "[decodeAndValidateSecretKey]: Invalid secret key length: %d bytes", len(key))
return nil, errors.New(errors.ErrCodeInvalidInput, "secret key must be 32 bytes for AES-256")
}
return key, nil
}
func (c *cryptographer) generateAESCipherBlock(key []byte) (cipher.Block, error) {
block, err := aes.NewCipher(key)
if err != nil {
c.logger.Error(c.module, "", "[generateAESCipherBlock]: Failed to create AES cipher block: %v", err)
return nil, errors.New(errors.ErrCodeEncryptionFailed, "failed to create AES cipher block")
}
return block, nil
}
func (c *cryptographer) generateGCMCipher(block cipher.Block) (cipher.AEAD, error) {
aesGCM, err := cipher.NewGCM(block)
if err != nil {
c.logger.Error(c.module, "", "[generateGCMCipher]: Failed to create GCM cipher: %v", err)
return nil, errors.New(errors.ErrCodeEncryptionFailed, "failed to create GCM cipher")
}
return aesGCM, nil
}
func (c *cryptographer) generateRandomNonce() ([]byte, error) {
nonce := make([]byte, nonceLength)
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
c.logger.Error(c.module, "", "[generateRandomNonce]: Failed to generate random nonce: %v", err)
return nil, errors.New(errors.ErrCodeRandomGenerationFailed, "failed to generate random nonce")
}
return nonce, nil
}
func (c *cryptographer) decodeBase64String(cipherTextBase64 string) ([]byte, error) {
cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64)
if err != nil {
c.logger.Error(c.module, "", "[decodeBase64String]: Failed to decode base64 string: %v", err)
return nil, errors.New(errors.ErrCodeInvalidInput, "failed to decode base64 string")
}
return cipherText, nil
}
func (c *cryptographer) decryptCipherText(aesGCM cipher.AEAD, nonce, cipherText []byte) ([]byte, error) {
plainBytes, err := aesGCM.Open(nil, nonce, cipherText, nil)
if err != nil {
c.logger.Error(c.module, "", "[decryptCipherText]: Failed to decrypt cipher text: %v", err)
return nil, errors.New(errors.ErrCodeDecryptionFailed, "failed to decrypt cipher text")
}
return plainBytes, nil
}
package service
import (
"github.com/vigiloauth/vigilo/v2/internal/constants"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/email"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"gopkg.in/gomail.v2"
)
var _ domain.Mailer = (*goMailer)(nil)
type goMailer struct{}
func NewGoMailer() domain.Mailer {
return &goMailer{}
}
// Dial establishes a connection to the email server with the provided host, port,
// username, and password. It returns a `gomail.SendCloser` which can be used
// to send emails, or an error if the connection fails.
//
// Parameters:
// - host string: The email server's host.
// - port string: The port number to connect to on the email server.
// - username string: The username for authenticating to the email server.
// - password string: The password for authenticating to the email server.
//
// Returns:
// - gomail.SendCloser: A send closer to send emails.
// - error: An error if the connection fails, or nil if successful.
func (m *goMailer) Dial(host string, port int, username string, password string) (gomail.SendCloser, error) {
dialer := gomail.NewDialer(host, port, username, password)
closer, err := dialer.Dial()
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to dial to the SMTP server")
}
return closer, nil
}
// DialAndSend connects to the email server and immediately sends the provided email message(s).
// It returns an error if the connection or sending process fails.
//
// Parameters:
// - host string: The email server's host.
// - port string: The port number to connect to on the email server.
// - username string: The username for authenticating to the email server.
// - password string: The password for authenticating to the email server.
// - message string: A list of gomail messages to send. One or more messages can be provided.
//
// Returns:
// - error: An error indicating the failure to send the email(s), or nil if successful.
func (m *goMailer) DialAndSend(host string, port int, username string, password string, message ...*gomail.Message) error {
dialer := gomail.NewDialer(host, port, username, password)
if err := dialer.DialAndSend(message...); err != nil {
return errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to dial and send message")
}
return nil
}
// NewMessage creates a new gomail message using the provided email request, body,
// subject, and sender address. This message can then be sent using the DialAndSend method.
//
// Parameters:
// - request string: The request object containing email details.
// - body string: The body content of the email.
// - subject string: The subject of the email.
// - fromAddress string: The sender's email address.
//
// Returns:
// - *gomail.Message: A new gomail message containing the provided email details.
func (m *goMailer) NewMessage(request *domain.EmailRequest, body string, subject string, fromAddress string) *gomail.Message {
message := gomail.NewMessage()
message.SetHeader(constants.FromAddress, fromAddress)
message.SetHeader(constants.Recipient, request.Recipient)
message.SetHeader(constants.EmailSubject, subject)
message.SetBody(constants.HTMLBody, body)
return message
}
package service
import (
"bytes"
"context"
"text/template"
"time"
_ "embed"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/email"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var (
//go:embed templates/account_verification.html
accountVerificationTemplate string
//go:embed templates/account_deletion.html
accountDeletionTemplate string
_ domain.EmailService = (*emailService)(nil)
)
const (
maxRetries int = 2
)
type emailService struct {
smtpConfig *config.SMTPConfig
host string
port int
username string
password string
baseURL string
fromAddress string
retryQueue *domain.EmailRetryQueue
mailer domain.Mailer
logger *config.Logger
module string
}
func NewEmailService(mailer domain.Mailer) domain.EmailService {
smtpConfig := config.GetServerConfig().SMTPConfig()
service := &emailService{
smtpConfig: smtpConfig,
host: smtpConfig.Host(),
port: smtpConfig.Port(),
username: smtpConfig.Username(),
password: smtpConfig.Password(),
baseURL: config.GetServerConfig().BaseURL(),
fromAddress: smtpConfig.FromAddress(),
retryQueue: &domain.EmailRetryQueue{},
mailer: mailer,
logger: config.GetLogger(),
module: "Email Service",
}
return service
}
// SendEmail sends an email based on the provided request.
// It returns an error if the email could not be sent.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - request *EmailRequest: The email request containing necessary details for sending the email.
//
// Returns:
// - error: An error indicating the failure to send the email, or nil if successful.
func (s *emailService) SendEmail(ctx context.Context, request *domain.EmailRequest) error {
requestID := utils.GetRequestID(ctx)
if !s.smtpConfig.IsHealthy() {
s.logger.Warn(s.module, requestID, "[SendEmail]: SMTP server is down. Adding the request to the retry queue for future processing.")
s.retryQueue.Add(request)
return nil
}
switch request.EmailType {
case domain.AccountVerification:
if err := s.sendEmail(ctx, request, constants.VerifyEmailAddress); err != nil {
return errors.Wrap(err, errors.ErrCodeEmailDeliveryFailed, "failed to send verification email")
}
case domain.AccountDeletion:
if err := s.sendEmail(ctx, request, constants.AccountDeletion); err != nil {
return errors.Wrap(err, errors.ErrCodeEmailDeliveryFailed, "failed to send verification email")
}
}
return nil
}
// TestConnection tests the connection to the email service.
// It returns an error if the connection test fails.
//
// Returns:
// - error: An error indicating the failure of the connection test, or nil if successful.
func (s *emailService) TestConnection() error {
backoff := constants.FiveSecondTimeout
var lastError error
for attempt := range maxRetries {
if attempt > 0 {
time.Sleep(backoff)
backoff *= 2
}
err := s.connectToSMTPServer()
if err == nil {
if attempt > 0 {
s.logger.Info(s.module, "", "[TestConnection]: SMTP connection restored after %d attempts", attempt)
}
s.updateSMTPServerStatus(true)
return nil
}
lastError = err
s.logger.Warn(s.module, "", "[TestConnection]: SMTP connection failed (attempt %d/%d): %v", attempt+1, maxRetries, err)
}
s.logger.Error(s.module, "", "[TestConnection]: SMTP connection check failed after %d attempts: %v", maxRetries, lastError)
s.updateSMTPServerStatus(false)
return lastError
}
// GetEmailRetryQueue retrieves the current email retry queue.
// The retry queue contains emails that failed to send and are awaiting retry.
//
// Returns:
// - *EmailRetryQueue: The current retry queue. If there are no failed emails, returns an empty queue.
func (s *emailService) GetEmailRetryQueue() *domain.EmailRetryQueue {
return s.retryQueue
}
func (s *emailService) connectToSMTPServer() error {
closer, err := s.mailer.Dial(s.host, s.port, s.username, s.password)
if err != nil {
s.logger.Error(s.module, "", "Failed to connect to the SMTP server: %v", err)
return errors.Wrap(err, errors.ErrCodeConnectionFailed, "failed to connect to the SMTP server")
}
if err := closer.Close(); err != nil {
s.logger.Error(s.module, "", "[connectToSMTPServer]: Failed to close 'gomail.SendCloser()': %v", err)
return errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to close sender")
}
return nil
}
func (s *emailService) sendEmail(ctx context.Context, request *domain.EmailRequest, subject string) error {
requestID := utils.GetRequestID(ctx)
body, err := s.generateEmailBody(request)
if err != nil {
return errors.Wrap(err, "", "failed to generate email template")
}
message := s.mailer.NewMessage(request, body, subject, s.fromAddress)
done := make(chan error, 1)
go func() {
err := s.mailer.DialAndSend(s.host, s.port, s.username, s.password, message)
done <- err
}()
select {
case <-ctx.Done():
s.logger.Error(s.module, requestID, "Context cancelled before email was sent: %v", ctx.Err())
if err := ctx.Err(); err != nil {
return errors.NewContextError(err) //nolint:wrapcheck
}
return nil
case err := <-done:
if err != nil {
s.logger.Error(s.module, requestID, "Failed to send email. Adding to retry queue: %v", err)
s.retryQueue.Add(request)
return errors.Wrap(err, errors.ErrCodeInternalServerError, err.Error())
}
}
return nil
}
func (s *emailService) generateEmailBody(request *domain.EmailRequest) (string, error) {
var body string
var err error
switch request.EmailType {
case domain.AccountVerification:
body, err = s.generateEmailTemplate(request, accountVerificationTemplate)
case domain.AccountDeletion:
body, err = s.generateEmailTemplate(request, accountDeletionTemplate)
}
if err != nil {
s.logger.Error(s.module, "", "An error occurred generating the email body request=[%s]: %v", request.EmailType.String(), err)
return "", errors.Wrap(err, "", "failed to generate email template")
}
return body, nil
}
func (s *emailService) generateEmailTemplate(request *domain.EmailRequest, templateName string) (string, error) {
tmpl, err := template.New(templateName).Parse(templateName)
if err != nil {
s.logger.Error(s.module, "", "Failed to parse email template file: %v", err)
return "", errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to parse email template file")
}
request.BaseURL = s.baseURL
var buf bytes.Buffer
if err := tmpl.Execute(&buf, request); err != nil {
s.logger.Error(s.module, "", "Failed to load email data into the template: %v", err)
return "", errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to execute template")
}
return buf.String(), nil
}
func (s *emailService) updateSMTPServerStatus(isHealthy bool) {
s.smtpConfig.SetHealth(isHealthy)
}
package service
import (
"context"
"crypto/rsa"
"github.com/golang-jwt/jwt"
"github.com/vigiloauth/vigilo/v2/idp/config"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/jwt"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ domain.JWTService = (*jwtService)(nil)
type jwtService struct {
publicKey *rsa.PublicKey
privateKey *rsa.PrivateKey
keyID string
logger *config.Logger
module string
}
func NewJWTService() domain.JWTService {
return &jwtService{
publicKey: config.GetServerConfig().TokenConfig().PublicKey(),
privateKey: config.GetServerConfig().TokenConfig().SecretKey(),
keyID: config.GetServerConfig().TokenConfig().KeyID(),
logger: config.GetServerConfig().Logger(),
module: "JWT Service",
}
}
func (s *jwtService) ParseWithClaims(ctx context.Context, tokenString string) (*tokens.TokenClaims, error) {
requestID := utils.GetRequestID(ctx)
tokenClaims, err := jwt.ParseWithClaims(tokenString, &tokens.TokenClaims{}, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
s.logger.Error(s.module, requestID, "[ParseWithClaims]: Unexpected signing method received")
return nil, errors.New(errors.ErrCodeTokenParsing, "unexpected signing method")
}
return s.publicKey, nil
})
if err != nil {
s.logger.Error(s.module, requestID, "[ParseWithClaims]: An error occurred parsing the token: %v", err)
return nil, errors.Wrap(err, errors.ErrCodeTokenParsing, "failed to parse JWT with claims")
}
if claims, ok := tokenClaims.Claims.(*tokens.TokenClaims); ok && tokenClaims.Valid {
return claims, nil
}
return nil, errors.New(errors.ErrCodeInvalidToken, "provided token is invalid")
}
func (s *jwtService) SignToken(ctx context.Context, claims *tokens.TokenClaims) (string, error) {
jwtClaims := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
jwtClaims.Header["kid"] = s.keyID
signedStr, err := jwtClaims.SignedString(s.privateKey)
if err != nil {
return "", errors.Wrap(err, errors.ErrCodeTokenSigning, "failed to sign token with claims")
}
return signedStr, nil
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
login "github.com/vigiloauth/vigilo/v2/internal/domain/login"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ login.LoginAttemptService = (*loginAttemptService)(nil)
type loginAttemptService struct {
userRepo user.UserRepository
loginRepo login.LoginAttemptRepository
maxFailedLoginAttempts int
logger *config.Logger
module string
}
// NewLoginAttemptService creates a new LoginServiceImpl instance.
//
// Parameters:
// - userRepo UserRepository: The user repository instance.
// - loginRepo LoginRepository: The login repository instance.
//
// Returns:
// - LoginAttemptService: A new LoginAttemptService instance.
func NewLoginAttemptService(
userRepo user.UserRepository,
loginRepo login.LoginAttemptRepository,
) login.LoginAttemptService {
return &loginAttemptService{
userRepo: userRepo,
loginRepo: loginRepo,
maxFailedLoginAttempts: config.GetServerConfig().LoginConfig().MaxFailedAttempts(),
logger: config.GetServerConfig().Logger(),
module: "Login Attempt Service",
}
}
// SaveLoginAttempt logs a login attempt.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - attempt *UserLoginAttempt: The login attempt to save.
func (s *loginAttemptService) SaveLoginAttempt(ctx context.Context, attempt *user.UserLoginAttempt) error {
requestID := utils.GetRequestID(ctx)
if err := s.loginRepo.SaveLoginAttempt(ctx, attempt); err != nil {
s.logger.Error(s.module, requestID, "[SaveLoginAttempt]: Failed to save login attempt for user=[%s]: %v",
utils.TruncateSensitive(attempt.UserID), err,
)
return errors.Wrap(err, "", "failed to save login attempt")
}
return nil
}
// GetLoginAttemptsByUserID retrieves all login attempts for a given user.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The user ID.
//
// Returns:
// - []*UserLoginAttempt: A slice of login attempts for the user.
// - error: An error if retrieval fails.
func (s *loginAttemptService) GetLoginAttemptsByUserID(ctx context.Context, userID string) ([]*user.UserLoginAttempt, error) {
requestID := utils.GetRequestID(ctx)
attempts, err := s.loginRepo.GetLoginAttemptsByUserID(ctx, userID)
if err != nil {
s.logger.Error(s.module, requestID, "[GetLoginAttemptsByUserID]: An error occurred retrieving login attempts: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve user login attempts")
}
return attempts, nil
}
// HandleFailedLoginAttempt handles a failed login attempt.
// It updates the user's last failed login time, saves the login attempt, and locks the account if necessary.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - user *User: The user who attempted to log in.
// - attempt *UserLoginAttempt: The login attempt information.
//
// Returns:
// - error: An error if an operation fails.
func (s *loginAttemptService) HandleFailedLoginAttempt(ctx context.Context, user *user.User, attempt *user.UserLoginAttempt) error {
requestID := utils.GetRequestID(ctx)
s.logger.Info(s.module, requestID, "[HandleFailedLoginAttempt]: Authentication failed for user=[%s]", utils.TruncateSensitive(user.ID))
user.LastFailedLogin = time.Now()
if err := s.loginRepo.SaveLoginAttempt(ctx, attempt); err != nil {
s.logger.Error(s.module, requestID, "[HandleFailedLoginAttempt]: Failed to save login attempt for user=[%s]: %v", utils.TruncateSensitive(attempt.UserID), err)
return errors.Wrap(err, "", "failed to save failed login attempt")
}
if err := s.userRepo.UpdateUser(ctx, user); err != nil {
s.logger.Error(s.module, requestID, "[HandleFailedLoginAttempt]: Failed to update user=[%s]: %v", utils.TruncateSensitive(user.ID), err)
return errors.Wrap(err, "", "failed to update the user")
}
if s.shouldLockAccount(ctx, user.ID) {
s.logger.Info(s.module, requestID, "[HandleFailedLoginAttempt]: Attempting to lock account for user=[%s]", utils.TruncateSensitive(user.ID))
if err := s.lockAccount(ctx, user); err != nil {
s.logger.Error(s.module, requestID, "[HandleFailedLoginAttempt]: Failed to lock account for user=[%s]: %v", utils.TruncateSensitive(user.ID), err)
return errors.Wrap(err, "", "failed to update the user")
}
s.logger.Debug(s.module, requestID, "[HandleFailedLoginAttempt]: Account successfully locked for user=[%s]", utils.TruncateSensitive(user.ID))
}
return nil
}
func (s *loginAttemptService) shouldLockAccount(ctx context.Context, userID string) bool {
requestID := utils.GetRequestID(ctx)
loginAttempts, err := s.loginRepo.GetLoginAttemptsByUserID(ctx, userID)
if err != nil {
s.logger.Error(s.module, requestID, "An error occurred retrieving user login attempts: %v", err)
}
return len(loginAttempts) >= s.maxFailedLoginAttempts
}
func (s *loginAttemptService) lockAccount(ctx context.Context, retrievedUser *user.User) error {
retrievedUser.AccountLocked = true
if err := s.userRepo.UpdateUser(ctx, retrievedUser); err != nil {
s.logger.Error(s.module, utils.GetRequestID(ctx), "[lockAccount]: Failed to update user: %v", err)
return errors.Wrap(err, "", "failed to update user")
}
return nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authorization"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
jwks "github.com/vigiloauth/vigilo/v2/internal/domain/jwks"
oidc "github.com/vigiloauth/vigilo/v2/internal/domain/oidc"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
user "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ oidc.OIDCService = (*oidcService)(nil)
type oidcService struct {
authorizationService authz.AuthorizationService
logger *config.Logger
module string
}
func NewOIDCService(authorizationService authz.AuthorizationService) oidc.OIDCService {
return &oidcService{
authorizationService: authorizationService,
logger: config.GetServerConfig().Logger(),
module: "OIDC Service",
}
}
// GetUserInfo retrieves the user's profile information based on the claims
// extracted from a validated access token.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - accessTokenClaims *TokenClaims: A pointer to TokenClaims that were parsed and validated
// from the access token. These typically include standard OIDC claims such as
// 'sub' (subject identifier), 'scope', 'exp' (expiration), etc.
//
// Returns:
// - *UserInfoResponse: A pointer to a UserInfoResponse struct containing the requested user
// information (e.g., name, email, profile picture), filtered according to the
// authorized scopes.
// - error: An error if the user cannot be found, the scopes are insufficient, or any
// other issue occurs during retrieval.
func (s *oidcService) GetUserInfo(ctx context.Context, accessTokenClaims *token.TokenClaims) (*user.UserInfoResponse, error) {
requestID := utils.GetRequestID(ctx)
retrievedUser, err := s.authorizationService.AuthorizeUserInfoRequest(ctx, accessTokenClaims)
if err != nil {
s.logger.Error(s.module, requestID, "[GetUserInfo]: Failed to authorize user info request: %v", err)
wrappedErr := errors.Wrap(err, "", "failed to authorize request")
return nil, wrappedErr
}
userInfoResponse := &user.UserInfoResponse{Sub: retrievedUser.ID}
requestedScopes := types.ParseScopesString(accessTokenClaims.Scopes.String())
s.populateUserInfoFromScopes(userInfoResponse, retrievedUser, requestedScopes)
userInfoClaims := accessTokenClaims.RequestedClaims
if userInfoClaims != nil {
s.populateUserInfoFromRequestedClaims(userInfoResponse, retrievedUser, userInfoClaims)
}
return userInfoResponse, nil
}
// GetJwks retrieves the JSON Web Key Set (JWKS) used for verifying signatures
// of tokens issued by the OpenID Connect provider.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
//
// Returns:
// - *Jwks: A pointer to a Jwks struct containing the public keys in JWKS format.
func (s *oidcService) GetJwks(ctx context.Context) *jwks.Jwks {
tokenConfig := config.GetServerConfig().TokenConfig()
publicKey := tokenConfig.PublicKey()
keyID := tokenConfig.KeyID()
return &jwks.Jwks{
Keys: []jwks.JWK{
jwks.NewJWK(keyID, publicKey),
},
}
}
func (s *oidcService) populateUserInfoFromScopes(
userInfoResponse *user.UserInfoResponse,
retrievedUser *user.User,
requestedScopes []types.Scope,
) {
for _, scope := range requestedScopes {
switch scope {
case types.UserProfileScope:
userInfoResponse.Name = retrievedUser.Name
userInfoResponse.FamilyName = retrievedUser.FamilyName
userInfoResponse.GivenName = retrievedUser.GivenName
userInfoResponse.MiddleName = retrievedUser.MiddleName
userInfoResponse.Nickname = retrievedUser.Nickname
userInfoResponse.PreferredUsername = retrievedUser.PreferredUsername
userInfoResponse.Profile = retrievedUser.Profile
userInfoResponse.Picture = retrievedUser.Picture
userInfoResponse.Website = retrievedUser.Website
userInfoResponse.Gender = retrievedUser.Gender
userInfoResponse.Birthdate = retrievedUser.Birthdate
userInfoResponse.Zoneinfo = retrievedUser.Zoneinfo
userInfoResponse.Locale = retrievedUser.Locale
userInfoResponse.UpdatedAt = retrievedUser.UpdatedAt.UTC().Unix()
case types.UserEmailScope:
userInfoResponse.Email = retrievedUser.Email
userInfoResponse.EmailVerified = &retrievedUser.EmailVerified
case types.UserPhoneScope:
userInfoResponse.PhoneNumber = retrievedUser.PhoneNumber
userInfoResponse.PhoneNumberVerified = &retrievedUser.PhoneNumberVerified
case types.UserAddressScope:
userInfoResponse.Address = retrievedUser.Address
}
}
}
func (s *oidcService) populateUserInfoFromRequestedClaims(
userInfoResponse *user.UserInfoResponse,
retrievedUser *user.User,
requestedClaims *domain.ClaimsRequest,
) {
if requestedClaims.UserInfo == nil {
return
}
userInfoClaims := *(requestedClaims.UserInfo)
claimSetters := map[string]func(){
constants.NameClaim: func() {
userInfoResponse.Name = retrievedUser.Name
},
constants.GivenNameClaim: func() {
userInfoResponse.GivenName = retrievedUser.GivenName
},
constants.FamilyNameClaim: func() {
userInfoResponse.FamilyName = retrievedUser.FamilyName
},
constants.MiddleNameClaim: func() {
userInfoResponse.MiddleName = retrievedUser.MiddleName
},
constants.NicknameClaim: func() {
userInfoResponse.Nickname = retrievedUser.Nickname
},
constants.PreferredUsernameClaim: func() {
userInfoResponse.PreferredUsername = retrievedUser.PreferredUsername
},
constants.ProfileClaim: func() {
userInfoResponse.Profile = retrievedUser.Profile
},
constants.PictureClaim: func() {
userInfoResponse.Picture = retrievedUser.Picture
},
constants.WebsiteClaim: func() {
userInfoResponse.Website = retrievedUser.Website
},
constants.GenderClaim: func() {
userInfoResponse.Gender = retrievedUser.Gender
},
constants.BirthdateClaim: func() {
userInfoResponse.Birthdate = retrievedUser.Birthdate
},
constants.ZoneinfoClaim: func() {
userInfoResponse.Zoneinfo = retrievedUser.Zoneinfo
},
constants.LocaleClaim: func() {
userInfoResponse.Locale = retrievedUser.Locale
},
constants.EmailClaim: func() {
userInfoResponse.Email = retrievedUser.Email
},
constants.EmailVerifiedClaim: func() {
userInfoResponse.EmailVerified = &retrievedUser.EmailVerified
},
constants.PhoneNumberClaim: func() {
userInfoResponse.PhoneNumber = retrievedUser.PhoneNumber
},
constants.PhoneNumberVerifiedClaim: func() {
userInfoResponse.PhoneNumberVerified = &retrievedUser.PhoneNumberVerified
},
constants.UpdatedAtClaim: func() {
userInfoResponse.UpdatedAt = retrievedUser.UpdatedAt.UTC().Unix()
},
constants.AddressClaim: func() {
userInfoResponse.Address = retrievedUser.Address
},
}
for claimName := range userInfoClaims {
if setter, exists := claimSetters[claimName]; exists {
setter()
}
}
}
package service
import (
"context"
"net/http"
"github.com/vigiloauth/vigilo/v2/idp/config"
cookie "github.com/vigiloauth/vigilo/v2/internal/domain/cookies"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ session.SessionManager = (*sessionManager)(nil)
type sessionManager struct {
repo session.SessionRepository
cookies cookie.HTTPCookieService
logger *config.Logger
module string
}
func NewSessionManager(
repo session.SessionRepository,
cookies cookie.HTTPCookieService,
) session.SessionManager {
return &sessionManager{
repo: repo,
cookies: cookies,
logger: config.GetServerConfig().Logger(),
module: "Session Manager",
}
}
// GetUserIDFromSession checks if the user session is active based on the provided context and HTTP request.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - r *http.Request: The HTTP request associated with the user session.
//
// Returns:
// - string: The user ID if the session is active, or an empty string if not.
// - error: An error if the session data retrieval fails.
func (s *sessionManager) GetUserIDFromSession(ctx context.Context, r *http.Request) (string, error) {
requestID := utils.GetRequestID(ctx)
cookie, err := s.cookies.GetSessionCookie(r)
if err != nil {
s.logger.Error(s.module, requestID, "[GetUserIDFromSession]: Failed to retrieve session cookie from header: %v", err)
return "", errors.Wrap(err, errors.ErrCodeMissingHeader, "session cookie not found in header")
}
sessionID := cookie.Value
sessionData, err := s.repo.GetSessionByID(ctx, sessionID)
if err != nil {
s.logger.Error(s.module, requestID, "[GetUserIDFromSession]: Failed to retrieve user ID from session: %v", err)
return "", errors.Wrap(err, "", "failed to retrieve session")
}
return sessionData.UserID, nil
}
// GetUserAuthenticationTime retrieves the authentication time of the user session based on the provided context and HTTP request.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - r *http.Request: The HTTP request associated with the user session.
//
// Returns:
// - int64: The authentication time in Unix timestamp format.
// - error: An error if the session data retrieval fails.
func (s *sessionManager) GetUserAuthenticationTime(ctx context.Context, r *http.Request) (int64, error) {
requestID := utils.GetRequestID(ctx)
cookie, err := s.cookies.GetSessionCookie(r)
if err != nil {
s.logger.Error(s.module, requestID, "[GetSessionData]: Failed to retrieve session cookie from header: %v", err)
return int64(0), errors.Wrap(err, errors.ErrCodeMissingHeader, "session cookie not found in header")
}
sessionID := cookie.Value
sessionData, err := s.repo.GetSessionByID(ctx, sessionID)
if err != nil {
s.logger.Error(s.module, requestID, "[GetSessionData]: Failed to retrieve session by ID: %v", err)
return int64(0), errors.Wrap(err, "", "failed to retrieve session")
}
return sessionData.AuthenticationTime, nil
}
package service
import (
"net/http"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
cookie "github.com/vigiloauth/vigilo/v2/internal/domain/cookies"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
// Ensure SessionService implements the Session interface.
var _ session.SessionService = (*sessionService)(nil)
// sessionService handles session management.
type sessionService struct {
sessionRepo session.SessionRepository
httpCookieService cookie.HTTPCookieService
auditLogger audit.AuditLogger
sessionDuration time.Duration
logger *config.Logger
module string
}
// NewSessionService creates a new instance of SessionService with the required dependencies.
//
// Parameters:
// - sessionRepo SessionRepository: The session repository.
// - httpCookieService HTTPCookieService: The HTTP Cookie Service instance.
// - auditLogger AuditLogger: The Audit Logger instance.
//
// Returns:
// - *SessionService: A new SessionService instance.
func NewSessionService(
sessionRepo session.SessionRepository,
httpCookieService cookie.HTTPCookieService,
auditLogger audit.AuditLogger,
) session.SessionService {
return &sessionService{
sessionRepo: sessionRepo,
httpCookieService: httpCookieService,
auditLogger: auditLogger,
logger: config.GetServerConfig().Logger(),
sessionDuration: config.GetServerConfig().TokenConfig().ExpirationTime(),
module: "Session Service",
}
}
// CreateSession creates a new session token and sets it in an HttpOnly cookie.
//
// Parameters:
// - w http.ResponseWriter: The HTTP response writer.
// - r *http.Request: The HTTP request.
// - userID string: The user's ID address.
// - sessionExpiration time.Duration: The session expiration time.
//
// Returns:
// - error: An error if token generation or cookie setting fails.
func (s *sessionService) CreateSession(w http.ResponseWriter, r *http.Request, sessionData *session.SessionData) error {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
sessionData.ID = constants.SessionIDPrefix + utils.GenerateUUID()
sessionData.ExpirationTime = time.Now().Add(s.sessionDuration)
if err := s.sessionRepo.SaveSession(ctx, sessionData); err != nil {
s.logger.Error(s.module, requestID, "[CreateSession]: Failed to save session: %v", err)
wrappedErr := errors.Wrap(err, errors.ErrCodeInternalServerError, "error creating session")
s.auditLogger.StoreEvent(ctx, audit.SessionCreated, false, audit.SessionCreationAction, audit.CookieMethod, wrappedErr)
return wrappedErr
}
s.auditLogger.StoreEvent(ctx, audit.SessionCreated, true, audit.SessionCreationAction, audit.CookieMethod, nil)
s.httpCookieService.SetSessionCookie(ctx, w, sessionData.ID, s.sessionDuration)
return nil
}
// InvalidateSession invalidates the session token by adding it to the blacklist.
//
// Parameters:
// - w http.ResponseWriter: The HTTP response writer.
// - r *http.Request: The HTTP request.
//
// Returns:
// - error: An error if token parsing or blacklist addition fails.
func (s *sessionService) InvalidateSession(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
cookie, err := s.httpCookieService.GetSessionCookie(r)
if err != nil {
s.logger.Error(s.module, requestID, "[InvalidSession]: Failed to retrieve session cookie from header: %v", err)
return errors.Wrap(err, errors.ErrCodeMissingHeader, "session cookie not found in header")
}
sessionID := cookie.Value
if err := s.sessionRepo.DeleteSessionByID(ctx, sessionID); err != nil {
s.logger.Error(s.module, requestID, "[InvalidateSession]: Failed to delete session=[%s]: %v", utils.TruncateSensitive(sessionID), err)
wrappedErr := errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to invalidate session")
s.auditLogger.StoreEvent(ctx, audit.SessionDeleted, false, audit.SessionDeletionAction, audit.CookieMethod, wrappedErr)
return wrappedErr
}
s.httpCookieService.ClearSessionCookie(ctx, w)
s.auditLogger.StoreEvent(ctx, audit.SessionDeleted, true, audit.SessionDeletionAction, audit.CookieMethod, nil)
return nil
}
// GetUserIDFromSession retrieves the user ID from the current session.
//
// Parameters:
// - r *http.Request: The HTTP request.
//
// Returns:
// - string: The user ID.
func (s *sessionService) GetUserIDFromSession(r *http.Request) (string, error) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
cookie, err := s.httpCookieService.GetSessionCookie(r)
if err != nil {
s.logger.Error(s.module, requestID, "[GetUserIDFromSession]: Failed to retrieve session cookie from header: %v", err)
return "", errors.Wrap(err, errors.ErrCodeMissingHeader, "session cookie not found in header")
}
sessionID := cookie.Value
sessionData, err := s.sessionRepo.GetSessionByID(ctx, sessionID)
if err != nil {
s.logger.Error(s.module, requestID, "[GetUserIDFromSession]: Failed to retrieve user ID from session: %v", err)
return "", errors.Wrap(err, "", "failed to retrieve user ID from session")
}
return sessionData.UserID, nil
}
// UpdateSession updates the current session.
//
// Parameters:
// - r *http.Request: The HTTP request.
// - sessionData *SessionData: The sessionData to update.
//
// Returns:
// - error: If an error occurs during the update.
func (s *sessionService) UpdateSession(r *http.Request, sessionData *session.SessionData) error {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
cookie, err := s.httpCookieService.GetSessionCookie(r)
if err != nil {
s.logger.Error(s.module, requestID, "[UpdateSession]: Failed to retrieve session cookie from header: %v", err)
return errors.Wrap(err, errors.ErrCodeMissingHeader, "session cookie not found in header")
}
sessionID := cookie.Value
if sessionID != sessionData.ID {
s.logger.Error(s.module, requestID, "[UpdateSession]: SessionID=[%s] and SessionDataID=[%s] do not match",
utils.TruncateSensitive(sessionID),
utils.TruncateSensitive(sessionData.ID),
)
return errors.New(errors.ErrCodeUnauthorized, "session IDs do no match")
}
if err := s.sessionRepo.UpdateSessionByID(ctx, sessionID, sessionData); err != nil {
s.logger.Error(s.module, requestID, "[UpdateSession]: Failed to update session: %v", err)
return errors.Wrap(err, "", "failed to update session")
}
return nil
}
// GetSessionData retrieves the current session.
//
// Parameters:
// - r *http.Request: The HTTP request.
//
// Returns:
// - *SessionData: The session data is successful.
// - error: An error if retrieval fails.
func (s *sessionService) GetSessionData(r *http.Request) (*session.SessionData, error) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
cookie, err := s.httpCookieService.GetSessionCookie(r)
if err != nil {
s.logger.Error(s.module, requestID, "[GetSessionData]: Failed to retrieve session cookie from header: %v", err)
return nil, errors.Wrap(err, errors.ErrCodeMissingHeader, "session cookie not found in header")
}
sessionID := cookie.Value
sessionData, err := s.sessionRepo.GetSessionByID(ctx, sessionID)
if err != nil {
s.logger.Error(s.module, requestID, "[GetSessionData]: Failed to retrieve session by ID: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve session")
}
return sessionData, nil
}
package service
import (
"context"
"slices"
"strings"
"time"
"github.com/golang-jwt/jwt"
"github.com/vigiloauth/vigilo/v2/idp/config"
claims "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
crypto "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
jwtService "github.com/vigiloauth/vigilo/v2/internal/domain/jwt"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ token.TokenCreator = (*tokenCreator)(nil)
const tokenIDLength int = 32
type tokenCreator struct {
repo token.TokenRepository
jwtService jwtService.JWTService
cryptographer crypto.Cryptographer
issuer string
accessTokenDuration int64
refreshTokenDuration int64
keyID string
logger *config.Logger
module string
}
func NewTokenCreator(
repo token.TokenRepository,
jwtService jwtService.JWTService,
cryptographer crypto.Cryptographer,
) token.TokenCreator {
accessTokenDuration := time.Now().Add(config.GetServerConfig().TokenConfig().AccessTokenDuration()).Unix()
refreshTokenDuration := time.Now().Add(config.GetServerConfig().TokenConfig().RefreshTokenDuration()).Unix()
return &tokenCreator{
repo: repo,
jwtService: jwtService,
cryptographer: cryptographer,
issuer: config.GetServerConfig().URL() + "/oauth2",
accessTokenDuration: accessTokenDuration,
refreshTokenDuration: refreshTokenDuration,
keyID: config.GetServerConfig().TokenConfig().KeyID(),
logger: config.GetServerConfig().Logger(),
module: "Token Creator",
}
}
// CreateAccessToken generates an access token for the given subject and expiration time.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - subject string: The subject of the token (e.g., user email).
// - audience string: The audience of the token (e.g., client ID).
// - scopes types.Scope: The scopes to be added to the token (can be an empty string if none are needed)..
// - roles string: The roles to be added to the token (can be an empty string if none are needed).
// - nonce string: A random string used to prevent replay attacks provided by the client.
//
// Returns:
// - string: The generated JWT token string.
// - error: An error if token generation fails.
func (t *tokenCreator) CreateAccessToken(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
) (string, error) {
requestID := utils.GetRequestID(ctx)
accessToken, err := t.generateAndStoreToken(
ctx,
subject,
audience,
scopes,
roles,
nonce,
t.accessTokenDuration,
"",
time.Time{},
nil,
)
if err != nil {
t.logger.Error(t.module, requestID, "[CreateAccessToken]: Failed to create access token: %v", err)
return "", errors.Wrap(err, "", "failed to create access token")
}
return accessToken, nil
}
// CreateRefreshToken generates an access token for the given subject and expiration time.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - subject string: The subject of the token (e.g., user email).
// - audience string: The audience of the token (e.g., client ID).
// - scopes types.Scope: The scopes to be added to the token (can be an empty string if none are needed)..
// - roles string: The roles to be added to the token (can be an empty string if none are needed).
// - nonce string: A random string used to prevent replay attacks provided by the client.
//
// Returns:
// - string: The generated JWT token string.
// - error: An error if token generation fails.
func (t *tokenCreator) CreateRefreshToken(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
) (string, error) {
requestID := utils.GetRequestID(ctx)
refreshToken, err := t.generateAndStoreToken(
ctx,
subject,
audience,
scopes,
roles,
nonce,
t.refreshTokenDuration,
"",
time.Time{},
nil,
)
if err != nil {
t.logger.Error(t.module, requestID, "[CreateRefreshToken]: Failed to create refresh token: %v", err)
return "", errors.Wrap(err, "", "failed to create refresh token")
}
return refreshToken, nil
}
// CreateAccessTokenWithClaims generates an access token for the given subject and expiration time.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - subject string: The subject of the token (e.g., user email).
// - audience string: The audience of the token (e.g., client ID).
// - scopes types.Scope: The scopes to be added to the token (can be an empty string if none are needed)..
// - roles string: The roles to be added to the token (can be an empty string if none are needed).
// - nonce string: A random string used to prevent replay attacks provided by the client.
// - requestedClaims *claims.ClaimsRequest: The requested claims
//
// Returns:
// - string: The generated JWT token string.
// - error: An error if token generation fails.
func (t *tokenCreator) CreateAccessTokenWithClaims(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
requestedClaims *claims.ClaimsRequest,
) (string, error) {
requestID := utils.GetRequestID(ctx)
accessToken, err := t.generateAndStoreToken(
ctx,
subject,
audience,
scopes,
roles,
nonce,
t.accessTokenDuration,
"",
time.Time{},
requestedClaims,
)
if err != nil {
t.logger.Error(t.module, requestID, "[CreateAccessTokenWithClaims]: Failed to create access token with claims: %v", err)
return "", errors.Wrap(err, "", "failed to create access token")
}
return accessToken, nil
}
// CreateIDToken creates an ID token for the specified user and client.
//
// The ID token is a JWT that contains claims about the authentication of the user.
// It includes information such as the user ID, client ID, scopes, and nonce for
// replay protection. The token is generated and then stored in the token store.
//
// Parameters:
// - ctx context.Context: Context for the request, containing the request ID for logging.
// - userID string: The unique identifier of the user.
// - clientID string: The client application identifier requesting the token.
// - scopes string: Space-separated list of requested scopes.
// - nonce string: A random string used to prevent replay attacks.
// - authTime *Time: Time at which the user was authenticated. The value of time can be nil as it only applies when a request with "max_age" was given
//
// Returns:
// - string: The signed ID token as a JWT string.
// - error: An error if token generation fails.
func (t *tokenCreator) CreateIDToken(
ctx context.Context,
userID string,
clientID string,
scopes types.Scope,
nonce string,
acrValues string,
authTime time.Time,
) (string, error) {
requestID := utils.GetRequestID(ctx)
IDToken, err := t.generateAndStoreToken(
ctx,
userID,
clientID,
scopes,
"",
nonce,
t.accessTokenDuration,
acrValues,
authTime,
nil,
)
if err != nil {
t.logger.Error(t.module, requestID, "[CreateIDToken]: Failed to create ID token: %v", err)
return "", errors.Wrap(err, "", "failed to create ID token")
}
return IDToken, nil
}
func (t *tokenCreator) generateAndStoreToken(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
duration int64,
acrValues string,
authTime time.Time,
claims *claims.ClaimsRequest,
) (string, error) {
requestID := utils.GetRequestID(ctx)
const maxRetries int = 5
var lastErr error
for i := range maxRetries {
signedToken, err := t.attemptTokenGeneration(
ctx,
subject,
audience,
scopes,
roles,
nonce,
duration,
acrValues,
authTime,
claims,
)
if err == nil {
return signedToken, nil
}
lastErr = err
t.logger.Warn(t.module, requestID, "[generateAndStoreToken]: Failed to generate token (attempt %d/%d): %v", i+1, maxRetries, err)
}
return "", errors.Wrap(lastErr, errors.ErrCodeInternalServerError, "failed to generate token after maximum retries")
}
func (t *tokenCreator) attemptTokenGeneration(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
duration int64,
acrValues string,
authTime time.Time,
claims *claims.ClaimsRequest,
) (string, error) {
requestID := utils.GetRequestID(ctx)
standardClaims, err := t.generateStandardClaims(
ctx,
subject,
audience,
scopes,
roles,
nonce,
duration,
acrValues,
authTime,
claims,
)
if err != nil {
t.logger.Error(t.module, requestID, "[attemptTokenGeneration]: An error occurred generating standard claims: %v", err)
return "", err
}
signedToken, err := t.jwtService.SignToken(ctx, standardClaims)
if err != nil {
t.logger.Error(t.module, requestID, "[attemptTokenGeneration]: Failed to sign token: %v", err)
return "", errors.Wrap(err, "", "failed to sign token")
}
hashedToken := utils.EncodeSHA256(signedToken)
tokenData := &token.TokenData{
Token: hashedToken,
ID: subject,
ExpiresAt: duration,
TokenID: t.keyID,
TokenClaims: standardClaims,
}
if err := t.repo.SaveToken(ctx, hashedToken, subject, tokenData, time.Now().Add(time.Duration(duration))); err != nil {
t.logger.Error(t.module, requestID, "[attemptTokenGeneration]: Failed to save token: %v", err)
return "", errors.Wrap(err, "", "failed to save token")
}
return signedToken, nil
}
func (t *tokenCreator) generateStandardClaims(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
duration int64,
acrValues string,
authTime time.Time,
claims *claims.ClaimsRequest,
) (*token.TokenClaims, error) {
requestID := utils.GetRequestID(ctx)
tokenID, err := t.cryptographer.GenerateRandomString(tokenIDLength)
if err != nil {
t.logger.Error(t.module, requestID, "[generateStandardClaims]: Failed to generate token ID: %v", err)
return nil, errors.Wrap(err, "", "failed to generate ID token")
}
tokenClaims := &token.TokenClaims{
Roles: roles,
StandardClaims: &jwt.StandardClaims{
Subject: subject,
Issuer: t.issuer,
IssuedAt: time.Now().Unix(),
ExpiresAt: duration,
},
}
existsByID, err := t.repo.ExistsByTokenID(ctx, tokenID)
if err != nil {
t.logger.Error(t.module, requestID, "[generateStandardClaims]: An error occurred retrieving the token: %v", err)
return nil, errors.Wrap(err, "", "failed to verify if token exists by token ID")
} else if existsByID {
t.logger.Warn(t.module, requestID, "[generateStandardClaims]: Failed to generate standard claims. Token already exists.")
return nil, errors.New(errors.ErrCodeInternalServerError, "token ID already exists")
} else {
tokenClaims.Id = tokenID
}
if audience != "" {
tokenClaims.Audience = audience
}
// VigiloAuth only supports password-based auth at the moment.
// This will be refactored in the future to support multiple ACRs.
if acrValues != "" {
acrValsArr := strings.Split(acrValues, " ")
if slices.Contains(acrValsArr, "1") {
tokenClaims.ACRValues = "1"
}
}
if nonce != "" {
tokenClaims.Nonce = nonce
}
if !authTime.IsZero() {
tokenClaims.AuthTime = authTime.Unix()
}
if scopes != "" {
tokenClaims.Scopes = scopes
}
if claims != nil {
tokenClaims.RequestedClaims = claims
}
return tokenClaims, nil
}
package service
import (
"context"
"net/http"
"strings"
"github.com/vigiloauth/vigilo/v2/idp/config"
authz "github.com/vigiloauth/vigilo/v2/internal/domain/authorization"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ tokens.TokenGrantProcessor = (*tokenGrantProcessor)(nil)
type tokenGrantProcessor struct {
tokenIssuer tokens.TokenIssuer
tokenManager tokens.TokenManager
clientAuthenticator client.ClientAuthenticator
userAuthenticator users.UserAuthenticator
authorization authz.AuthorizationService
tokenDuration int64
logger *config.Logger
module string
}
func NewTokenGrantProcessor(
tokenIssuer tokens.TokenIssuer,
tokenManager tokens.TokenManager,
clientAuthenticator client.ClientAuthenticator,
userAuthenticator users.UserAuthenticator,
authorization authz.AuthorizationService,
) tokens.TokenGrantProcessor {
return &tokenGrantProcessor{
tokenIssuer: tokenIssuer,
tokenManager: tokenManager,
clientAuthenticator: clientAuthenticator,
userAuthenticator: userAuthenticator,
authorization: authorization,
tokenDuration: int64(config.GetServerConfig().TokenConfig().AccessTokenDuration().Seconds()),
logger: config.GetServerConfig().Logger(),
module: "Token Issuer",
}
}
// IssueClientCredentialsToken issues a token using the Client Credentials grant type.
//
// Parameters:
// - ctx context.Context: The context for managing timeouts and cancellations.
// - clientID string: The ID of the client requesting the token.
// - clientSecret string: The secret associated with the client.
// - grantType string: The OAuth2 grant type being used (must be "client_credentials").
// - scopes types.Scope: The scopes to associate with the issued token.
//
// Returns:
// - *TokenResponse: The response containing the issued token.
// - error: An error if token issuance fails.
func (s *tokenGrantProcessor) IssueClientCredentialsToken(
ctx context.Context,
clientID string,
clientSecret string,
grantType string,
scopes types.Scope,
) (*tokens.TokenResponse, error) {
requestID := utils.GetRequestID(ctx)
req := &client.ClientAuthenticationRequest{
ClientID: clientID,
ClientSecret: clientSecret,
RequestedGrant: grantType,
RequestedScopes: scopes,
}
if err := s.clientAuthenticator.AuthenticateClient(ctx, req); err != nil {
s.logger.Error(s.module, requestID, "[IssueClientCredentialsToken]: Failed to authenticate client: %v", err)
return nil, errors.Wrap(err, "", "failed to authenticate client")
}
accessToken, refreshToken, err := s.tokenIssuer.IssueTokenPair(ctx, "", clientID, scopes, "", "", nil)
if err != nil {
s.logger.Error(s.module, requestID, "[IssueClientCredentialsToken]: Failed to issue tokens: %v", err)
return nil, errors.Wrap(err, "", "failed to issue access and refresh tokens")
}
return &tokens.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: tokens.BearerToken,
ExpiresIn: s.tokenDuration,
Scope: scopes,
}, nil
}
// IssueResourceOwnerToken issues a token using the Resource Owner Password Credentials grant type.
//
// Parameters:
// - ctx context.Context: The context for managing timeouts and cancellations.
// - clientID string: The ID of the client requesting the token.
// - clientSecret string: The secret associated with the client.
// - grantType string: The OAuth2 grant type being used (must be "password").
// - scopes types.Scope: The scopes to associate with the issued token.
// - user *users.UserLoginAttempt: The user's login attempt information including credentials.
//
// Returns:
// - *TokenResponse: The response containing the issued token.
// - error: An error if authentication or token issuance fails.
func (s *tokenGrantProcessor) IssueResourceOwnerToken(
ctx context.Context,
clientID string,
clientSecret string,
grantType string,
scopes types.Scope,
user *users.UserLoginRequest,
) (*tokens.TokenResponse, error) {
requestID := utils.GetRequestID(ctx)
req := &client.ClientAuthenticationRequest{
ClientID: clientID,
ClientSecret: clientSecret,
RequestedGrant: grantType,
RequestedScopes: scopes,
}
if err := s.clientAuthenticator.AuthenticateClient(ctx, req); err != nil {
s.logger.Error(s.module, requestID, "[IssueResourceOwnerToken]: Failed to authenticate client: %v", err)
return nil, errors.Wrap(err, "", "failed to authenticate client")
}
authenticatedUser, err := s.userAuthenticator.AuthenticateUser(ctx, user)
if err != nil {
s.logger.Error(s.module, requestID, "[IssueResourceOwnerToken]: Failed to authenticate user: %v", err)
return nil, errors.Wrap(err, "", "failed to authenticate user")
}
accessToken, refreshToken, err := s.tokenIssuer.IssueTokenPair(ctx, authenticatedUser.UserID, clientID, scopes, "", "", nil)
if err != nil {
s.logger.Error(s.module, requestID, "[IssueResourceOwnerToken]: Failed to issue tokens: %v", err)
return nil, errors.Wrap(err, "", "failed to issue token pair")
}
return &tokens.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: tokens.BearerToken,
ExpiresIn: s.tokenDuration,
Scope: scopes,
}, nil
}
// RefreshToken issues a new access token using a valid refresh token.
//
// Parameters:
// - ctx context.Context: The context for managing timeouts and cancellations.
// - clientID string: The ID of the client requesting the token.
// - clientSecret string: The secret associated with the client.
// - grantType string: The OAuth2 grant type being used (must be "refresh_token").
// - refreshToken string: The refresh token used to obtain a new access token.
// - scopes types.Scope: The scopes to associate with the new access token.
//
// Returns:
// - *TokenResponse: The response containing the new access token (and optionally a new refresh token).
// - error: An error if the refresh token is invalid or expired.
func (s *tokenGrantProcessor) RefreshToken(
ctx context.Context,
clientID string,
clientSecret string,
grantType string,
refreshToken string,
scopes types.Scope,
) (resp *tokens.TokenResponse, err error) {
requestID := utils.GetRequestID(ctx)
defer func() {
if err != nil || resp != nil {
if err := s.tokenManager.BlacklistToken(ctx, refreshToken); err != nil {
s.logger.Error(s.module, requestID, "[RefreshToken]: Failed to blacklist token: %v", err)
}
}
}()
req := &client.ClientAuthenticationRequest{
ClientID: clientID,
ClientSecret: clientSecret,
RequestedGrant: grantType,
RequestedScopes: scopes,
}
if err := s.clientAuthenticator.AuthenticateClient(ctx, req); err != nil {
s.logger.Error(s.module, requestID, "[RefreshToken]: Failed to authenticate client: %v", err)
return nil, errors.Wrap(err, "", "failed to authenticate client")
}
tokenData, err := s.tokenManager.GetTokenData(ctx, refreshToken)
if err != nil {
s.logger.Error(s.module, requestID, "[RefreshToken]: Failed to get token data: %v", err)
return nil, errors.New(errors.ErrCodeInvalidGrant, "invalid token")
}
audience := tokenData.TokenClaims.Audience
if clientID != audience {
s.logger.Error(s.module, requestID, "[RefreshToken]: Client ID does not match with associated refresh token")
return nil, errors.New(errors.ErrCodeInvalidGrant, "refresh token was issued to a different client")
}
if scopes == "" {
scopes = tokenData.TokenClaims.Scopes
} else {
requested := strings.Fields(scopes.String())
original := strings.Fields(tokenData.TokenClaims.Scopes.String())
if !utils.IsSubset(requested, original) {
return nil, errors.New(errors.ErrCodeInvalidRequest, "requested scopes exceed originally granted scopes")
}
}
userID := tokenData.TokenClaims.Subject
newAccessToken, newRefreshToken, err := s.tokenIssuer.IssueTokenPair(ctx, userID, clientID, scopes, "", "", nil)
if err != nil {
s.logger.Error(s.module, requestID, "[RefreshToken]: Failed to issue new access and refresh tokens: %v", err)
return nil, errors.Wrap(err, "", "failed to issue new tokens")
}
return &tokens.TokenResponse{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
TokenType: tokens.BearerToken,
ExpiresIn: s.tokenDuration,
Scope: scopes,
}, nil
}
// ExchangeAuthorizationCode creates access and refresh tokens based on a validated token exchange request.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - request *TokenRequest: The token request data.
//
// Returns:
// - *token.TokenResponse: A fully formed token response with access and refresh tokens.
// - error: An error if token generation fails.
func (s *tokenGrantProcessor) ExchangeAuthorizationCode(
ctx context.Context,
request *tokens.TokenRequest,
) (*tokens.TokenResponse, error) {
requestID := utils.GetRequestID(ctx)
authzCodeData, err := s.authorization.AuthorizeTokenExchange(ctx, request)
if err != nil {
s.logger.Error(s.module, requestID, "[ExchangeAuthorizationCode]: Failed to authorize token exchange: %v", err)
return nil, errors.Wrap(err, "", "failed to authorize token exchange")
}
accessToken, refreshToken, err := s.tokenIssuer.IssueTokenPair(
ctx,
authzCodeData.UserID,
authzCodeData.ClientID,
authzCodeData.Scope, "",
authzCodeData.Nonce,
authzCodeData.ClaimsRequest,
)
authzCodeData.AccessTokenHash = accessToken
if err := s.authorization.UpdateAuthorizationCode(ctx, authzCodeData); err != nil {
s.logger.Error(s.module, requestID, "[ExchangeAuthorizationCode]: Failed to update authorization code data: %v", err)
return nil, errors.Wrap(err, errors.ErrCodeInternalServerError, "something went wrong updating the authorization code")
}
if err != nil {
s.logger.Error(s.module, requestID, "[ExchangeAuthorizationCode]: Failed to issue access and refresh tokens: %v", err)
return nil, errors.Wrap(err, "", "failed to issue tokens")
}
IDToken, err := s.tokenIssuer.IssueIDToken(
ctx,
authzCodeData.UserID,
authzCodeData.ClientID,
authzCodeData.Scope,
authzCodeData.Nonce,
authzCodeData.ACRValues,
authzCodeData.UserAuthenticationTime,
)
if err != nil {
s.logger.Error(s.module, requestID, "[ExchangeAuthorizationCode]: Failed to issue ID token: %v", err)
return nil, errors.Wrap(err, "", "failed to issue the ID token")
}
return &tokens.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
IDToken: IDToken,
TokenType: tokens.BearerToken,
ExpiresIn: s.tokenDuration,
Scope: authzCodeData.Scope,
}, nil
}
// IntrospectToken verifies the validity of a given token by introspecting its details.
// This method checks whether the token is valid, expired, or revoked and returns the
// associated token information if it is valid.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - r *http.Request: The request for client authentication.
// - tokenStr string: The token to be introspected.
//
// Returns:
// - *TokenIntrospectionResponse: A struct containing token details such as
// validity, expiration, and any associated metadata. If the token is valid, this
// response will include all relevant claims associated with the token.
// error: An error if client authentication fails.
func (s *tokenGrantProcessor) IntrospectToken(
ctx context.Context,
r *http.Request,
tokenStr string,
) (*tokens.TokenIntrospectionResponse, error) {
requestID := utils.GetRequestID(ctx)
if err := s.clientAuthenticator.AuthenticateRequest(ctx, r, types.TokenIntrospectScope); err != nil {
s.logger.Error(s.module, requestID, "[Introspect Token]: Failed to authenticate client request: %v", err)
return nil, errors.Wrap(err, "", "failed to authenticate client")
}
response := s.tokenManager.Introspect(ctx, tokenStr)
return response, nil
}
// RevokeToken handles revoking the given token. The token can either be an Access token or a Refresh token.
// This method has no return values since the content of the response should be ignored by clients.
// If an error occurs during the process, the errors will be logged.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - r *http.Request: The request for client authentication.
// - tokenStr string: The token to be revoked.
//
// Returns:
// - error: An error if client authentication fails.
func (s *tokenGrantProcessor) RevokeToken(
ctx context.Context,
r *http.Request,
tokenStr string,
) error {
requestID := utils.GetRequestID(ctx)
if err := s.clientAuthenticator.AuthenticateRequest(ctx, r, types.TokenRevokeScope); err != nil {
s.logger.Error(s.module, requestID, "[RevokeToken]: Failed to authenticate client request: %v", err)
return errors.Wrap(err, "", "failed to authenticate client")
}
if err := s.tokenManager.Revoke(ctx, tokenStr); err != nil {
s.logger.Error(s.module, requestID, "[RevokeToken]: Failed to revoke token: %v", err)
return errors.Wrap(err, "", "failed to revoke token")
}
return nil
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
claims "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ token.TokenIssuer = (*tokenIssuer)(nil)
type tokenIssuer struct {
creator token.TokenCreator
logger *config.Logger
module string
}
func NewTokenIssuer(creator token.TokenCreator) token.TokenIssuer {
return &tokenIssuer{
creator: creator,
logger: config.GetServerConfig().Logger(),
module: "Token Creator",
}
}
func (t *tokenIssuer) IssueTokenPair(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
requestedClaims *claims.ClaimsRequest,
) (string, string, error) {
requestID := utils.GetRequestID(ctx)
accessToken, err := t.creator.CreateAccessTokenWithClaims(
ctx,
subject,
audience,
scopes,
roles,
nonce,
requestedClaims,
)
if err != nil {
t.logger.Error(t.module, requestID, "[IssueTokenPair]: Failed to issue access token: %v", err)
return "", "", errors.Wrap(err, "", "failed to issue access token")
}
refreshToken, err := t.creator.CreateRefreshToken(
ctx,
subject,
audience,
scopes,
roles,
nonce,
)
if err != nil {
t.logger.Error(t.module, requestID, "[IssueTokenPair]: Failed to issue refresh token: %v", err)
return "", "", errors.Wrap(err, "", "failed to issue refresh token")
}
return accessToken, refreshToken, nil
}
func (t *tokenIssuer) IssueIDToken(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
nonce string,
acrValues string,
authTime time.Time,
) (string, error) {
requestID := utils.GetRequestID(ctx)
IDToken, err := t.creator.CreateIDToken(
ctx,
subject,
audience,
"", // scopes are not included in ID token
nonce,
acrValues,
authTime,
)
if err != nil {
t.logger.Error(t.module, requestID, "[IssueTokenPair]: Failed to issue ID token: %v", err)
return "", errors.Wrap(err, "", "failed to issue ID token")
}
return IDToken, nil
}
func (t *tokenIssuer) IssueAccessToken(
ctx context.Context,
subject string,
audience string,
scopes types.Scope,
roles string,
nonce string,
) (string, error) {
requestID := utils.GetRequestID(ctx)
accessToken, err := t.creator.CreateAccessToken(
ctx,
subject,
audience,
scopes,
roles,
nonce,
)
if err != nil {
t.logger.Error(t.module, requestID, "[IssueAccessToken]: Failed to issue access token: %v", err)
return "", errors.Wrap(err, "", "failed to issue access token")
}
return accessToken, nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ token.TokenManager = (*tokenManager)(nil)
type tokenManager struct {
repo token.TokenRepository
parser token.TokenParser
validator token.TokenValidator
logger *config.Logger
module string
}
func NewTokenManager(
repo token.TokenRepository,
parser token.TokenParser,
validator token.TokenValidator,
) token.TokenManager {
return &tokenManager{
repo: repo,
parser: parser,
validator: validator,
logger: config.GetServerConfig().Logger(),
module: "Token Manager",
}
}
// Introspect checks the validity and metadata of the given token string.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenStr string: The token to introspect.
//
// Returns:
// - *TokenIntrospectionResponse: TokenIntrospectionResponse containing information about the token.
func (m *tokenManager) Introspect(ctx context.Context, tokenStr string) *token.TokenIntrospectionResponse {
requestID := utils.GetRequestID(ctx)
if _, err := m.repo.GetToken(ctx, tokenStr); err != nil {
m.logger.Warn(m.module, requestID, "[Introspect]: An error occurred retrieving the requested token: %v", err)
return &token.TokenIntrospectionResponse{Active: false}
}
tokenClaims, err := m.parser.ParseToken(ctx, tokenStr)
if err != nil {
m.logger.Error(m.module, requestID, "[Introspect]: An error occurred parsing the token: %v", err)
return &token.TokenIntrospectionResponse{Active: false}
}
response := token.NewTokenIntrospectionResponse(tokenClaims)
if err := m.validator.ValidateToken(ctx, tokenStr); err != nil {
m.logger.Warn(m.module, requestID, "[Introspect]: Token is either blacklisted or expired... Setting active to false")
response.Active = false
}
return response
}
// Revoke invalidates the given token string, rendering it unusable.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenStr string: The token to introspect.
//
// Returns:
// - error: An error if revocation fails.
func (m *tokenManager) Revoke(ctx context.Context, tokenStr string) error {
requestID := utils.GetRequestID(ctx)
hashedToken := utils.EncodeSHA256(tokenStr)
if err := m.repo.BlacklistToken(ctx, hashedToken); err != nil {
m.logger.Error(m.module, requestID, "[Revoke]: Failed to blacklist token: %v", err)
return errors.Wrap(err, "", "failed to revoke token")
}
return nil
}
// GetTokenData retrieves the token data from the token repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - token string: The token string to retrieve.
//
// Returns:
// - *TokenData: The TokenData if the token is valid, or nil if not found or invalid.
// - error: An error if the token is not found, expired, or the subject doesn't match.
func (m *tokenManager) GetTokenData(ctx context.Context, tokenStr string) (*token.TokenData, error) {
requestID := utils.GetRequestID(ctx)
hashedToken := utils.EncodeSHA256(tokenStr)
tokenData, err := m.repo.GetToken(ctx, hashedToken)
if err != nil {
m.logger.Error(m.module, requestID, "[GetTokenData]: Failed to retrieve token data: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve token data")
}
return tokenData, nil
}
// BlacklistToken adds the specified token to the blacklist, preventing it from being used
// for further authentication or authorization. The token is marked as invalid, even if it
// has not yet expired.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenStr string: The token to be blacklisted. This is the token that will no longer be valid for further use.
//
// Returns:
// - error: An error if the token is not found in the token store or if it has already expired, in which case it cannot be blacklisted.
func (m *tokenManager) BlacklistToken(ctx context.Context, tokenStr string) error {
requestID := utils.GetRequestID(ctx)
hashedToken := utils.EncodeSHA256(tokenStr)
if err := m.repo.BlacklistToken(ctx, hashedToken); err != nil {
m.logger.Error(m.module, requestID, "[BlacklistToken]: An error occurred while attempting to blacklist token: %v", err)
return errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to blacklist token")
}
return nil
}
// DeleteExpiredTokens retrieves expired tokens from the repository and deletes them.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
//
// Returns:
// - error: An error if retrieval or deletion fails.
func (m *tokenManager) DeleteExpiredTokens(ctx context.Context) error {
requestID := utils.GetRequestID(ctx)
tokens, err := m.repo.GetExpiredTokens(ctx)
if err != nil {
m.logger.Error(m.module, requestID, "[DeleteExpiredTokens]: Failed to retrieve expired tokens: %v", err)
return errors.Wrap(err, "", "failed to retrieve expired tokens")
}
for _, token := range tokens {
if err := m.repo.DeleteToken(ctx, token.Token); err != nil {
m.logger.Error(m.module, requestID, "[DeleteExpiredTokens]: Failed to delete an expired token: %v", err)
return errors.Wrap(err, errors.ErrCodeInternalServerError, "an error occurred deleting expired tokens")
}
}
return nil
}
// DeleteToken removes a token from the token repository.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - tokenStr string: The token string to delete.
//
// Returns:
// - error: An error if the token deletion fails.
func (m *tokenManager) DeleteToken(ctx context.Context, tokenStr string) error {
requestID := utils.GetRequestID(ctx)
hashedToken := utils.EncodeSHA256(tokenStr)
err := m.repo.DeleteToken(ctx, hashedToken)
if err != nil {
m.logger.Error(m.module, requestID, "[DeleteToken]: An error occurred deleting a token: %v", err)
return errors.Wrap(err, "", "an error occurred deleting the given token")
}
return nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
jwt "github.com/vigiloauth/vigilo/v2/internal/domain/jwt"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ token.TokenParser = (*tokenParser)(nil)
type tokenParser struct {
jwtService jwt.JWTService
logger *config.Logger
module string
}
func NewTokenParser(jwtService jwt.JWTService) token.TokenParser {
return &tokenParser{
jwtService: jwtService,
logger: config.GetServerConfig().Logger(),
module: "Token Parser",
}
}
// ParseToken parses a JWT token string into TokenClaims.
//
// Parameters:
// - ctx ctx.Context: Context for the request, containing the request ID for logging.
// - tokenString string: The JWT token string to parse and validate.
//
// Returns:
// - *token.TokenClaims: The parsed token claims if successful.
// - error: An error if token parsing, decryption, or validation fails.
func (t *tokenParser) ParseToken(ctx context.Context, tokenString string) (*token.TokenClaims, error) {
requestID := utils.GetRequestID(ctx)
claims, err := t.jwtService.ParseWithClaims(ctx, tokenString)
if err != nil {
t.logger.Error(t.module, requestID, "[ParseToken]: Failed to parse token: %v", err)
return nil, errors.Wrap(err, "", "failed to parse token with claims")
}
return claims, nil
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ token.TokenValidator = (*tokenValidator)(nil)
type tokenValidator struct {
repo token.TokenRepository
parser token.TokenParser
logger *config.Logger
module string
}
func NewTokenValidator(
tokenRepo token.TokenRepository,
tokenParser token.TokenParser,
) token.TokenValidator {
return &tokenValidator{
repo: tokenRepo,
parser: tokenParser,
logger: config.GetServerConfig().Logger(),
module: "Token Validator",
}
}
// ValidateToken checks to see if a token is blacklisted or expired.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - token string: The token string to check.
//
// Returns:
// - error: An error if the token is blacklisted or expired.
func (t *tokenValidator) ValidateToken(ctx context.Context, tokenStr string) error {
requestID := utils.GetRequestID(ctx)
if t.isTokenExpired(ctx, tokenStr) {
t.logger.Warn(t.module, requestID, "[ValidateToken]: Token '%s' is expired", utils.TruncateSensitive(tokenStr))
return errors.New(errors.ErrCodeExpiredToken, "the token is expired")
} else if t.isTokenBlacklisted(ctx, tokenStr) {
t.logger.Warn(t.module, requestID, "[ValidateToken]: Token '%s' is blacklisted", utils.TruncateSensitive(tokenStr))
return errors.New(errors.ErrCodeUnauthorized, "the token is blacklisted")
}
return nil
}
func (t *tokenValidator) isTokenExpired(ctx context.Context, tokenStr string) bool {
requestID := utils.GetRequestID(ctx)
claims, err := t.parser.ParseToken(ctx, tokenStr)
if err != nil {
t.logger.Warn(t.module, requestID, "[isTokenExpired]: Failed to parse token: %v", err)
return true
}
return time.Now().Unix() > claims.ExpiresAt
}
func (t *tokenValidator) isTokenBlacklisted(ctx context.Context, tokenStr string) bool {
requestID := utils.GetRequestID(ctx)
hashedToken := utils.EncodeSHA256(tokenStr)
isBlacklisted, err := t.repo.IsTokenBlacklisted(ctx, hashedToken)
if err != nil {
t.logger.Warn(t.module, requestID, "[isTokenBlacklisted]: An error occurred checking if the token was blacklisted: %v", err)
return true
}
return isBlacklisted
}
package service
import (
"context"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
login "github.com/vigiloauth/vigilo/v2/internal/domain/login"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ users.UserAuthenticator = (*userAuthenticator)(nil)
type RequestMetadata struct {
IPAddress string
UserAgent string
}
type userAuthenticator struct {
repo users.UserRepository
auditLogger audit.AuditLogger
loginAttemptService login.LoginAttemptService
artificialDelay time.Duration
maxFailedAuthenticationAttempts int
logger *config.Logger
module string
}
func NewUserAuthenticator(
repo users.UserRepository,
auditLogger audit.AuditLogger,
loginAttemptService login.LoginAttemptService,
) users.UserAuthenticator {
return &userAuthenticator{
repo: repo,
auditLogger: auditLogger,
loginAttemptService: loginAttemptService,
artificialDelay: config.GetServerConfig().LoginConfig().Delay(),
maxFailedAuthenticationAttempts: config.GetServerConfig().LoginConfig().MaxFailedAttempts(),
logger: config.GetServerConfig().Logger(),
module: "User Authenticator",
}
}
// AuthenticateUserWithRequest authenticates a user based on a login request and request metadata.
//
// This method constructs a User object and a UserLoginAttempt object from the provided
// login request and HTTP request metadata, then delegates the authentication process
// to the AuthenticateUser method.
//
// Parameters:
// - ctx Context: The context for managing timeouts, cancellations, and for retrieving/storing request metadata.
// - request *UserLoginRequest: The login request containing the user's email and password.
//
// Returns:
// - *UserLoginResponse: The response containing user information and a JWT token if authentication is successful.
// - error: An error if authentication fails or if the input is invalid.
func (u *userAuthenticator) AuthenticateUser(
ctx context.Context,
request *users.UserLoginRequest,
) (res *users.UserLoginResponse, err error) {
requestID := utils.GetRequestID(ctx)
startTime := time.Now()
defer u.applyArtificialDelay(startTime)
requestMetadata := u.extractMetadataFromContext(ctx)
loginAttempt := &users.UserLoginAttempt{
Timestamp: time.Now().UTC(),
IPAddress: requestMetadata.IPAddress,
UserAgent: requestMetadata.UserAgent,
}
user, err := u.repo.GetUserByUsername(ctx, request.Username)
if err != nil {
u.logger.Error(u.module, requestID, "[AuthenticateUser]: Failed to retrieve user by username: %v", err)
u.logAuthenticationAttempt(ctx, false, err, "")
return nil, errors.Wrap(err, errors.ErrCodeInvalidCredentials, "username or password are incorrect")
}
defer func() {
var userID string
if user != nil {
userID = user.ID
}
if err != nil {
u.logAuthenticationAttempt(ctx, false, err, userID)
} else {
u.logAuthenticationAttempt(ctx, true, err, userID)
}
}()
loginAttempt.UserID = user.ID
defer func() {
if err != nil {
if err := u.loginAttemptService.HandleFailedLoginAttempt(ctx, user, loginAttempt); err != nil {
u.logger.Error(u.module, requestID, "[AuthenticateUser]: An error occurred while handling the failed auth attempt: %v", err)
}
}
}()
if user.AccountLocked {
err := errors.New(errors.ErrCodeAccountLocked, "account has been locked due to too many failed attempts")
u.logAuthenticationAttempt(ctx, false, err, user.ID)
return nil, err
}
if err := u.comparePasswords(request.Password, user.Password); err != nil {
u.logger.Error(u.module, requestID, "[AuthenticateUser]: Failed to compare passwords: %v", err)
u.logAuthenticationAttempt(ctx, false, err, user.ID)
return nil, errors.Wrap(err, "", "failed to authenticate user")
}
user.LastFailedLogin = time.Time{}
if err := u.repo.UpdateUser(ctx, user); err != nil {
u.logger.Error(u.module, requestID, "[AuthenticateUser]: Failed to update user: %v", err)
u.logAuthenticationAttempt(ctx, false, err, user.ID)
return nil, errors.Wrap(err, "", "failed to update user")
}
if err := u.loginAttemptService.SaveLoginAttempt(ctx, loginAttempt); err != nil {
u.logger.Error(u.module, requestID, "[AuthenticateUser]: Failed to save authentication attempt: %v", err)
return nil, errors.Wrap(err, "", "failed to save authentication attempt")
}
u.logAuthenticationAttempt(ctx, true, nil, user.ID)
return users.NewUserLoginResponse(user), nil
}
func (u *userAuthenticator) applyArtificialDelay(startTime time.Time) {
elapsed := time.Since(startTime)
if elapsed < u.artificialDelay {
time.Sleep(u.artificialDelay - elapsed)
}
}
func (u *userAuthenticator) logAuthenticationAttempt(ctx context.Context, success bool, err error, userID string) {
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyUserID, userID)
u.auditLogger.StoreEvent(ctx, audit.LoginAttempt, success, audit.AuthenticationAction, audit.OAuthMethod, err)
}
func (u *userAuthenticator) comparePasswords(password string, hashedPassword string) error {
passwordsAreEqual := utils.CompareHash(password, hashedPassword)
if !passwordsAreEqual {
return errors.New(errors.ErrCodeInvalidCredentials, "invalid credentials")
}
return nil
}
func (u *userAuthenticator) extractMetadataFromContext(ctx context.Context) RequestMetadata {
var requestMetadata RequestMetadata
if IP := utils.GetValueFromContext(ctx, constants.ContextKeyIPAddress); IP != "" {
requestMetadata.IPAddress, _ = IP.(string)
}
if userAgent := utils.GetValueFromContext(ctx, constants.ContextKeyUserAgent); userAgent != "" {
requestMetadata.UserAgent, _ = userAgent.(string)
}
return requestMetadata
}
package service
import (
"context"
"strings"
"time"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/internal/constants"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
crypto "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
emails "github.com/vigiloauth/vigilo/v2/internal/domain/email"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ users.UserCreator = (*userCreator)(nil)
type userCreator struct {
repo users.UserRepository
issuer tokens.TokenIssuer
audit audit.AuditLogger
email emails.EmailService
cryptographer crypto.Cryptographer
logger *config.Logger
module string
}
func NewUserCreator(
repo users.UserRepository,
issuer tokens.TokenIssuer,
audit audit.AuditLogger,
email emails.EmailService,
cryptographer crypto.Cryptographer,
) users.UserCreator {
return &userCreator{
repo: repo,
issuer: issuer,
audit: audit,
email: email,
cryptographer: cryptographer,
logger: config.GetServerConfig().Logger(),
module: "User Creator",
}
}
// CreateUser creates a new user in the system.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - user *User: The user to register.
//
// Returns:
// - *UserRegistrationResponse: The registered user object and an access token.
// - error: An error if any occurred during the process.
func (u *userCreator) CreateUser(
ctx context.Context,
user *users.User,
) (res *users.UserRegistrationResponse, err error) {
requestID := utils.GetRequestID(ctx)
defer func() {
if err != nil {
u.audit.StoreEvent(ctx, audit.RegistrationAttempt, false, audit.RegistrationAction, audit.EmailMethod, err)
} else {
u.audit.StoreEvent(ctx, audit.RegistrationAttempt, true, audit.RegistrationAction, audit.EmailMethod, err)
}
}()
encryptedPassword, err := u.cryptographer.HashString(user.Password)
if err != nil {
u.logger.Error(u.module, requestID, "[CreateUser]: Failed to encrypt password: %v", err)
return nil, errors.Wrap(err, "", "failed to encrypt password")
}
user.Password = encryptedPassword
user.CreatedAt = time.Now()
user.UpdatedAt = time.Now()
user.EmailVerified = false
user.PhoneNumberVerified = false
if user.HasRole(constants.AdminRole) {
user.ID = constants.AdminRoleIDPrefix + utils.GenerateUUID()
} else {
user.ID = constants.UserRoleIDPrefix + utils.GenerateUUID()
}
if err := u.repo.AddUser(ctx, user); err != nil {
u.logger.Error(u.module, requestID, "[CreateUser]: Failed to save user: %v", err)
return nil, errors.Wrap(err, "", "failed to save user")
}
accessToken, verificationCode, err := u.issuer.IssueTokenPair(
ctx,
user.Email,
user.ID, "",
strings.Join(user.Roles, " "),
"", nil,
)
if err != nil {
u.logger.Error(u.module, requestID, "[CreateUser]: Failed to generate verification code: %v", err)
return nil, errors.Wrap(err, "", "failed to generate verification code")
}
emailRequest := emails.NewEmailRequest(user.Email, verificationCode, verificationCode, emails.AccountVerification)
if err := u.email.SendEmail(ctx, emailRequest); err != nil {
u.logger.Error(u.module, requestID, "[CreateUser]: Failed to send verification email: %v", err)
return nil, errors.Wrap(err, "", "failed to send verification email")
}
return users.NewUserRegistrationResponse(user, accessToken), nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
crypto "github.com/vigiloauth/vigilo/v2/internal/domain/crypto"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ users.UserManager = (*userManager)(nil)
type userManager struct {
repo users.UserRepository
parser tokens.TokenParser
manager tokens.TokenManager
cryptographer crypto.Cryptographer
logger *config.Logger
module string
}
func NewUserManager(
repo users.UserRepository,
parser tokens.TokenParser,
manager tokens.TokenManager,
cryptographer crypto.Cryptographer,
) users.UserManager {
return &userManager{
repo: repo,
parser: parser,
manager: manager,
cryptographer: cryptographer,
logger: config.GetServerConfig().Logger(),
module: "User Manager",
}
}
// GetUserByUsername retrieves a user using their username.
//
// Parameter:
// - ctx Context: The context for managing timeouts and cancellations.
// - username string: The username of the user to retrieve.
//
// Returns:
// - *User: The retrieved user, otherwise nil.
// - error: If an error occurs retrieving the user.
func (u *userManager) GetUserByUsername(ctx context.Context, username string) (*users.User, error) {
requestID := utils.GetRequestID(ctx)
user, err := u.repo.GetUserByUsername(ctx, username)
if err != nil {
u.logger.Error(u.module, requestID, "[GetUserByUsername]: Failed to get user by username: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve user")
}
return user, nil
}
// GetUserByID retrieves a user from the store using their ID.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The ID used to retrieve the user.
//
// Returns:
// - *User: The User object if found, or nil if not found.
// - error: If an error occurs retrieving the user.
func (u *userManager) GetUserByID(ctx context.Context, userID string) (*users.User, error) {
requestID := utils.GetRequestID(ctx)
user, err := u.repo.GetUserByID(ctx, userID)
if err != nil {
u.logger.Error(u.module, requestID, "[GetUserByID]: Failed to get user by ID: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve user")
}
return user, nil
}
// DeleteUnverifiedUsers deletes any user that hasn't verified their account and
// has been created for over a week.
//
// Parameter:
// - ctx Context: The context for managing timeouts and cancellations.
//
// Returns:
// - error: an error if deletion fails, otherwise nil.
func (u *userManager) DeleteUnverifiedUsers(ctx context.Context) error {
requestID := utils.GetRequestID(ctx)
unverifiedUsers, err := u.repo.FindUnverifiedUsersOlderThanWeek(ctx)
if err != nil {
u.logger.Error(u.module, requestID, "[DeleteUnverifiedUsers]: Failed to retrieve unverified users: %v", err)
return errors.Wrap(err, "", "failed to retrieve unverified users")
}
for _, user := range unverifiedUsers {
if err := u.repo.DeleteUserByID(ctx, user.ID); err != nil {
u.logger.Error(u.module, requestID, "[DeleteUnverifiedUsers]: Failed to delete user by ID: %v", err)
return errors.Wrap(err, "", "failed to delete user by ID")
}
}
return nil
}
// ResetPassword resets the user's password using the provided reset token.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userEmail string: The user's email address.
// - newPassword string: The new password.
// - resetToken string: The reset token.
//
// Returns:
// - *users.UserPasswordResetResponse: A response message.
// - error: An error if the operation fails.
func (u *userManager) ResetPassword(
ctx context.Context,
userEmail string,
newPassword string,
resetToken string,
) (*users.UserPasswordResetResponse, error) {
requestID := utils.GetRequestID(ctx)
defer func() {
if err := u.manager.BlacklistToken(ctx, resetToken); err != nil {
u.logger.Error(u.module, requestID, "[ResetPassword]: Failed to blacklist reset token: %v", err)
}
}()
user, err := u.repo.GetUserByEmail(ctx, userEmail)
if err != nil {
u.logger.Error(u.module, requestID, "[ResetPassword]: Failed to retrieve user: %v", err)
return nil, errors.Wrap(err, "", "user not found")
}
parsedToken, err := u.parser.ParseToken(ctx, resetToken)
if err != nil {
u.logger.Error(u.module, requestID, "[ResetPassword]: Failed to parse reset token")
return nil, errors.Wrap(err, "", "invalid reset token")
}
if user.ID != parsedToken.Subject {
u.logger.Error(u.module, requestID, "[ResetPassword]: User ID does not match the token subject")
return nil, errors.New(errors.ErrCodeUnauthorized, "invalid credentials")
}
encryptedPassword, err := u.cryptographer.HashString(newPassword)
if err != nil {
u.logger.Error(u.module, requestID, "[ResetPassword]: Failed to encrypt password: %v", err)
return nil, errors.Wrap(err, "", "failed to encrypt password")
}
if user.AccountLocked {
u.logger.Debug(u.module, requestID, "[ResetPassword]: Unlocking account for user=[%s]", (userEmail))
user.AccountLocked = false
}
user.Password = encryptedPassword
if err := u.repo.UpdateUser(ctx, user); err != nil {
u.logger.Error(u.module, requestID, "[ResetPassword]: Failed to update user: %v", err)
return nil, errors.Wrap(err, "", "failed to update password")
}
return &users.UserPasswordResetResponse{
Message: "Password has been reset successfully",
}, nil
}
package service
import (
"context"
"github.com/vigiloauth/vigilo/v2/idp/config"
tokens "github.com/vigiloauth/vigilo/v2/internal/domain/token"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/utils"
)
var _ users.UserVerifier = (*userVerifier)(nil)
type userVerifier struct {
repo users.UserRepository
parser tokens.TokenParser
validator tokens.TokenValidator
manager tokens.TokenManager
logger *config.Logger
module string
}
func NewUserVerifier(
repo users.UserRepository,
parser tokens.TokenParser,
validator tokens.TokenValidator,
manager tokens.TokenManager,
) users.UserVerifier {
return &userVerifier{
repo: repo,
parser: parser,
validator: validator,
manager: manager,
logger: config.GetServerConfig().Logger(),
module: "User Verifier",
}
}
// VerifyEmailAddress validates the verification code and marks the user's email as verified.
//
// Parameter:
// - ctx Context: The context for managing timeouts and cancellations.
// - verificationCode string: The verification code to verify.
//
// Returns:
// - error: an error if verification fails, otherwise nil.
func (u *userVerifier) VerifyEmailAddress(ctx context.Context, verificationCode string) error {
requestID := utils.GetRequestID(ctx)
defer func() {
if err := u.manager.BlacklistToken(ctx, verificationCode); err != nil {
u.logger.Error(u.module, requestID, "[VerifyEmailAddress]: Failed to blacklist verification code: %v", err)
}
}()
if err := u.validator.ValidateToken(ctx, verificationCode); err != nil {
u.logger.Error(u.module, requestID, "[VerifyEmailAddress]: Failed to validate verification code: %v", err)
return errors.New(errors.ErrCodeUnauthorized, "the verification code either does not exist or is expired")
}
tokenClaims, err := u.parser.ParseToken(ctx, verificationCode)
if err != nil {
u.logger.Error(u.module, requestID, "[VerifyEmailAddress]: Failed to parse verification code: %v", err)
return errors.Wrap(err, "", "failed to parse token")
}
user, err := u.repo.GetUserByID(ctx, tokenClaims.Subject)
if err != nil {
u.logger.Error(u.module, requestID, "[VerifyEmailAddress]: Failed to retrieve user by ID: %v", err)
return errors.Wrap(err, errors.ErrCodeUnauthorized, "failed to retrieve user")
}
user.EmailVerified = true
if err := u.repo.UpdateUser(ctx, user); err != nil {
u.logger.Error(u.module, requestID, "[VerifyEmailAddress]: Failed to update user")
return errors.Wrap(err, "", "failed to update user")
}
return nil
}
package service
import (
"context"
"fmt"
"net/http"
"net/url"
"github.com/vigiloauth/vigilo/v2/idp/config"
clients "github.com/vigiloauth/vigilo/v2/internal/domain/client"
session "github.com/vigiloauth/vigilo/v2/internal/domain/session"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
consent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
)
// Compile-time interface implementation check
var _ consent.UserConsentService = (*userConsentService)(nil)
// userConsentService implements the UserConsentService interface
// and manages user consent-related operations by coordinating
// between consent and user repositories.
type userConsentService struct {
consentRepo consent.UserConsentRepository
userRepo users.UserRepository
sessionService session.SessionService
clientManager clients.ClientManager
logger *config.Logger
module string
}
func NewUserConsentService(
consentRepo consent.UserConsentRepository,
userRepo users.UserRepository,
sessionService session.SessionService,
clientManager clients.ClientManager,
) consent.UserConsentService {
return &userConsentService{
consentRepo: consentRepo,
userRepo: userRepo,
sessionService: sessionService,
clientManager: clientManager,
logger: config.GetServerConfig().Logger(),
module: "User Consent Service",
}
}
// CheckUserConsent verifies if a user has previously granted consent to a client
// for the requested scope.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The unique identifier of the user.
// - clientID string: The identifier of the client application requesting access.
// - scope string: The space-separated list of permissions being requested.
//
// Returns:
// - bool: True if consent exists, false if consent is needed.
// - error: An error if the consent check operation fails.
func (c *userConsentService) CheckUserConsent(
ctx context.Context,
userID string,
clientID string,
scope types.Scope,
) (bool, error) {
requestID := utils.GetRequestID(ctx)
if _, err := c.userRepo.GetUserByID(ctx, userID); err != nil {
c.logger.Error(c.module, requestID, "[CheckUserConsent]: An error occurred retrieving a user by ID: %v", err)
return false, errors.Wrap(err, "", "failed to retrieve user by ID")
}
hasConsent, err := c.consentRepo.HasConsent(ctx, userID, clientID, scope)
if err != nil {
c.logger.Error(c.module, requestID, "[CheckUserConsent]: Failed to check user consent: %v", err)
return false, errors.Wrap(err, "", "failed to verify consent")
}
return hasConsent, nil
}
// SaveUserConsent records a user's consent for a client application
// to access resources within the specified scope.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The unique identifier of the user granting consent.
// - clientID string: The identifier of the client application receiving consent.
// - scope string: The space-separated list of permissions being granted.
//
// Returns:
// - error: An error if the consent cannot be saved, or nil if successful.
func (c *userConsentService) SaveUserConsent(
ctx context.Context,
userID string,
clientID string,
scope types.Scope,
) error {
requestID := utils.GetRequestID(ctx)
if _, err := c.userRepo.GetUserByID(ctx, userID); err != nil {
c.logger.Error(c.module, requestID, "[SaveUserConsent]: An error occurred retrieving the user by ID: %v", err)
return errors.Wrap(err, "", "failed to retrieve user by ID")
}
if err := c.consentRepo.SaveConsent(ctx, userID, clientID, scope); err != nil {
c.logger.Error(c.module, requestID, "[SaveUserConsent]: Failed to save user consent: %v", err)
return errors.Wrap(err, "", "failed to process consent decision")
}
return nil
}
// RevokeConsent removes a user's consent for a client.
//
// Parameters:
// - ctx Context: The context for managing timeouts and cancellations.
// - userID string: The ID of the user.
// - clientID string: The ID of the client application.
//
// Returns:
// - error: An error if the consent cannot be revoked, or nil if successful.
func (c *userConsentService) RevokeConsent(ctx context.Context, userID, clientID string) error {
requestID := utils.GetRequestID(ctx)
if _, err := c.userRepo.GetUserByID(ctx, userID); err != nil {
c.logger.Error(c.module, requestID, "[RevokeConsent]: An error occurred retrieving the user by ID: %v", err)
return errors.Wrap(err, "", "failed to revoke user consent")
}
if err := c.consentRepo.RevokeConsent(ctx, userID, clientID); err != nil {
c.logger.Error(c.module, requestID, "[RevokeConsent]: Failed to revoke consent: %v", err)
return errors.Wrap(err, "", "failed to revoke consent")
}
return nil
}
// GetConsentDetails retrieves the details required for the user consent process.
//
// This method fetches information about the client application and the requested scopes,
// and prepares the response to be displayed to the user for consent.
//
// Parameters:
// - userID string: The unique identifier of the user.
// - clientID string: The identifier of the client application requesting access.
// - redirectURI string: The redirect URI provided by the client application.
// - scope string: The space-separated list of permissions being requested.
// - r *http.Request: The HTTP request containing session and other metadata.
//
// Returns:
// - *consent.UserConsentResponse: The response containing client and scope details for the consent process.
// - error: An error if the details cannot be retrieved or prepared.
func (c *userConsentService) GetConsentDetails(
userID string,
clientID string,
redirectURI string,
state string,
scope types.Scope,
responseType string,
nonce string,
display string,
r *http.Request,
) (*consent.UserConsentResponse, error) {
ctx := r.Context()
requestID := utils.GetRequestID(ctx)
if err := c.validateRequest(userID, clientID, redirectURI, scope); err != nil {
c.logger.Error(c.module, requestID, "[GetConsentDetails]: Failed to retrieve consent details: %v", err)
wrappedErr := errors.Wrap(err, "", "invalid request parameters")
return nil, wrappedErr
}
client, err := c.clientManager.GetClientByID(ctx, clientID)
if err != nil {
c.logger.Error(c.module, requestID, "[GetConsentDetails]: An error occurred retrieving client by ID: %v", err)
return nil, errors.Wrap(err, "", "failed to retrieve client by ID")
}
sessionData, err := c.sessionService.GetSessionData(r)
if err != nil {
wrappedErr := errors.Wrap(err, "", "failed to get session data")
c.logger.Error(c.module, requestID, "[GetConsentDetails]: Failed to retrieve consent details: %v", err)
return nil, wrappedErr
}
if err := c.updateSessionWithConsentDetails(r, sessionData, clientID, state, redirectURI); err != nil {
c.logger.Error(c.module, requestID, "[GetConsentDetails]: Failed to update session: %v", err)
return nil, err
}
approved, err := c.CheckUserConsent(ctx, userID, clientID, scope)
if err != nil {
c.logger.Error(c.module, requestID, "[GetConsentDetails]: Failed to check user consent: %v", err)
return nil, err
}
if approved {
c.logger.Debug(c.module, requestID, "[GetConsentDetails]: User has previously given consent. Processing approval.")
consentRequest := &consent.UserConsentRequest{
ResponseType: responseType,
State: state,
Nonce: nonce,
Display: display,
}
return c.processApprovedConsent(ctx, userID, clientID, scope, consentRequest)
}
scopeList := types.ParseScopesString(scope.String())
return &consent.UserConsentResponse{
Approved: approved,
ClientID: clientID,
ClientName: client.Name,
RedirectURI: redirectURI,
Scopes: scopeList,
State: state,
ConsentEndpoint: web.OAuthEndpoints.UserConsent,
}, nil
}
// ProcessUserConsent processes the user's decision for the consent request.
//
// This method handles the user's approval or denial of the requested scopes,
// stores the consent decision if approved, and generates the appropriate response
// (e.g., an authorization code or an error redirect).
//
// Parameters:
// - userID string: The unique identifier of the user.
// - clientID string: The identifier of the client application requesting access.
// - redirectURI string: The redirect URI provided by the client application.
// - scope string: The space-separated list of permissions being requested.
// - consentRequest *consent.UserConsentRequest: The user's consent decision and approved scopes.
// - r *http.Request: The HTTP request containing session and other metadata.
//
// Returns:
// - *consent.UserConsentResponse: The response containing the result of the consent process (e.g., success or denial).
// - error: An error if the consent decision cannot be processed or stored.
func (c *userConsentService) ProcessUserConsent(
userID string,
clientID string,
redirectURI string,
scope types.Scope,
consentRequest *consent.UserConsentRequest,
r *http.Request,
) (*consent.UserConsentResponse, error) {
requestID := utils.GetRequestID(r.Context())
if err := c.validateRequest(userID, clientID, redirectURI, scope); err != nil {
wrappedErr := errors.Wrap(err, "", "invalid request parameters")
c.logger.Error(c.module, requestID, "[ProcessUserConsent]: Failed to process user consent: %v", err)
return nil, wrappedErr
}
if !consentRequest.Approved {
c.logger.Warn(c.module, requestID, "[ProcessUserConsent]: Creating error response for denied consent")
return c.handleDeniedConsent(consentRequest.State, redirectURI), nil
}
return c.processApprovedConsent(r.Context(), userID, clientID, scope, consentRequest)
}
func (c *userConsentService) handleDeniedConsent(state, redirectURI string) *consent.UserConsentResponse {
errorURL := fmt.Sprintf("%s?error=access_denied&error_description=%s",
redirectURI, url.QueryEscape("user denied access to the requested scope"))
if state != "" {
errorURL = fmt.Sprintf("%s&state=%s", errorURL, url.QueryEscape(state))
}
return &consent.UserConsentResponse{
Error: errors.ErrCodeAccessDenied,
RedirectURI: errorURL,
}
}
func (c *userConsentService) validateRequest(
userID string,
clientID string,
redirectURI string,
scope types.Scope,
) error {
if userID == "" || clientID == "" || redirectURI == "" || scope == "" {
c.logger.Error(c.module, "", "Missing required OAuth parameters in request")
return errors.New(errors.ErrCodeBadRequest, "missing required OAuth parameters")
}
return nil
}
func (c *userConsentService) getApprovedScopes(defaultScopes types.Scope, requestedScopes []types.Scope) types.Scope {
if len(requestedScopes) > 0 {
return types.CombineScopes(requestedScopes...)
}
return defaultScopes
}
func (c *userConsentService) processApprovedConsent(
ctx context.Context,
userID string,
clientID string,
scope types.Scope,
consentRequest *consent.UserConsentRequest,
) (*consent.UserConsentResponse, error) {
requestID := utils.GetRequestID(ctx)
approvedScopes := c.getApprovedScopes(scope, consentRequest.Scopes)
if err := c.consentRepo.SaveConsent(ctx, userID, clientID, approvedScopes); err != nil {
wrappedErr := errors.Wrap(err, "", "failed to save user consent")
c.logger.Error(c.module, requestID, "Failed to save user consent")
return nil, wrappedErr
}
c.logger.Debug(c.module, requestID, "Building success response for approved consent")
return &consent.UserConsentResponse{
Success: true,
Approved: true,
}, nil
}
func (c *userConsentService) updateSessionWithConsentDetails(r *http.Request, sessionData *session.SessionData, clientID, state, redirectURI string) error {
c.logger.Info(c.module, "", "Updating session with consent details for sessionID=%s, clientID=%s, redirectURI=%s",
utils.TruncateSensitive(sessionData.ID), utils.TruncateSensitive(clientID), utils.SanitizeURL(redirectURI))
sessionData.ClientID = clientID
sessionData.RedirectURI = redirectURI
sessionData.State = state
if err := c.sessionService.UpdateSession(r, sessionData); err != nil {
wrappedErr := errors.Wrap(err, "", "failed to update session")
c.logger.Error(c.module, "", "Failed to update session with consent details: %v", err.Error())
return wrappedErr
}
return nil
}
package types
type ClientType string
const (
ConfidentialClient ClientType = "confidential"
PublicClient ClientType = "public"
)
var SupportedClientTypes = map[ClientType]bool{
ConfidentialClient: true,
PublicClient: true,
}
func (c ClientType) String() string {
return string(c)
}
package types
type CodeChallengeMethod string
const (
PlainCodeChallengeMethod CodeChallengeMethod = "plain"
SHA256CodeChallengeMethod CodeChallengeMethod = "S256"
)
var SupportedCodeChallengeMethods = map[CodeChallengeMethod]bool{
PlainCodeChallengeMethod: true,
SHA256CodeChallengeMethod: true,
}
func (c CodeChallengeMethod) String() string {
return string(c)
}
package types
import (
"slices"
"strings"
)
// Scope represents an OAuth 2.0 or OpenID Connect scope.
// Scopes define the level of access that a client is requesting.
type Scope string
const (
// TokenIntrospectScope allows introspection of access tokens,
// typically used by resource servers or internal services.
TokenIntrospectScope Scope = "tokens:introspect"
// TokenRevokeScope allows revocation of access or refresh tokens.
// Used by clients or systems that need to invalidate tokens.
TokenRevokeScope Scope = "tokens:revoke"
// OpenIDScope is required for OpenID Connect authentication.
// It indicates that the application intends to use OIDC for identity-related information.
OpenIDScope Scope = "openid"
// UserProfileScope grants access to basic profile information of the user.
// This includes:
// - name
// - family name
// - given name
// - middle name
// - preferred username
// - profile URL
// - picture URL
// - website URL
// - gender
// - birthdate
// - timezone
// - locale
// - updated_at timestamp
UserProfileScope Scope = "profile"
// UserEmailScope grants access to the user's email information, including:
// - email address
// - email verified status
UserEmailScope Scope = "email"
// UserPhoneScope grants access to the user's phone number information, including:
// - phone number
// - phone number verified status
UserPhoneScope Scope = "phone"
// UserAddressScope grants access to the user's address information, including:
// - formatted address
// - street address
// - locality (e.g., city)
// - region (e.g., state or province)
// - postal code
// - country
UserAddressScope Scope = "address"
// UserOfflineAccessScope grants the client access to the user's information while they are offline.
// Typically used to request a refresh token for long-lived access.
UserOfflineAccessScope Scope = "offline_access"
)
// SupportedScopes defines the set of recognized and allowed scopes within the application.
// Keys are the supported scope values; values indicate support (true = supported).
var SupportedScopes = map[Scope]bool{
TokenIntrospectScope: true,
TokenRevokeScope: true,
OpenIDScope: true,
UserProfileScope: true,
UserEmailScope: true,
UserPhoneScope: true,
UserAddressScope: true,
UserOfflineAccessScope: true,
}
func (s Scope) String() string {
return string(s)
}
// ParseScopesString converts a space-delimited scope string into a slice of Scope types
func ParseScopesString(scopeStr string) []Scope {
if scopeStr == "" {
return []Scope{}
}
parts := strings.Split(scopeStr, " ")
scopes := make([]Scope, len(parts))
for i, part := range parts {
scopes[i] = Scope(part) // Or however you convert string to Scope
}
return scopes
}
// ContainsScope checks if a scope is in a slice of scopes
func ContainsScope(scopes []Scope, target Scope) bool {
return slices.Contains(scopes, target)
}
// NewScopeList creates a single Scope value from multiple Scope values
func NewScopeList(scopes ...Scope) Scope {
var combined strings.Builder
for i, scope := range scopes {
if i > 0 {
combined.WriteString(" ")
}
combined.WriteString(scope.String())
}
return Scope(combined.String())
}
// CombineScopes combines multiple Scope values into a single Scope
func CombineScopes(scopes ...Scope) Scope {
if len(scopes) == 0 {
return Scope("")
}
var combined strings.Builder
for i, scope := range scopes {
if i > 0 {
combined.WriteString(" ")
}
combined.WriteString(scope.String())
}
return Scope(combined.String())
}
package types
// TokenAuthMethod represents supported OAuth 2.0 token endpoint authentication methods.
// These values are typically used in the client authentication process when obtaining an access token.
type TokenAuthMethod string
const (
// NoTokenAuth indicates that no client authentication is used at the token endpoint.
NoTokenAuth TokenAuthMethod = "none"
// ClientSecretPostTokenAuth indicates client authentication using HTTP POST parameters (client_id and client_secret in the body).
ClientSecretPostTokenAuth TokenAuthMethod = "client_secret_post"
// ClientSecretBasicTokenAuth indicates client authentication using HTTP Basic Authentication (client_id and client_secret in the Authorization header).
ClientSecretBasicTokenAuth TokenAuthMethod = "client_secret_basic"
)
// SupportedTokenEndpointAuthMethods defines the set of supported and recognized token endpoint authentication methods.
// This can be used for validating incoming configuration or requests.
var SupportedTokenEndpointAuthMethods = map[TokenAuthMethod]bool{
NoTokenAuth: true,
ClientSecretBasicTokenAuth: true,
ClientSecretPostTokenAuth: true,
}
func (t TokenAuthMethod) String() string {
return string(t)
}
package utils
import (
"context"
"github.com/vigiloauth/vigilo/v2/internal/constants"
)
// GetRequestID retrieves the request ID from the context.
//
// Parameters:
// - ctx context.Context: The context containing the request ID.
//
// Returns:
// - string: The request ID as a string if present, otherwise an empty string.
func GetRequestID(ctx context.Context) string {
if requestID, ok := ctx.Value(constants.ContextKeyRequestID).(string); ok {
return requestID
}
return ""
}
// GetValueFromContext retrieves a string value from the context based on the provided key.
//
// Parameters:
// - ctx ctx.Context: The context from which to retrieve the value.
// - value constants.ContextKey: The key of type constants.ContextKey used to retrieve the value.
//
// Returns:
// - any: The value if found, otherwise an empty string.
func GetValueFromContext(ctx context.Context, value constants.ContextKey) any {
return ctx.Value(value)
}
// AddKeyValueToContext returns a new context with the specified key-value pair added.
//
// Parameters:
// - ctx context.Context: The base context.
// - key constants.ContextKey: The key of type constants.ContextKey to associate with the value.
// - value any: The value to store in the context.
//
// Returns:
// - context.Context: A new context with the key-value pair added.
func AddKeyValueToContext(ctx context.Context, key constants.ContextKey, value any) context.Context {
return context.WithValue(ctx, key, value)
}
package utils
import (
"crypto/sha256"
"encoding/base64"
"net"
"net/url"
"slices"
"strings"
"github.com/google/uuid"
"github.com/vigiloauth/vigilo/v2/internal/errors"
"golang.org/x/crypto/bcrypt"
)
// KeysToSlice converts the keys of a map to slice.
//
// Parameters:
// - input map[K]V: The map to extract the keys from.
//
// Returns:
// - []K: A slice containing the keys from the input map.
func KeysToSlice[K comparable, V any](input map[K]V) []K {
keys := make([]K, 0, len(input))
for key := range input {
keys = append(keys, key)
}
return keys
}
func IsSubset(subset, set []string) bool {
m := make(map[string]struct{})
for _, s := range set {
m[s] = struct{}{}
}
for _, s := range subset {
if _, ok := m[s]; !ok {
return false
}
}
return true
}
// contains checks if a slice contains a specific element.
func Contains[T comparable](slice []T, element T) bool {
return slices.Contains(slice, element)
}
// Helper to check if a slice of space-separated response type strings contains a specific component (e.g., "code", "id_token", "token").
func ContainsResponseType(responseTypes []string, component string) bool {
for _, responseTypeCombo := range responseTypes {
components := strings.Fields(responseTypeCombo)
if slices.Contains(components, component) {
return true
}
}
return false
}
// checks if a string contains a wildcard
func ContainsWildcard(uri string) bool {
return strings.Contains(uri, "*")
}
// isLoopbackIP checks if the given IP is a loopback address.
func IsLoopbackIP(host string) bool {
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func ValidateRedirectURIScheme(parsedURL *url.URL) error {
if parsedURL.Scheme != "https" && parsedURL.Scheme != "http" && !strings.HasPrefix(parsedURL.Scheme, "custom") {
return errors.New(
errors.ErrCodeInvalidRedirectURI, "invalid scheme, must be 'https' or 'http' for localhost or 'custom' for mobile",
)
}
return nil
}
func ParseURI(uri string) (*url.URL, error) {
parsedURL, err := url.Parse(uri)
if err != nil {
return nil, errors.Wrap(err, errors.ErrCodeInvalidRedirectURI, "invalid redirect URI format")
}
if parsedURL.Fragment != "" {
return nil, errors.New(errors.ErrCodeInvalidRedirectURI, "fragments are not allowed in the redirect URI")
}
return parsedURL, nil
}
func ValidatePublicURIScheme(parsedURL *url.URL) error {
if parsedURL.Scheme == "http" && parsedURL.Host != "localhost" {
return errors.New(errors.ErrCodeInvalidRedirectURI, "'http' scheme is only allowed for 'localhost'")
}
if parsedURL.Scheme == "https" && parsedURL.Host == "localhost" {
return errors.New(
errors.ErrCodeInvalidRedirectURI,
"'https' scheme is not allowed for for public clients using 'localhost'",
)
}
return nil
}
func ValidateConfidentialURIScheme(parsedURL *url.URL) error {
if strings.Contains(parsedURL.Host, "*") {
return errors.New(errors.ErrCodeInvalidRedirectURI, "wildcards are not allowed for confidential clients")
}
return nil
}
// GenerateJWKKeyID generates a JWK Key ID (kid) by hashing the provided key
// using SHA-256 and encoding it in base64 URL format.
//
// Parameters:
// - key string: The key to generate the JWK Key ID from.
//
// Returns:
// - string: The base64 URL encoded JWK Key ID.
func GenerateJWKKeyID(key string) string {
hasher := sha256.New()
hasher.Write([]byte(key))
return base64.RawURLEncoding.EncodeToString(hasher.Sum(nil))
}
// CompareHash compares a plain text string with a hashed
// string and returns true if they match.
//
// Parameters:
// - plainStr string: The plain text string.
// - hashStr string: The encrypted string.
//
// Returns:
// - bool: True if they match, otherwise false.
func CompareHash(plainStr, hashStr string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hashStr), []byte(plainStr))
return err == nil
}
// GenerateUUID generates a new universally unique identifier (UUID) as a string.
// It uses the uuid package to create a version 4 UUID, which is a randomly generated UUID.
//
// Returns:
// - string: A string representation of the generated UUID.
func GenerateUUID() string {
uuid := uuid.New().String()
return uuid
}
// EncodeSHA256 hashes the input using SHA-256 and encodes it in base64 URL format.
func EncodeSHA256(input string) string {
hash := sha256.Sum256([]byte(input))
return base64.RawURLEncoding.EncodeToString(hash[:])
}
package utils
import (
"net/url"
)
// TruncateSensitive shortens sensitive strings for safe logging.
//
// Parameters:
// - data string: The sensitive string to truncate.
//
// Returns:
// - string: A truncated version of the string with "[REDACTED]" appended if its length
// is greater than 5. Otherwise, returns the original string.
func TruncateSensitive(data string) string {
const minDataLength int = 5
if len(data) > minDataLength {
return data[:minDataLength] + "[REDACTED]"
}
return data
}
// SanitizeURL redacts query parameters from the provided URL for secure logging.
//
// Parameters:
// - uri string: The URL string to sanitize.
//
// Returns:
// - string: The sanitized URL with query parameters replaced by "[REDACTED]".
// If the URL is invalid, returns "[INVALID URL]".
func SanitizeURL(uri string) string {
parsed, err := url.Parse(uri)
if err != nil {
return "[INVALID URL]"
}
parsed.RawQuery = "[REDACTED]"
return parsed.String()
}
package web
import (
"encoding/json"
"net/http"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
// DecodeJSONRequest decodes the JSON request body into the provided generic type T.
// It reads the request body, attempts to decode it into the specified type, and returns
// a pointer to the decoded object. If decoding fails, it returns an error wrapped with
// an internal server error code.
//
// Parameters:
//
// - w: The HTTP response writer.
// - r: The HTTP request containing the JSON body to decode.
//
// Returns:
//
// - A pointer to the decoded object of type T, or an error if decoding fails.
func DecodeJSONRequest[T any](w http.ResponseWriter, r *http.Request) (*T, error) {
var request T
if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
return nil, errors.Wrap(err, errors.ErrCodeInternalServerError, "failed to decode request")
}
return &request, nil
}
package web
import (
"encoding/base64"
"net/http"
"strings"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
// ExtractClientBasicAuth extracts the client ID and client secret from the Authorization
// header of the request. The header must be in the form "Basic <base64 encoded
// client_id:client_secret>". If the header is invalid, an error is returned.
func ExtractClientBasicAuth(r *http.Request) (string, string, error) {
authHeader := r.Header.Get(constants.AuthorizationHeader)
if !strings.HasPrefix(authHeader, constants.BasicAuthHeader) {
return "", "", errors.New(errors.ErrCodeInvalidClient, "the authorization header is invalid or missing")
}
credentials, err := base64.StdEncoding.DecodeString(authHeader[6:])
if err != nil {
return "", "", errors.New(errors.ErrCodeInvalidClient, "invalid credentials in the authorization header")
}
const subStrCount int = 2
parts := strings.SplitN(string(credentials), ":", subStrCount)
if len(parts) != subStrCount {
return "", "", errors.New(errors.ErrCodeInvalidClient, "invalid credentials format in the authorization header")
}
return parts[0], parts[1], nil
}
// ExtractBearerToken extracts the Bearer token from the Authorization header of an HTTP request.
// It trims the "Bearer" prefix from the token and returns the token string.
//
// Parameters:
// - r *http.Request: The HTTP request containing the Authorization header.
//
// Returns:
// - string: The extracted Bearer token.
// - error: An error if the Authorization header is missing or invalid.
func ExtractBearerToken(r *http.Request) (string, error) {
authHeader := r.Header.Get(constants.AuthorizationHeader)
if authHeader == "" {
err := errors.New(errors.ErrCodeMissingHeader, "authorization header is missing")
return "", err
}
lowercaseHeader := strings.ToLower(authHeader)
if !strings.HasPrefix(lowercaseHeader, "bearer ") {
err := errors.New(errors.ErrCodeInvalidFormat, "authorization header must start with Bearer")
return "", err
}
return authHeader[7:], nil
}
// SetNoStoreHeader sets the Cache-Control header of an HTTP response to "no-store".
// This ensures that the response is not cached by the client or intermediary caches.
//
// Parameters:
//
// w: The HTTP response writer to set the header on.
func SetNoStoreHeader(w http.ResponseWriter) {
w.Header().Set(constants.CacheControlHeader, constants.NoStoreHeader)
}
package web
import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/vigiloauth/vigilo/v2/internal/constants"
"github.com/vigiloauth/vigilo/v2/internal/errors"
)
// WriteJSON writes the provided data as JSON to the HTTP response writer.
func WriteJSON(w http.ResponseWriter, status int, data any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
if data != nil {
err := json.NewEncoder(w).Encode(data)
if err != nil {
panic(err)
}
}
}
// WriteError writes an error as JSON response with appropriate HTTP status code
func WriteError(w http.ResponseWriter, err error) {
if e, ok := err.(*errors.ErrorCollection); ok { //nolint:errorlint
err := errors.VigiloAuthError{
ErrorCode: errors.ErrCodeValidationError,
ErrorDescription: "One or more validation errors occurred",
Errors: e.Errors(),
}
WriteJSON(w, http.StatusBadRequest, err)
} else {
WriteJSON(w, errors.HTTPStatusCodeMap[errors.ErrorCode(err)], err)
}
}
func RenderErrorPage(w http.ResponseWriter, r *http.Request, errorCode string, invalidURI string) {
errorURL := "/error?type=" + errors.SystemErrorCodeMap[errorCode]
if invalidURI != "" {
errorURL += "&uri=" + url.QueryEscape(invalidURI)
}
http.Redirect(w, r, errorURL, http.StatusFound)
}
func BuildErrorURL(errCode, errDescription, state, redirectURI string) string {
params := url.Values{}
params.Add("error", errCode)
params.Add("error_description", errDescription)
params.Add(constants.StateReqField, state)
return redirectURI + "?" + params.Encode()
}
func BuildRedirectURL(
clientID string,
redirectURI string,
scope string,
responseType string,
state string,
nonce string,
prompt string,
display string,
acrValues string,
claims string,
endpoint string,
) string {
queryParams := url.Values{}
queryParams.Add(constants.ClientIDReqField, clientID)
queryParams.Add(constants.RedirectURIReqField, redirectURI)
queryParams.Add(constants.ScopeReqField, scope)
queryParams.Add(constants.ResponseTypeReqField, responseType)
if state != "" {
queryParams.Add(constants.StateReqField, state)
}
if nonce != "" {
queryParams.Add(constants.NonceReqField, nonce)
}
if prompt != "" {
queryParams.Add(constants.PromptReqField, prompt)
}
if acrValues != "" {
queryParams.Add(constants.ACRReqField, acrValues)
}
if claims != "" {
queryParams.Add(constants.ClaimsReqField, claims)
}
if display != "" && constants.ValidAuthenticationDisplays[display] {
queryParams.Add(constants.DisplayReqField, display)
} else {
queryParams.Add(constants.DisplayReqField, constants.DisplayPage)
}
return fmt.Sprintf("/%s?%s", endpoint, queryParams.Encode())
}
package integration
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vigiloauth/vigilo/v2/idp/config"
"github.com/vigiloauth/vigilo/v2/idp/server"
"github.com/vigiloauth/vigilo/v2/internal/constants"
audit "github.com/vigiloauth/vigilo/v2/internal/domain/audit"
domain "github.com/vigiloauth/vigilo/v2/internal/domain/claims"
service "github.com/vigiloauth/vigilo/v2/internal/service/crypto"
jwtService "github.com/vigiloauth/vigilo/v2/internal/service/jwt"
tokenService "github.com/vigiloauth/vigilo/v2/internal/service/token"
"github.com/vigiloauth/vigilo/v2/internal/types"
"github.com/vigiloauth/vigilo/v2/internal/utils"
"github.com/vigiloauth/vigilo/v2/internal/web"
client "github.com/vigiloauth/vigilo/v2/internal/domain/client"
token "github.com/vigiloauth/vigilo/v2/internal/domain/token"
userConsent "github.com/vigiloauth/vigilo/v2/internal/domain/userconsent"
"github.com/vigiloauth/vigilo/v2/internal/errors"
auditEventRepo "github.com/vigiloauth/vigilo/v2/internal/repository/audit"
sessionRepo "github.com/vigiloauth/vigilo/v2/internal/repository/session"
tokenRepo "github.com/vigiloauth/vigilo/v2/internal/repository/token"
consentRepo "github.com/vigiloauth/vigilo/v2/internal/repository/userconsent"
users "github.com/vigiloauth/vigilo/v2/internal/domain/user"
clientRepo "github.com/vigiloauth/vigilo/v2/internal/repository/client"
userRepo "github.com/vigiloauth/vigilo/v2/internal/repository/user"
)
const (
// Test constants for reuse
testUsername string = "testUser"
testFirstName string = "John"
testMiddleName string = "Mary"
testFamilyName string = "Doe"
testBirthdate string = "2000-12-06"
testPhoneNumber string = "+14255551212"
testGender string = "male"
testStreetAddress string = "123 Main St"
testLocality string = "Springfield"
testRegion string = "IL"
testPostalCode string = "62704"
testCountry string = "USA"
testEmail string = "test@email.com"
testPassword1 string = "Password123!@"
testPassword2 string = "NewPassword_$55"
testClientName1 string = "Test App"
testClientName2 string = "Test App 2"
testInvalidPassword string = "weak"
testClientID string = "client-1234"
testUserID string = "user-1234"
testClientSecret string = "a-string-secret-at-least-256-bits-long-enough"
testScope string = "openid profile address"
encodedTestScope string = "client%3Amanage%20user%3Amanage"
testRedirectURI string = "https://vigiloauth.com/callback"
testConsentApproved string = "true"
testAuthzCode string = "valid-auth-code"
testIP string = "192.168.1.10"
testNonce string = "123na"
testState string = "12345State"
)
// VigiloTestContext encapsulates constants testing functionality across all test types
type VigiloTestContext struct {
T *testing.T
VigiloServer *server.VigiloIdentityServer
ResponseRecorder *httptest.ResponseRecorder
TestServer *httptest.Server
HttpClient *http.Client
User *users.User
OAuthClient *client.Client
SessionCookie *http.Cookie
JWTToken string
ClientAuthToken string
State string
SH256CodeChallenge string
PlainCodeChallenge string
RequestURI string
}
// NewVigiloTestContext creates a basic test context with default server configurations.
func NewVigiloTestContext(t *testing.T) *VigiloTestContext {
config.GetServerConfig().Logger().SetLevel("debug")
return &VigiloTestContext{
T: t,
VigiloServer: server.NewVigiloIdentityServer(),
SH256CodeChallenge: utils.EncodeSHA256(testClientSecret),
PlainCodeChallenge: testClientSecret,
}
}
// WithLiveServer adds a live test server to the context.
func (tc *VigiloTestContext) WithLiveHTTPServer() *VigiloTestContext {
tc.TestServer = httptest.NewServer(tc.VigiloServer.Router())
return tc
}
// WithUser creates and adds a user to the system.
func (tc *VigiloTestContext) WithUser(roles []string) *VigiloTestContext {
user := &users.User{
ID: testUserID,
PreferredUsername: testUsername,
Name: testFirstName + " " + testMiddleName + " " + testFamilyName,
GivenName: testFirstName,
MiddleName: testMiddleName,
FamilyName: testFamilyName,
Email: testEmail,
PhoneNumber: testPhoneNumber,
Password: testPassword1,
Birthdate: testBirthdate,
Gender: testGender,
Roles: roles,
LastFailedLogin: time.Time{},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
AccountLocked: false,
EmailVerified: false,
PhoneNumberVerified: true,
Address: &users.UserAddress{
Formatted: testStreetAddress + ", " + testLocality + ", " + testRegion + ", " + testPostalCode + ", " + testCountry,
StreetAddress: testStreetAddress,
Locality: testLocality,
Region: testRegion,
PostalCode: testPostalCode,
Country: testCountry,
},
}
crypto := service.NewCryptographer()
hashedPassword, err := crypto.HashString(user.Password)
require.NoError(tc.T, err)
user.Password = hashedPassword
_ = userRepo.GetInMemoryUserRepository().AddUser(context.Background(), user)
tc.User = user
return tc
}
func (tc *VigiloTestContext) WithUserConsent() *VigiloTestContext {
_ = consentRepo.GetInMemoryUserConsentRepository().
SaveConsent(
context.Background(),
testUserID,
testClientID,
types.CombineScopes(types.Scope(testScope)),
)
return tc
}
// WithClient creates and adds a client to the system.
//
// Parameters:
//
// clientType client.ClientType: The type of client (public or confidential).
// scopes []client.Scope: An array of scopes.
// grantTypes []client.GrantType: An array of grantTypes.
func (tc *VigiloTestContext) WithClient(clientType types.ClientType, scopes []types.Scope, grantTypes []string) {
c := &client.Client{
Name: testClientName1,
ID: testClientID,
Type: clientType,
GrantTypes: grantTypes,
ResponseTypes: []string{constants.CodeResponseType},
RedirectURIs: []string{testRedirectURI},
}
if len(scopes) == 0 {
c.CanRequestScopes = true
} else {
c.Scopes = scopes
}
if clientType == types.ConfidentialClient {
c.Secret = testClientSecret
c.ApplicationType = constants.WebApplicationType
c.TokenEndpointAuthMethod = types.ClientSecretBasicTokenAuth
} else if slices.Contains(grantTypes, constants.AuthorizationCodeGrantType) {
c.RequiresPKCE = true
c.ApplicationType = constants.NativeApplicationType
c.TokenEndpointAuthMethod = types.NoTokenAuth
}
_ = clientRepo.GetInMemoryClientRepository().SaveClient(context.Background(), c)
}
// WithJWTToken creates and adds a user JWT token to the system.
func (tc *VigiloTestContext) WithJWTToken(id string, duration time.Duration) *VigiloTestContext {
if tc.User == nil {
tc.WithUser([]string{constants.AdminRole})
}
repo := tokenRepo.GetInMemoryTokenRepository()
cryptoService := service.NewCryptographer()
jwt := jwtService.NewJWTService()
creator := tokenService.NewTokenCreator(repo, jwt, cryptoService)
tokenService := tokenService.NewTokenIssuer(creator)
token, err := tokenService.IssueAccessToken(
context.Background(),
id,
testClientID,
types.CombineScopes(types.Scope(testScope)),
"", testNonce,
)
require.NoError(tc.T, err)
tc.JWTToken = token
return tc
}
func (tc *VigiloTestContext) WithAdminToken(id string, duration time.Duration) *VigiloTestContext {
if tc.User == nil {
tc.WithUser([]string{constants.AdminRole})
}
repo := tokenRepo.GetInMemoryTokenRepository()
cryptoService := service.NewCryptographer()
jwt := jwtService.NewJWTService()
creator := tokenService.NewTokenCreator(repo, jwt, cryptoService)
tokenService := tokenService.NewTokenIssuer(creator)
token, err := tokenService.IssueAccessToken(
context.Background(),
testUserID,
testClientID,
types.CombineScopes(types.Scope(testScope)),
constants.AdminRole, testNonce,
)
require.NoError(tc.T, err)
tc.JWTToken = token
return tc
}
func (tc *VigiloTestContext) WithJWTTokenWithScopes(subject, audience string, scopes []types.Scope, duration time.Duration) *VigiloTestContext {
if tc.User == nil {
tc.WithUser([]string{constants.AdminRole})
}
if len(scopes) == 0 {
scopes = append(scopes, types.OpenIDScope)
}
repo := tokenRepo.GetInMemoryTokenRepository()
cryptoService := service.NewCryptographer()
jwt := jwtService.NewJWTService()
creator := tokenService.NewTokenCreator(repo, jwt, cryptoService)
tokenService := tokenService.NewTokenIssuer(creator)
accessToken, err := tokenService.IssueAccessToken(
context.Background(),
testUserID,
testClientID,
types.NewScopeList(scopes...),
"", testNonce,
)
require.NoError(tc.T, err)
tc.JWTToken = accessToken
return tc
}
func (tc *VigiloTestContext) WithJWTTokenWithClaims(subject, audience string, claims *domain.ClaimsRequest) *VigiloTestContext {
if tc.User == nil {
tc.WithUser([]string{constants.AdminRole})
}
repo := tokenRepo.GetInMemoryTokenRepository()
cryptoService := service.NewCryptographer()
jwt := jwtService.NewJWTService()
creator := tokenService.NewTokenCreator(repo, jwt, cryptoService)
tokenService := tokenService.NewTokenIssuer(creator)
accessToken, _, err := tokenService.IssueTokenPair(
context.Background(),
testUserID,
testClientID,
types.OpenIDScope,
"", testNonce,
claims,
)
require.NoError(tc.T, err)
tc.JWTToken = accessToken
return tc
}
func (tc *VigiloTestContext) WithBlacklistedToken(id string) *VigiloTestContext {
repo := tokenRepo.GetInMemoryTokenRepository()
cryptoService := service.NewCryptographer()
jwt := jwtService.NewJWTService()
creator := tokenService.NewTokenCreator(repo, jwt, cryptoService)
issuer := tokenService.NewTokenIssuer(creator)
token, err := issuer.IssueAccessToken(
context.Background(),
testUserID,
testClientID,
types.CombineScopes(types.Scope(testScope)),
"", testNonce,
)
require.NoError(tc.T, err)
tc.JWTToken = token
parser := tokenService.NewTokenParser(jwt)
validator := tokenService.NewTokenValidator(repo, parser)
manager := tokenService.NewTokenManager(repo, parser, validator)
_ = manager.BlacklistToken(context.Background(), token)
return tc
}
func (tc *VigiloTestContext) GetSessionCookie() *http.Cookie {
var sessionCookie *http.Cookie
for _, cookie := range tc.ResponseRecorder.Result().Cookies() {
if cookie.Name == config.GetServerConfig().SessionCookieName() {
sessionCookie = cookie
break
}
}
assert.NotNil(tc.T, sessionCookie)
return sessionCookie
}
// WithClientCredentialsToken generates and adds a client credentials token
func (tc *VigiloTestContext) WithClientCredentialsToken() *VigiloTestContext {
if tc.OAuthClient == nil {
tc.WithClient(
types.ConfidentialClient,
[]types.Scope{types.OpenIDScope},
[]string{constants.ClientCredentialsGrantType},
)
}
auth := base64.StdEncoding.EncodeToString([]byte(testClientID + ":" + testClientSecret))
formData := url.Values{}
formData.Add(constants.GrantTypeReqField, constants.ClientCredentialsGrantType)
formData.Add(constants.ScopeReqField, types.OpenIDScope.String())
headers := map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": "Basic " + auth,
}
rr := tc.SendHTTPRequest(http.MethodPost, web.OAuthEndpoints.Token, strings.NewReader(formData.Encode()), headers)
var tokenResponse token.TokenResponse
err := json.NewDecoder(rr.Body).Decode(&tokenResponse)
require.NoError(tc.T, err)
assert.Equal(tc.T, http.StatusOK, rr.Code)
tc.ClientAuthToken = tokenResponse.AccessToken
return tc
}
// WithExpiredToken generates an expired token for testing.
func (tc *VigiloTestContext) WithExpiredToken() *VigiloTestContext {
return tc.WithJWTToken(testUserID, -10*time.Hour)
}
// WithPasswordResetToken generates a password reset token.
func (tc *VigiloTestContext) WithPasswordResetToken(duration time.Duration) (string, *VigiloTestContext) {
if tc.User == nil {
tc.WithUser([]string{constants.AdminRole})
}
repo := tokenRepo.GetInMemoryTokenRepository()
cryptoService := service.NewCryptographer()
jwt := jwtService.NewJWTService()
creator := tokenService.NewTokenCreator(repo, jwt, cryptoService)
issuer := tokenService.NewTokenIssuer(creator)
token, err := issuer.IssueAccessToken(
context.Background(),
testUserID,
testClientID,
types.CombineScopes(types.Scope(testScope)),
"", testNonce,
)
require.NoError(tc.T, err)
return token, tc
}
// WithCustomConfig sets a custom server configuration
func (tc *VigiloTestContext) WithCustomConfig(options ...config.ServerConfigOptions) *VigiloTestContext {
config.NewServerConfig(options...)
tc.VigiloServer = server.NewVigiloIdentityServer()
return tc
}
// SendHTTPRequest sends an HTTP request using the test recorder
func (tc *VigiloTestContext) SendHTTPRequest(method, endpoint string, body io.Reader, headers map[string]string) *httptest.ResponseRecorder {
req := httptest.NewRequest(method, endpoint, body)
tc.addHeaderAuth(req, headers)
rr := httptest.NewRecorder()
tc.VigiloServer.Router().ServeHTTP(rr, req)
tc.ResponseRecorder = rr
return rr
}
func (tc *VigiloTestContext) ClearSession() {
sessionRepo.ResetInMemorySessionRepository()
}
func (tc *VigiloTestContext) WithOAuthLogin() {
loginRequest := users.UserLoginRequest{
Username: testUsername,
Password: testPassword1,
}
requestBody, err := json.Marshal(loginRequest)
require.NoError(tc.T, err)
state := tc.GetStateFromSession()
tc.State = state
queryParams := url.Values{}
queryParams.Add(constants.ClientIDReqField, testClientID)
queryParams.Add(constants.RedirectURIReqField, testRedirectURI)
endpoint := web.OAuthEndpoints.Authenticate + "?" + queryParams.Encode()
rr := tc.SendHTTPRequest(
http.MethodPost,
endpoint,
bytes.NewReader(requestBody), nil,
)
tc.T.Log(rr.Body.String())
assert.Equal(tc.T, http.StatusOK, rr.Code)
}
// SendLiveRequest sends a request to the live test server
func (tc *VigiloTestContext) SendLiveRequest(method, endpoint string, body io.Reader, headers map[string]string) (*http.Response, error) {
if tc.TestServer == nil {
tc.WithLiveHTTPServer()
}
url := tc.TestServer.URL + endpoint
req, err := http.NewRequest(method, url, body)
require.NoError(tc.T, err)
tc.addHeaderAuth(req, headers)
return tc.HttpClient.Do(req) //nolint:wrapcheck
}
func (tc *VigiloTestContext) WithUserSession() {
if tc.User == nil {
tc.WithUser([]string{constants.AdminRole})
}
loginRequest := users.NewUserLoginRequest(testUsername, testPassword1)
body, err := json.Marshal(loginRequest)
require.NoError(tc.T, err)
rr := tc.SendHTTPRequest(
http.MethodPost,
web.OAuthEndpoints.Authenticate,
bytes.NewBuffer(body), nil,
)
assert.Equal(tc.T, http.StatusOK, rr.Code)
// Store the session cookie
cookies := rr.Result().Cookies()
for _, cookie := range cookies {
if cookie.Name == config.GetServerConfig().SessionCookieName() {
tc.SessionCookie = cookie
return
}
}
tc.T.Fatalf("Session cookie not found in login response")
}
func (tc *VigiloTestContext) GetStateFromSession() string {
queryParams := url.Values{}
queryParams.Add(constants.ClientIDReqField, testClientID)
queryParams.Add(constants.RedirectURIReqField, testRedirectURI)
queryParams.Add(constants.ScopeReqField, testScope)
getEndpoint := web.OAuthEndpoints.UserConsent + "?" + queryParams.Encode()
headers := map[string]string{"Cookie": tc.SessionCookie.Name + "=" + tc.SessionCookie.Value}
rr := tc.SendHTTPRequest(http.MethodGet, getEndpoint, nil, headers)
assert.Equal(tc.T, http.StatusOK, rr.Code)
// Parse the response to extract the state
var consentResponse userConsent.UserConsentResponse
err := json.Unmarshal(rr.Body.Bytes(), &consentResponse)
require.NoError(tc.T, err)
state := consentResponse.State
assert.NotEmpty(tc.T, state)
return state
}
func (tc *VigiloTestContext) GetAuthzCode() string {
queryParams := tc.CreateAuthorizationCodeRequestQueryParams("", "")
parsedURL := tc.sendAuthorizationCodeRequest(queryParams)
authzCode := parsedURL.Query().Get(constants.CodeResponseType)
assert.NotEmpty(tc.T, authzCode, "Authorization code should not be empty")
return authzCode
}
func (tc *VigiloTestContext) CreateAuthorizationCodeRequestQueryParams(codeChallenge, codeChallengeMethod string) url.Values {
queryParams := url.Values{}
queryParams.Add(constants.ResponseTypeReqField, constants.CodeResponseType)
queryParams.Add(constants.ClientIDReqField, testClientID)
queryParams.Add(constants.RedirectURIReqField, testRedirectURI)
queryParams.Add(constants.ScopeReqField, "openid profile")
queryParams.Add(constants.StateReqField, tc.State)
queryParams.Add(constants.ConsentApprovedURLValue, "true")
queryParams.Add(constants.NonceReqField, testNonce)
queryParams.Add("acr_values", "1 2")
if codeChallenge != "" {
queryParams.Add(constants.CodeChallengeReqField, codeChallenge)
}
if codeChallengeMethod != "" {
queryParams.Add(constants.CodeChallengeMethodReqField, codeChallengeMethod)
}
return queryParams
}
func (tc *VigiloTestContext) GetAuthzCodeWithPKCE(codeChallenge, codeChallengeMethod string) string {
queryParams := tc.CreateAuthorizationCodeRequestQueryParams(codeChallenge, codeChallengeMethod)
parsedURL := tc.sendAuthorizationCodeRequest(queryParams)
authzCode := parsedURL.Query().Get(constants.CodeURLValue)
assert.NotEmpty(tc.T, authzCode, "Authorization code should not be empty")
return authzCode
}
// AssertErrorResponseDescription checks to see if the test returns a correct error.
func (tc *VigiloTestContext) AssertErrorResponseDescription(
rr *httptest.ResponseRecorder,
expectedErrCode, expectedDescription string,
) {
errResp := tc.decodeErrorResponse(rr)
assert.Equal(tc.T, expectedErrCode, errResp.ErrorCode)
assert.Equal(tc.T, expectedDescription, errResp.ErrorDescription)
}
func (tc *VigiloTestContext) AssertErrorResponseDetails(
rr *httptest.ResponseRecorder,
expectedErrCode, expectedDetails string,
) {
errResp := tc.decodeErrorResponse(rr)
assert.Equal(tc.T, expectedErrCode, errResp.ErrorCode)
assert.Equal(tc.T, expectedDetails, errResp.ErrorDetails)
}
func (tc *VigiloTestContext) AssertErrorResponse(
rr *httptest.ResponseRecorder,
expectedErrCode, expectedDescription, expectedDetails string,
) {
errResp := tc.decodeErrorResponse(rr)
assert.Equal(tc.T, expectedErrCode, errResp.ErrorCode)
assert.Equal(tc.T, expectedDescription, errResp.ErrorDescription)
assert.Equal(tc.T, expectedDetails, errResp.ErrorDetails)
}
func (tc *VigiloTestContext) WithDebugLogs() {
config.GetServerConfig().Logger().SetLevel("DEBUG")
}
func (tc *VigiloTestContext) WithAuditEvents() {
ctx := context.Background()
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyUserID, testUserID)
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyIPAddress, testIP)
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyRequestID, "req-"+utils.GenerateUUID())
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeySessionID, "sess-"+utils.GenerateUUID())
ctx = utils.AddKeyValueToContext(ctx, constants.ContextKeyTokenClaims, tc.JWTToken)
eventCount := 100
for range eventCount {
event := audit.NewAuditEvent(ctx, audit.LoginAttempt, false, audit.AuthenticationAction, audit.EmailMethod, errors.ErrCodeAccountLocked)
err := auditEventRepo.GetInMemoryAuditEventRepository().StoreAuditEvent(ctx, event)
require.NoError(tc.T, err)
}
}
func (tc *VigiloTestContext) GetUserRegistrationRequest() *users.UserRegistrationRequest {
return &users.UserRegistrationRequest{
Username: testUsername,
FirstName: testFirstName,
MiddleName: testMiddleName,
FamilyName: testFamilyName,
Birthdate: testBirthdate,
Email: testEmail,
Gender: testGender,
PhoneNumber: testPhoneNumber,
Password: testPassword1,
Roles: []string{constants.AdminRole},
Address: users.UserAddress{
StreetAddress: testStreetAddress,
Locality: testLocality,
Region: testRegion,
PostalCode: testPostalCode,
Country: testCountry,
},
}
}
// TearDown performs cleanup operations.
func (tc *VigiloTestContext) TearDown() {
if tc.TestServer != nil {
tc.TestServer.Close()
}
config.GetServerConfig().Logger().SetLevel("info")
resetInMemoryStores()
}
func (tc *VigiloTestContext) sendAuthorizationCodeRequest(queryParams url.Values) *url.URL {
endpoint := web.OAuthEndpoints.Authorize + "?" + queryParams.Encode()
headers := map[string]string{"Cookie": tc.SessionCookie.Name + "=" + tc.SessionCookie.Value}
rr := tc.SendHTTPRequest(http.MethodGet, endpoint, nil, headers)
assert.Equal(tc.T, http.StatusFound, rr.Code)
location := rr.Header().Get(constants.RedirectLocationURLValue)
assert.NotEmpty(tc.T, location, "Redirect location should not be empty")
parsedURL, err := url.Parse(location)
require.NoError(tc.T, err)
return parsedURL
}
func (tc *VigiloTestContext) addHeaderAuth(req *http.Request, headers map[string]string) {
if _, exists := headers["Content-Type"]; !exists {
req.Header.Set("Content-Type", "application/json")
}
// Add all headers
for key, value := range headers {
req.Header.Set(key, value)
}
}
func resetInMemoryStores() {
userRepo.ResetInMemoryUserRepository()
tokenRepo.ResetInMemoryTokenRepository()
clientRepo.ResetInMemoryClientRepository()
consentRepo.ResetInMemoryUserConsentRepository()
sessionRepo.ResetInMemorySessionRepository()
auditEventRepo.ResetInMemoryAuditEventRepository()
}
func (tc *VigiloTestContext) decodeErrorResponse(rr *httptest.ResponseRecorder) errors.VigiloAuthError {
var errResp errors.VigiloAuthError
err := json.NewDecoder(rr.Body).Decode(&errResp)
require.NoError(tc.T, err, "Failed to unmarshal response body")
return errResp
}
func GenerateHeaderWithCredentials(id, secret string) map[string]string {
headers := map[string]string{
"Content-Type": "application/x-www-form-urlencoded",
"Authorization": "Basic " + encodeClientCredentials(id, secret),
}
return headers
}
func encodeClientCredentials(clientID, clientSecret string) string {
return base64.StdEncoding.EncodeToString([]byte(clientID + ":" + clientSecret))
}