package ctrl
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/JMURv/sso/internal/auth"
"github.com/JMURv/sso/internal/cache"
"github.com/JMURv/sso/internal/dto"
repo "github.com/JMURv/sso/internal/repository"
"github.com/JMURv/sso/pkg/consts"
"github.com/JMURv/sso/pkg/model"
"github.com/google/uuid"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"math/rand"
"strconv"
"time"
)
const codeCacheKey = "code:%v"
const recoveryCacheKey = "recovery:%v"
func (c *Controller) Authenticate(ctx context.Context, req *dto.EmailAndPasswordRequest) (*dto.EmailAndPasswordResponse, error) {
const op = "sso.Authenticate.ctrl"
span, ctx := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
u, err := c.repo.GetUserByEmail(ctx, req.Email)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
)
return nil, ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
)
return nil, err
}
accessToken, err := c.auth.NewToken(u, auth.AccessTokenDuration)
if err != nil {
zap.L().Debug(
"failed to generate access token",
zap.Error(err), zap.String("op", op),
)
return nil, ErrWhileGeneratingToken
}
return &dto.EmailAndPasswordResponse{
Token: accessToken,
}, nil
}
func (c *Controller) ParseClaims(ctx context.Context, token string) (map[string]any, error) {
const op = "sso.ParseClaims.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
claims, err := c.auth.VerifyToken(token)
if err != nil {
zap.L().Debug("invalid token", zap.Error(err))
return nil, err
}
if _, ok := claims["uid"].(string); !ok {
zap.L().Debug("failed to parse uid", zap.String("op", op))
return nil, ErrParseUUID
}
return claims, nil
}
func (c *Controller) GetUserByToken(ctx context.Context, token string) (*model.User, error) {
const op = "sso.GetUserByToken.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
claims, err := c.ParseClaims(ctx, token)
if err != nil {
zap.L().Debug("invalid token", zap.Error(err))
return nil, err
}
uid, err := uuid.Parse(claims["uid"].(string))
if err != nil {
zap.L().Debug("failed to parse uuid", zap.String("op", op))
return nil, ErrParseUUID
}
cached := &model.User{}
cacheKey := fmt.Sprintf(userCacheKey, uid)
if err := c.cache.GetToStruct(ctx, cacheKey, cached); err == nil {
return cached, nil
}
res, err := c.repo.GetUserByID(ctx, uid)
if err != nil && errors.Is(err, repo.ErrNotFound) {
return nil, ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
zap.String("id", uid.String()),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, cacheKey, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.String("id", uid.String()),
)
}
}
return res, nil
}
func (c *Controller) SendSupportEmail(ctx context.Context, uid uuid.UUID, theme, text string) error {
const op = "sso.SendSupportEmail.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
u, err := c.repo.GetUserByID(ctx, uid)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"Error find user",
zap.Error(err), zap.String("op", op),
)
return ErrNotFound
} else if err != nil {
zap.L().Debug(
"Error getting user",
zap.Error(err), zap.String("op", op),
)
return err
}
if err = c.smtp.SendSupportEmail(ctx, u, theme, text); err != nil {
zap.L().Debug(
"Error sending email",
zap.Error(err), zap.String("op", op),
)
return err
}
return nil
}
func (c *Controller) CheckForgotPasswordEmail(ctx context.Context, password string, uid uuid.UUID, code int) error {
const op = "sso.CheckForgotPasswordEmail.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
storedCode, err := c.cache.GetCode(ctx, fmt.Sprintf(recoveryCacheKey, uid))
if err != nil {
zap.L().Debug(
"Error getting from cache",
zap.Error(err), zap.String("op", op),
)
return err
}
if storedCode != code {
return ErrCodeIsNotValid
}
u, err := c.repo.GetUserByID(ctx, uid)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"Error find user",
zap.Error(err), zap.String("op", op),
)
return ErrNotFound
} else if err != nil {
zap.L().Debug(
"Error getting user",
zap.Error(err), zap.String("op", op),
)
return err
}
newPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return repo.ErrGeneratingPassword
}
u.Password = string(newPassword)
if err = c.repo.UpdateUser(ctx, uid, u); err != nil {
zap.L().Debug(
"Error updating user",
zap.Error(err), zap.String("op", op),
)
return err
}
if err = c.cache.Delete(ctx, fmt.Sprintf(userCacheKey, uid)); err != nil {
zap.L().Debug(
"Error deleting from cache",
zap.Error(err), zap.String("op", op),
)
}
return nil
}
func (c *Controller) SendForgotPasswordEmail(ctx context.Context, email string) error {
const op = "sso.SendForgotPasswordEmail.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
u, err := c.repo.GetUserByEmail(ctx, email)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
)
return ErrInvalidCredentials
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
)
return err
}
rand.New(rand.NewSource(time.Now().UnixNano()))
code := rand.Intn(9999-1000+1) + 1000
if err = c.cache.Set(ctx, time.Minute*15, fmt.Sprintf(recoveryCacheKey, u.ID.String()), code); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
)
return err
}
if err = c.smtp.SendForgotPasswordEmail(ctx, strconv.Itoa(code), u.ID.String(), email); err != nil {
zap.L().Debug(
"failed to send email",
zap.Error(err), zap.String("op", op),
)
return err
}
return nil
}
func (c *Controller) SendLoginCode(ctx context.Context, email, password string) error {
const op = "sso.SendLoginCode.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
u, err := c.repo.GetUserByEmail(ctx, email)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
)
return ErrInvalidCredentials
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
)
return err
}
if err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password)); err != nil {
return ErrInvalidCredentials
}
rand.New(rand.NewSource(time.Now().UnixNano()))
code := rand.Intn(9999-1000+1) + 1000
if err = c.cache.Set(ctx, time.Minute*15, fmt.Sprintf(codeCacheKey, email), code); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
)
return err
}
if err = c.smtp.SendLoginEmail(ctx, code, email); err != nil {
zap.L().Debug(
"failed to send an email",
zap.Error(err), zap.String("op", op),
)
return err
}
return nil
}
func (c *Controller) CheckLoginCode(ctx context.Context, email string, code int) (string, string, error) {
const op = "sso.CheckLoginCode.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
storedCode, err := c.cache.GetCode(ctx, fmt.Sprintf(codeCacheKey, email))
if err != nil && errors.Is(err, cache.ErrNotFoundInCache) {
return "", "", ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get from cache",
zap.Error(err), zap.String("op", op),
)
return "", "", err
}
if storedCode != code {
return "", "", ErrCodeIsNotValid
}
u, err := c.repo.GetUserByEmail(ctx, email)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
)
return "", "", ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
)
return "", "", err
}
accessToken, err := c.auth.NewToken(u, auth.AccessTokenDuration)
if err != nil {
zap.L().Debug(
"failed to generate access token",
zap.Error(err), zap.String("op", op),
)
return "", "", ErrWhileGeneratingToken
}
refreshToken, err := c.auth.NewToken(u, auth.RefreshTokenDuration)
if err != nil {
zap.L().Debug(
"failed to generate refresh token",
zap.Error(err), zap.String("op", op),
)
return "", "", ErrWhileGeneratingToken
}
return accessToken, refreshToken, nil
}
package ctrl
import (
"context"
md "github.com/JMURv/sso/pkg/model"
"io"
"time"
)
type AppRepo interface {
userRepo
permRepo
}
type AuthService interface {
NewToken(u *md.User, d time.Duration) (string, error)
VerifyToken(tokenStr string) (map[string]any, error)
}
type CacheService interface {
io.Closer
GetCode(ctx context.Context, key string) (int, error)
GetToStruct(ctx context.Context, key string, dest any) error
Set(ctx context.Context, t time.Duration, key string, val any) error
Delete(ctx context.Context, key string) error
InvalidateKeysByPattern(ctx context.Context, pattern string)
}
type EmailService interface {
SendLoginEmail(ctx context.Context, code int, toEmail string) error
SendForgotPasswordEmail(ctx context.Context, token, uid64, toEmail string) error
SendSupportEmail(ctx context.Context, u *md.User, theme, text string) error
SendUserCredentials(_ context.Context, email, pass string) error
SendOptFile(_ context.Context, email string, filename string, bytes []byte) error
}
type Controller struct {
repo AppRepo
auth AuthService
cache CacheService
smtp EmailService
}
func New(auth AuthService, repo AppRepo, cache CacheService, smtp EmailService) *Controller {
return &Controller{
auth: auth,
repo: repo,
cache: cache,
smtp: smtp,
}
}
package ctrl
import (
"context"
"errors"
"fmt"
repo "github.com/JMURv/sso/internal/repository"
"github.com/JMURv/sso/pkg/consts"
"github.com/JMURv/sso/pkg/model"
"github.com/goccy/go-json"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
)
const permKey = "perm:%v"
const permListKey = "perms-list:%v:%v"
const permPattern = "perms-*"
type permRepo interface {
ListPermissions(ctx context.Context, page, size int) (*model.PaginatedPermission, error)
GetPermission(ctx context.Context, id uint64) (*model.Permission, error)
CreatePerm(ctx context.Context, req *model.Permission) (uint64, error)
UpdatePerm(ctx context.Context, id uint64, req *model.Permission) error
DeletePerm(ctx context.Context, id uint64) error
}
func (c *Controller) ListPermissions(ctx context.Context, page, size int) (*model.PaginatedPermission, error) {
const op = "sso.ListPermissions.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
cached := &model.PaginatedPermission{}
key := fmt.Sprintf(permListKey, page, size)
if err := c.cache.GetToStruct(ctx, key, &cached); err == nil {
return cached, nil
}
res, err := c.repo.ListPermissions(ctx, page, size)
if err != nil {
zap.L().Debug(
"failed to list permissions",
zap.Error(err), zap.String("op", op),
zap.Int("page", page), zap.Int("size", size),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, key, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.Int("page", page), zap.Int("size", size),
)
}
}
return res, nil
}
func (c *Controller) GetPermission(ctx context.Context, id uint64) (*model.Permission, error) {
const op = "sso.GetPermission.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
cached := &model.Permission{}
cacheKey := fmt.Sprintf(permKey, id)
if err := c.cache.GetToStruct(ctx, cacheKey, cached); err == nil {
return cached, nil
}
res, err := c.repo.GetPermission(ctx, id)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find permission",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
return nil, ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get permission",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, cacheKey, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
}
}
return res, nil
}
func (c *Controller) CreatePerm(ctx context.Context, req *model.Permission) (uint64, error) {
const op = "sso.CreatePerm.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
res, err := c.repo.CreatePerm(ctx, req)
if err != nil && errors.Is(err, repo.ErrAlreadyExists) {
return 0, ErrAlreadyExists
} else if err != nil {
zap.L().Debug(
"failed to create permission",
zap.Error(err), zap.String("op", op),
)
return 0, err
}
go c.cache.InvalidateKeysByPattern(ctx, permPattern)
return res, nil
}
func (c *Controller) UpdatePerm(ctx context.Context, id uint64, req *model.Permission) error {
const op = "sso.UpdatePerm.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
err := c.repo.UpdatePerm(ctx, id, req)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find permission",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
return ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to update permission",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
return err
}
if err := c.cache.Delete(ctx, fmt.Sprintf(permKey, id)); err != nil {
zap.L().Debug(
"failed to delete from cache",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
}
go c.cache.InvalidateKeysByPattern(ctx, permPattern)
return nil
}
func (c *Controller) DeletePerm(ctx context.Context, id uint64) error {
const op = "sso.DeletePerm.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
if err := c.repo.DeletePerm(ctx, id); err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to delete permission",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
return ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to delete permission",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
return err
}
if err := c.cache.Delete(ctx, fmt.Sprintf(permKey, id)); err != nil {
zap.L().Debug(
"failed to delete from cache",
zap.Error(err), zap.String("op", op),
zap.Uint64("id", id),
)
}
go c.cache.InvalidateKeysByPattern(ctx, permPattern)
return nil
}
package ctrl
import (
"context"
"errors"
"fmt"
"github.com/JMURv/sso/internal/auth"
repo "github.com/JMURv/sso/internal/repository"
"github.com/JMURv/sso/pkg/consts"
"github.com/JMURv/sso/pkg/model"
"github.com/goccy/go-json"
"github.com/google/uuid"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
)
type userRepo interface {
SearchUser(ctx context.Context, query string, page int, size int) (*model.PaginatedUser, error)
ListUsers(ctx context.Context, page, size int) (*model.PaginatedUser, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error)
GetUserByEmail(ctx context.Context, email string) (*model.User, error)
CreateUser(ctx context.Context, req *model.User) (uuid.UUID, error)
UpdateUser(ctx context.Context, id uuid.UUID, req *model.User) error
DeleteUser(ctx context.Context, userID uuid.UUID) error
}
const userCacheKey = "user:%v"
const usersSearchCacheKey = "users-search:%v:%v:%v"
const usersListKey = "users-list:%v:%v"
const userPattern = "users-*"
func (c *Controller) IsUserExist(ctx context.Context, email string) (isExist bool, err error) {
const op = "sso.IsUserExist.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
_, err = c.repo.GetUserByEmail(ctx, email)
if err != nil && errors.Is(err, repo.ErrNotFound) {
return false, nil
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
)
return true, err
}
return true, nil
}
func (c *Controller) SearchUser(ctx context.Context, query string, page, size int) (*model.PaginatedUser, error) {
const op = "users.UserSearch.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
cached := &model.PaginatedUser{}
cacheKey := fmt.Sprintf(usersSearchCacheKey, query, page, size)
if err := c.cache.GetToStruct(ctx, cacheKey, &cached); err == nil {
return cached, nil
}
res, err := c.repo.SearchUser(ctx, query, page, size)
if err != nil {
zap.L().Debug(
"failed to search users",
zap.Error(err), zap.String("op", op),
zap.String("query", query),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, cacheKey, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.String("query", query),
)
}
}
return res, nil
}
func (c *Controller) ListUsers(ctx context.Context, page, size int) (*model.PaginatedUser, error) {
const op = "users.GetUsersList.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
cached := &model.PaginatedUser{}
cacheKey := fmt.Sprintf(usersListKey, page, size)
if err := c.cache.GetToStruct(ctx, cacheKey, &cached); err == nil {
return cached, nil
}
res, err := c.repo.ListUsers(ctx, page, size)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to list users",
zap.Error(err), zap.String("op", op),
zap.Int("page", page), zap.Int("size", size),
)
return nil, ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to list users",
zap.Error(err), zap.String("op", op),
zap.Int("page", page), zap.Int("size", size),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, cacheKey, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.Int("page", page), zap.Int("size", size),
)
}
}
return res, nil
}
func (c *Controller) GetUserByID(ctx context.Context, userID uuid.UUID) (*model.User, error) {
const op = "users.GetUserByID.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
cached := &model.User{}
cacheKey := fmt.Sprintf(userCacheKey, userID)
if err := c.cache.GetToStruct(ctx, cacheKey, cached); err == nil {
return cached, nil
}
res, err := c.repo.GetUserByID(ctx, userID)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
zap.String("id", userID.String()),
)
return nil, ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
zap.String("id", userID.String()),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, cacheKey, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.String("id", userID.String()),
)
}
}
return res, nil
}
func (c *Controller) GetUserByEmail(ctx context.Context, email string) (*model.User, error) {
const op = "users.GetUserByEmail.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
cached := &model.User{}
cacheKey := fmt.Sprintf(userCacheKey, email)
if err := c.cache.GetToStruct(ctx, cacheKey, cached); err == nil {
return cached, nil
}
res, err := c.repo.GetUserByEmail(ctx, email)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
zap.String("email", email),
)
return nil, ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to get user",
zap.Error(err), zap.String("op", op),
zap.String("email", email),
)
return nil, err
}
if bytes, err := json.Marshal(res); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, cacheKey, bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
zap.String("email", email),
)
}
}
return res, nil
}
func (c *Controller) CreateUser(ctx context.Context, u *model.User, fileName string, bytes []byte) (uuid.UUID, string, string, error) {
const op = "users.CreateUser.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
id, err := c.repo.CreateUser(ctx, u)
if err != nil && errors.Is(err, repo.ErrAlreadyExists) {
zap.L().Debug(
"user already exists",
zap.Error(err), zap.String("op", op),
)
return uuid.Nil, "", "", ErrAlreadyExists
} else if err != nil {
zap.L().Debug(
"failed to create user",
zap.Error(err), zap.String("op", op),
)
return uuid.Nil, "", "", err
}
if bytes, err := json.Marshal(u); err == nil {
if err = c.cache.Set(ctx, consts.DefaultCacheTime, fmt.Sprintf(userCacheKey, id), bytes); err != nil {
zap.L().Debug(
"failed to set to cache",
zap.Error(err), zap.String("op", op),
)
}
}
if fileName != "" && bytes != nil {
if err = c.smtp.SendOptFile(ctx, u.Email, fileName, bytes); err != nil {
zap.L().Debug(
"failed to send email",
zap.Error(err),
zap.String("op", op),
)
}
}
accessToken, err := c.auth.NewToken(u, auth.AccessTokenDuration)
if err != nil {
zap.L().Debug(
"failed to create access token",
zap.Error(err), zap.String("op", op),
)
return id, "", "", ErrWhileGeneratingToken
}
refreshToken, err := c.auth.NewToken(u, auth.RefreshTokenDuration)
if err != nil {
zap.L().Debug(
"failed to create refresh token",
zap.Error(err), zap.String("op", op),
)
return id, "", "", ErrWhileGeneratingToken
}
if err = c.smtp.SendUserCredentials(ctx, u.Email, u.Password); err != nil {
return id, accessToken, refreshToken, err
}
go c.cache.InvalidateKeysByPattern(ctx, userPattern)
return id, accessToken, refreshToken, nil
}
func (c *Controller) UpdateUser(ctx context.Context, id uuid.UUID, req *model.User) error {
const op = "users.UpdateUser.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
err := c.repo.UpdateUser(ctx, id, req)
if err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
zap.String("id", id.String()),
)
return ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to update user",
zap.Error(err), zap.String("op", op),
zap.String("id", id.String()),
)
return err
}
if err := c.cache.Delete(ctx, fmt.Sprintf(userCacheKey, id)); err != nil {
zap.L().Debug(
"failed to delete from cache",
zap.Error(err), zap.String("op", op),
zap.String("id", id.String()),
)
}
go c.cache.InvalidateKeysByPattern(ctx, userPattern)
return nil
}
func (c *Controller) DeleteUser(ctx context.Context, userID uuid.UUID) error {
const op = "users.DeleteUser.ctrl"
span, _ := opentracing.StartSpanFromContext(ctx, op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer span.Finish()
if err := c.repo.DeleteUser(ctx, userID); err != nil && errors.Is(err, repo.ErrNotFound) {
zap.L().Debug(
"failed to find user",
zap.Error(err), zap.String("op", op),
zap.String("id", userID.String()),
)
return ErrNotFound
} else if err != nil {
zap.L().Debug(
"failed to delete user",
zap.Error(err), zap.String("op", op),
zap.String("id", userID.String()),
)
return err
}
if err := c.cache.Delete(ctx, fmt.Sprintf(userCacheKey, userID)); err != nil {
zap.L().Debug(
"failed to delete from cache",
zap.Error(err), zap.String("op", op),
zap.String("id", userID.String()),
)
}
go c.cache.InvalidateKeysByPattern(ctx, userPattern)
return nil
}
package grpc
import (
"context"
"errors"
pb "github.com/JMURv/sso/api/pb"
ctrl "github.com/JMURv/sso/internal/controller"
"github.com/JMURv/sso/internal/dto"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
"github.com/JMURv/sso/internal/validation"
utils "github.com/JMURv/sso/pkg/utils/grpc"
"github.com/google/uuid"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strconv"
"time"
)
func (h *Handler) Authenticate(ctx context.Context, req *pb.SSO_EmailAndPasswordRequest) (*pb.SSO_EmailAndPasswordResponse, error) {
const op = "sso.Authenticate.hdl"
s, c := time.Now(), codes.OK
span, ctx := opentracing.StartSpanFromContext(ctx, op)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if req == nil {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
r := &dto.EmailAndPasswordRequest{
Email: req.Email,
Password: req.Password,
}
err := validation.LoginAndPasswordRequest(r)
if err != nil {
c = codes.InvalidArgument
return nil, status.Errorf(c, err.Error())
}
res, err := h.ctrl.Authenticate(ctx, r)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_EmailAndPasswordResponse{
Token: res.Token,
}, nil
}
func (h *Handler) GetUserByToken(ctx context.Context, req *pb.SSO_StringMsg) (*pb.SSO_User, error) {
s, c := time.Now(), codes.OK
const op = "sso.GetUserByToken.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
token := req.String_
if req == nil || token == "" {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
u, err := h.ctrl.GetUserByToken(ctx, token)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return utils.ModelToProto(u), nil
}
func (h *Handler) ParseClaims(ctx context.Context, req *pb.SSO_StringMsg) (*pb.SSO_ParseClaimsRes, error) {
s, c := time.Now(), codes.OK
const op = "sso.ValidateToken.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
token := req.GetString_()
if req == nil || token == "" {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
res, err := h.ctrl.ParseClaims(ctx, token)
if err != nil {
c = codes.Internal
zap.L().Debug("failed to parse claims", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_ParseClaimsRes{
Token: res["uid"].(string),
Email: res["email"].(string),
Exp: res["exp"].(string),
}, nil
}
func (h *Handler) SendLoginCode(ctx context.Context, req *pb.SSO_SendLoginCodeReq) (*pb.SSO_Empty, error) {
const op = "sso.SendLoginCode.hdl"
s, c := time.Now(), codes.OK
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if req == nil || req.Email == "" || req.Password == "" {
c = codes.InvalidArgument
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
if err := validation.ValidateEmail(req.Email); err != nil {
c = codes.InvalidArgument
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
err := h.ctrl.SendLoginCode(ctx, req.Email, req.Password)
if err != nil && errors.Is(err, ctrl.ErrInvalidCredentials) {
c = codes.InvalidArgument
zap.L().Debug("failed to send login code", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
span.SetTag("error", true)
c = codes.Internal
zap.L().Error("failed to send login code", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Empty{}, nil
}
func (h *Handler) CheckLoginCode(ctx context.Context, req *pb.SSO_CheckLoginCodeReq) (*pb.SSO_CheckLoginCodeRes, error) {
s, c := time.Now(), codes.OK
const op = "sso.checkLoginCode.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
email, code := req.Email, req.Code
if email == "" || code == 0 {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
if err := validation.ValidateEmail(email); err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to validate email", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
}
access, refresh, err := h.ctrl.CheckLoginCode(ctx, email, int(code))
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
zap.L().Debug("user not found", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to check login code", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_CheckLoginCodeRes{
Access: access,
Refresh: refresh,
}, nil
}
func (h *Handler) CheckEmail(ctx context.Context, req *pb.SSO_EmailMsg) (*pb.SSO_CheckEmailRes, error) {
s, c := time.Now(), codes.OK
const op = "sso.CheckEmail.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if req == nil || req.Email == "" {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
isExist, err := h.ctrl.IsUserExist(ctx, req.Email)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
zap.L().Debug("user not found", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to check email", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_CheckEmailRes{
IsExist: isExist,
}, nil
}
func (h *Handler) Logout(ctx context.Context, _ *pb.SSO_Empty) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.Logout.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
uidStr, ok := ctx.Value("uid").(string)
if !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
uid, err := uuid.Parse(uidStr)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
return &pb.SSO_Empty{}, nil
}
func (h *Handler) SendForgotPasswordEmail(ctx context.Context, req *pb.SSO_EmailMsg) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.SendForgotPasswordEmail.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if req == nil || req.Email == "" {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
err := h.ctrl.SendForgotPasswordEmail(ctx, req.Email)
if err != nil && errors.Is(err, ctrl.ErrInvalidCredentials) {
c = codes.InvalidArgument
zap.L().Debug("invalid credentials", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to send forgot password email", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Empty{}, nil
}
func (h *Handler) CheckForgotPasswordEmail(ctx context.Context, req *pb.SSO_CheckForgotPasswordEmailReq) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.CheckForgotPasswordEmail.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
pass, uid, token := req.Password, req.Uidb64, req.Token
if pass == "" || uid == "" || token == "" {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
uidb64, err := uuid.Parse(req.Uidb64)
if err != nil {
c = codes.InvalidArgument
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
intToken, err := strconv.Atoi(req.Token)
if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
err = h.ctrl.CheckForgotPasswordEmail(ctx, pass, uidb64, intToken)
if err != nil && errors.Is(err, ctrl.ErrCodeIsNotValid) {
c = codes.InvalidArgument
zap.L().Debug("invalid code", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
zap.L().Debug("failed to find user", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrNotFound.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to check forgot password email", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Empty{}, nil
}
func (h *Handler) SendSupportEmail(ctx context.Context, req *pb.SSO_SendSupportEmailReq) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.SendSupportEmail.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
uidStr, ok := ctx.Value("uid").(string)
if !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
uid, err := uuid.Parse(uidStr)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
if req == nil || req.Theme == "" || req.Text == "" {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
err = h.ctrl.SendSupportEmail(ctx, uid, req.Theme, req.Text)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
zap.L().Debug("failed to find user", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to send email", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, "failed to send email")
}
return &pb.SSO_Empty{}, nil
}
func (h *Handler) Me(ctx context.Context, _ *pb.SSO_Empty) (*pb.SSO_User, error) {
s, c := time.Now(), codes.OK
const op = "sso.Me.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
uidStr, ok := ctx.Value("uid").(string)
if !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
uid, err := uuid.Parse(uidStr)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
u, err := h.ctrl.GetUserByID(ctx, uid)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
zap.L().Debug("user not found", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to get user", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return utils.ModelToProto(u), nil
}
func (h *Handler) UpdateMe(ctx context.Context, req *pb.SSO_User) (*pb.SSO_User, error) {
s, c := time.Now(), codes.OK
const op = "sso.UpdateMe.handler"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
uidStr, ok := ctx.Value("uid").(string)
if !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
uid, err := uuid.Parse(uidStr)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
if req == nil {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
user := utils.ProtoToModel(req)
if err = validation.UserValidation(user); err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to validate obj", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
}
err = h.ctrl.UpdateUser(ctx, uid, user)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
zap.L().Debug("user not found", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
zap.L().Debug("failed to update user", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_User{}, nil
}
package grpc
import (
"context"
"errors"
"fmt"
pb "github.com/JMURv/sso/api/pb"
ctrl "github.com/JMURv/sso/internal/controller"
"github.com/JMURv/sso/internal/dto"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
md "github.com/JMURv/sso/pkg/model"
"github.com/google/uuid"
pm "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/health"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/reflection"
"log"
"net"
)
type Ctrl interface {
Authenticate(ctx context.Context, req *dto.EmailAndPasswordRequest) (*dto.EmailAndPasswordResponse, error)
ParseClaims(ctx context.Context, token string) (map[string]any, error)
GetUserByToken(ctx context.Context, token string) (*md.User, error)
SendSupportEmail(ctx context.Context, uid uuid.UUID, theme, text string) error
CheckForgotPasswordEmail(ctx context.Context, password string, uid uuid.UUID, code int) error
SendForgotPasswordEmail(ctx context.Context, email string) error
SendLoginCode(ctx context.Context, email, password string) error
CheckLoginCode(ctx context.Context, email string, code int) (string, string, error)
IsUserExist(ctx context.Context, email string) (isExist bool, err error)
SearchUser(ctx context.Context, query string, page, size int) (*md.PaginatedUser, error)
ListUsers(ctx context.Context, page, size int) (*md.PaginatedUser, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*md.User, error)
GetUserByEmail(ctx context.Context, email string) (*md.User, error)
CreateUser(ctx context.Context, u *md.User, fileName string, bytes []byte) (uuid.UUID, string, string, error)
UpdateUser(ctx context.Context, id uuid.UUID, req *md.User) error
DeleteUser(ctx context.Context, userID uuid.UUID) error
ListPermissions(ctx context.Context, page, size int) (*md.PaginatedPermission, error)
GetPermission(ctx context.Context, id uint64) (*md.Permission, error)
CreatePerm(ctx context.Context, req *md.Permission) (uint64, error)
UpdatePerm(ctx context.Context, id uint64, req *md.Permission) error
DeletePerm(ctx context.Context, id uint64) error
}
type Handler struct {
pb.SSOServer
pb.UsersServer
pb.PermissionSvcServer
srv *grpc.Server
hsrv *health.Server
ctrl Ctrl
}
func New(auth ctrl.AuthService, ctrl Ctrl) *Handler {
srv := grpc.NewServer(
grpc.ChainUnaryInterceptor(
AuthUnaryInterceptor(auth),
metrics.SrvMetrics.UnaryServerInterceptor(pm.WithExemplarFromContext(metrics.Exemplar)),
),
grpc.ChainStreamInterceptor(
metrics.SrvMetrics.StreamServerInterceptor(pm.WithExemplarFromContext(metrics.Exemplar)),
),
)
hsrv := health.NewServer()
hsrv.SetServingStatus("sso", grpc_health_v1.HealthCheckResponse_SERVING)
reflection.Register(srv)
return &Handler{
ctrl: ctrl,
srv: srv,
hsrv: hsrv,
}
}
func (h *Handler) Start(port int) {
pb.RegisterSSOServer(h.srv, h)
pb.RegisterUsersServer(h.srv, h)
pb.RegisterPermissionSvcServer(h.srv, h)
grpc_health_v1.RegisterHealthServer(h.srv, h.hsrv)
lis, err := net.Listen("tcp", fmt.Sprintf(":%v", port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
if err := h.srv.Serve(lis); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
log.Fatal(err)
}
}
func (h *Handler) Close() error {
h.srv.GracefulStop()
return nil
}
func AuthUnaryInterceptor(auth ctrl.AuthService) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
zap.L().Debug("missing metadata")
return handler(ctx, req)
}
authHeaders := md["authorization"]
if len(authHeaders) == 0 {
zap.L().Debug("missing authorization token")
return handler(ctx, req)
}
tokenStr := authHeaders[0]
if len(tokenStr) > 7 && tokenStr[:7] == "Bearer " {
tokenStr = tokenStr[7:]
}
claims, err := auth.VerifyToken(tokenStr)
if err != nil {
zap.L().Debug("invalid token", zap.Error(err))
return handler(ctx, req)
}
ctx = context.WithValue(ctx, "uid", claims["uid"])
return handler(ctx, req)
}
}
package grpc
import (
"context"
"errors"
pb "github.com/JMURv/sso/api/pb"
ctrl "github.com/JMURv/sso/internal/controller"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
"github.com/JMURv/sso/internal/validation"
md "github.com/JMURv/sso/pkg/model"
utils "github.com/JMURv/sso/pkg/utils/grpc"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"time"
)
func (h *Handler) ListPermissions(ctx context.Context, req *pb.SSO_ListReq) (*pb.SSO_PermissionList, error) {
s, c := time.Now(), codes.OK
const op = "sso.ListPermissions.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
page, size := req.Page, req.Size
if page == 0 || size == 0 {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
res, err := h.ctrl.ListPermissions(ctx, int(page), int(size))
if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_PermissionList{
Data: utils.ListPermissionsToProto(res.Data),
Count: res.Count,
TotalPages: int64(res.TotalPages),
CurrentPage: int64(res.CurrentPage),
HasNextPage: res.HasNextPage,
}, nil
}
func (h *Handler) GetPermission(ctx context.Context, req *pb.SSO_Uint64Msg) (*pb.SSO_Permission, error) {
s, c := time.Now(), codes.OK
const op = "sso.GetPermission.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if req == nil || req.Uint64 == 0 {
c = codes.InvalidArgument
zap.L().Debug(
"failed to parse uid",
zap.String("op", op),
zap.Uint64("uid", req.Uint64),
)
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
res, err := h.ctrl.GetPermission(ctx, req.Uint64)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return utils.PermissionToProto(res), nil
}
func (h *Handler) CreatePermission(ctx context.Context, req *pb.SSO_Permission) (*pb.SSO_Uint64Msg, error) {
s, c := time.Now(), codes.OK
const op = "sso.CreatePermission.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
mdPerm := &md.Permission{
ID: req.Id,
Name: req.Name,
}
if err := validation.PermValidation(mdPerm); err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to validate obj", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
}
uid, err := h.ctrl.CreatePerm(ctx, mdPerm)
if err != nil && errors.Is(err, ctrl.ErrAlreadyExists) {
c = codes.AlreadyExists
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Uint64Msg{
Uint64: uid,
}, nil
}
func (h *Handler) UpdatePermission(ctx context.Context, req *pb.SSO_Permission) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.UpdatePermission.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if _, ok := ctx.Value("uid").(string); !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
if req == nil || req.Id == 0 {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
mdPerm := utils.PermissionFromProto(req)
if err := validation.PermValidation(mdPerm); err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to validate obj", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
}
if err := h.ctrl.UpdatePerm(ctx, req.Id, mdPerm); err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Empty{}, nil
}
func (h *Handler) DeletePermission(ctx context.Context, req *pb.SSO_Uint64Msg) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.DeletePermission.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
if _, ok := ctx.Value("uid").(string); !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
if req == nil || req.Uint64 == 0 {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
if err := h.ctrl.DeletePerm(ctx, req.Uint64); err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Empty{}, nil
}
package grpc
import (
"context"
"errors"
pb "github.com/JMURv/sso/api/pb"
ctrl "github.com/JMURv/sso/internal/controller"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
"github.com/JMURv/sso/internal/validation"
md "github.com/JMURv/sso/pkg/model"
utils "github.com/JMURv/sso/pkg/utils/grpc"
"github.com/google/uuid"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"time"
)
func (h *Handler) SearchUser(ctx context.Context, req *pb.SSO_SearchReq) (*pb.SSO_PaginatedUsersRes, error) {
s, c := time.Now(), codes.OK
const op = "sso.SearchUser.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
q, page, size := req.Query, req.Page, req.Size
if q == "" || page == 0 || size == 0 {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
u, err := h.ctrl.SearchUser(ctx, q, int(page), int(size))
if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_PaginatedUsersRes{
Data: utils.ListModelToProto(u.Data),
Count: u.Count,
TotalPages: int64(u.TotalPages),
CurrentPage: int64(u.CurrentPage),
HasNextPage: u.HasNextPage,
}, nil
}
func (h *Handler) ListUsers(ctx context.Context, req *pb.SSO_ListReq) (*pb.SSO_PaginatedUsersRes, error) {
s, c := time.Now(), codes.OK
const op = "sso.ListUsers.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
page, size := req.Page, req.Size
if page == 0 || size == 0 {
c = codes.InvalidArgument
zap.L().Debug("failed to decode request", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrDecodeRequest.Error())
}
u, err := h.ctrl.ListUsers(ctx, int(page), int(size))
if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_PaginatedUsersRes{
Data: utils.ListModelToProto(u.Data),
Count: u.Count,
TotalPages: int64(u.TotalPages),
CurrentPage: int64(u.CurrentPage),
HasNextPage: u.HasNextPage,
}, nil
}
func (h *Handler) CreateUser(ctx context.Context, req *pb.SSO_CreateUserReq) (*pb.SSO_CreateUserRes, error) {
s, c := time.Now(), codes.OK
const op = "sso.CreateUser.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
protoUser := &md.User{
Name: req.Name,
Email: req.Email,
Password: req.Password,
}
if err := validation.NewUserValidation(protoUser); err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to validate user", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
}
uid, access, refresh, err := h.ctrl.CreateUser(ctx, protoUser, req.File.Filename, req.File.File)
if err != nil && errors.Is(err, ctrl.ErrAlreadyExists) {
c = codes.AlreadyExists
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_CreateUserRes{
Uid: uid.String(),
Access: access,
Refresh: refresh,
}, nil
}
func (h *Handler) GetUser(ctx context.Context, req *pb.SSO_UuidMsg) (*pb.SSO_User, error) {
s, c := time.Now(), codes.OK
const op = "sso.GetUser.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
uid, err := uuid.Parse(req.Uuid)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
u, err := h.ctrl.GetUserByID(ctx, uid)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return utils.ModelToProto(u), nil
}
func (h *Handler) UpdateUser(ctx context.Context, req *pb.SSO_UserWithUid) (*pb.SSO_UuidMsg, error) {
s, c := time.Now(), codes.OK
const op = "sso.UpdateUser.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
_, ok := ctx.Value("uid").(string)
if !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
uid, err := uuid.Parse(req.Uid)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
protoUser := utils.ProtoToModel(req.User)
if err = validation.UserValidation(protoUser); err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to validate obj", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, err.Error())
}
err = h.ctrl.UpdateUser(ctx, uid, protoUser)
if err != nil && errors.Is(err, ctrl.ErrNotFound) {
c = codes.NotFound
return nil, status.Errorf(c, err.Error())
} else if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_UuidMsg{Uuid: uid.String()}, nil
}
func (h *Handler) DeleteUser(ctx context.Context, req *pb.SSO_UuidMsg) (*pb.SSO_Empty, error) {
s, c := time.Now(), codes.OK
const op = "sso.DeleteUser.hdl"
span := opentracing.GlobalTracer().StartSpan(op)
ctx = opentracing.ContextWithSpan(ctx, span)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), int(c), op)
}()
_, ok := ctx.Value("uid").(string)
if !ok {
c = codes.Unauthenticated
zap.L().Debug("failed to get uid from context", zap.String("op", op))
return nil, status.Errorf(c, ctrl.ErrUnauthorized.Error())
}
uid, err := uuid.Parse(req.Uuid)
if uid == uuid.Nil || err != nil {
c = codes.InvalidArgument
zap.L().Debug("failed to parse uid", zap.String("op", op), zap.Error(err))
return nil, status.Errorf(c, ctrl.ErrParseUUID.Error())
}
err = h.ctrl.DeleteUser(ctx, uid)
if err != nil {
c = codes.Internal
return nil, status.Errorf(c, ctrl.ErrInternalError.Error())
}
return &pb.SSO_Empty{}, nil
}
package http
import (
"errors"
"github.com/JMURv/sso/internal/auth"
controller "github.com/JMURv/sso/internal/controller"
"github.com/JMURv/sso/internal/dto"
"github.com/JMURv/sso/internal/handler"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
"github.com/JMURv/sso/internal/validation"
"github.com/JMURv/sso/pkg/model"
utils "github.com/JMURv/sso/pkg/utils/http"
"github.com/goccy/go-json"
"github.com/google/uuid"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"net/http"
"time"
)
func RegisterAuthRoutes(mux *http.ServeMux, h *Handler) {
mux.HandleFunc("/api/sso/parse", h.parseClaims)
mux.HandleFunc("/api/sso/user", h.getUserByToken)
mux.HandleFunc(
"/api/sso/recovery", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
h.sendForgotPasswordEmail(w, r)
case http.MethodPut:
h.checkForgotPasswordEmail(w, r)
default:
utils.ErrResponse(w, http.StatusMethodNotAllowed, handler.ErrMethodNotAllowed)
}
},
)
mux.HandleFunc(
"/api/sso/me", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
middlewareFunc(h.me, h.authMiddleware)
case http.MethodPut:
middlewareFunc(h.updateMe, h.authMiddleware)
default:
utils.ErrResponse(w, http.StatusMethodNotAllowed, handler.ErrMethodNotAllowed)
}
},
)
mux.HandleFunc("/api/sso/auth", h.authenticate)
mux.HandleFunc("/api/sso/send-login-code", h.sendLoginCode)
mux.HandleFunc("/api/sso/check-login-code", h.checkLoginCode)
mux.HandleFunc("/api/sso/check-email", h.checkEmail)
mux.HandleFunc("/api/sso/logout", middlewareFunc(h.logout, h.authMiddleware))
mux.HandleFunc("/api/sso/support", middlewareFunc(h.sendSupportEmail, h.authMiddleware))
}
func (h *Handler) authenticate(w http.ResponseWriter, r *http.Request) {
const op = "sso.authenticate.hdl"
s, c := time.Now(), http.StatusOK
span, ctx := opentracing.StartSpanFromContext(r.Context(), op)
defer func() {
span.Finish()
metrics.ObserveRequest(time.Since(s), c, op)
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
req := &dto.EmailAndPasswordRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if err := validation.LoginAndPasswordRequest(req); err != nil {
c = http.StatusBadRequest
utils.ErrResponse(w, c, err)
return
}
res, err := h.ctrl.Authenticate(ctx, req)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, err)
return
}
utils.SuccessResponse(w, c, res)
}
type TokenReq struct {
Token string `json:"token"`
}
func (h *Handler) parseClaims(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.parseClaims.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
req := &TokenReq{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if req.Token == "" {
c = http.StatusBadRequest
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
res, err := h.ctrl.ParseClaims(r.Context(), req.Token)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, res)
}
func (h *Handler) getUserByToken(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.getUserByToken.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
req := &TokenReq{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if req.Token == "" {
c = http.StatusBadRequest
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
res, err := h.ctrl.GetUserByToken(r.Context(), req.Token)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, res)
}
type sendSupportEmailRequest struct {
Theme string `json:"theme"`
Text string `json:"text"`
}
func (h *Handler) sendSupportEmail(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.sendSupportEmail.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
uid, err := uuid.Parse(r.Context().Value("uid").(string))
if err != nil {
zap.L().Debug(
"failed to parse uid",
zap.String("op", op), zap.Error(err),
)
c = http.StatusUnauthorized
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
req := &sendSupportEmailRequest{}
if err = json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
err = h.ctrl.SendSupportEmail(r.Context(), uid, req.Theme, req.Text)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
type checkForgotPasswordEmailRequest struct {
Password string `json:"password"`
Uidb64 uuid.UUID `json:"uidb64"`
Token int `json:"token"`
}
func (h *Handler) checkForgotPasswordEmail(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.checkForgotPasswordEmail.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPut {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
req := &checkForgotPasswordEmailRequest{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
err := h.ctrl.CheckForgotPasswordEmail(r.Context(), req.Password, req.Uidb64, req.Token)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
func (h *Handler) sendForgotPasswordEmail(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.sendForgotPasswordEmail.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
req := &model.User{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
err := h.ctrl.SendForgotPasswordEmail(r.Context(), req.Email)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
func (h *Handler) updateMe(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.updateMe.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPut {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
str, ok := r.Context().Value("uid").(string)
if !ok {
zap.L().Debug(
"failed to get uid from context",
zap.String("op", op),
)
c = http.StatusUnauthorized
utils.ErrResponse(w, c, controller.ErrUnauthorized)
return
}
uid, err := uuid.Parse(str)
if err != nil {
zap.L().Debug(
"failed to parse uid",
zap.String("op", op), zap.Error(err),
)
c = http.StatusUnauthorized
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
req := &model.User{}
if err = json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if err = validation.UserValidation(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug("failed to validate obj", zap.String("op", op), zap.Error(err))
utils.ErrResponse(w, c, err)
return
}
err = h.ctrl.UpdateUser(r.Context(), uid, req)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
func (h *Handler) checkEmail(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.checkEmail.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
u := &model.User{}
if err := json.NewDecoder(r.Body).Decode(u); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if u.Email == "" {
c = http.StatusBadRequest
utils.ErrResponse(w, c, handler.ErrMissingEmail)
return
}
isExist, err := h.ctrl.IsUserExist(r.Context(), u.Email)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, err)
return
}
utils.SuccessPaginatedResponse(
w, c, struct {
IsExist bool `json:"is_exist"`
}{
IsExist: isExist,
},
)
}
type loginCodeRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
func (h *Handler) sendLoginCode(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.sendLoginCode.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
data := &loginCodeRequest{}
err := json.NewDecoder(r.Body).Decode(data)
if err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
email, pass := data.Email, data.Password
if err := validation.ValidateEmail(email); err != nil {
c = http.StatusBadRequest
utils.ErrResponse(w, c, err)
return
}
if email == "" || pass == "" {
c = http.StatusBadRequest
utils.ErrResponse(w, c, handler.ErrEmailAndPasswordRequired)
return
}
err = h.ctrl.SendLoginCode(r.Context(), email, pass)
if err != nil && errors.Is(err, controller.ErrInvalidCredentials) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, err)
return
}
utils.SuccessResponse(w, c, "Login code sent successfully")
}
type checkLoginCodeRequest struct {
Email string `json:"email"`
Code int `json:"code"`
}
func (h *Handler) checkLoginCode(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.checkLoginCode.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
data := &checkLoginCodeRequest{}
if err := json.NewDecoder(r.Body).Decode(data); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
email, code := data.Email, data.Code
if email == "" || code == 0 {
c = http.StatusBadRequest
utils.ErrResponse(w, c, handler.ErrEmailAndCodeRequired)
return
}
if err := validation.ValidateEmail(email); err != nil {
c = http.StatusBadRequest
utils.ErrResponse(w, c, err)
return
}
access, refresh, err := h.ctrl.CheckLoginCode(r.Context(), email, code)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
http.SetCookie(
w, &http.Cookie{
Name: "access",
Value: access,
Expires: time.Now().Add(auth.AccessTokenDuration),
HttpOnly: true,
Secure: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
},
)
http.SetCookie(
w, &http.Cookie{
Name: "refresh",
Value: refresh,
Expires: time.Now().Add(auth.RefreshTokenDuration),
HttpOnly: true,
Secure: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
},
)
utils.SuccessResponse(w, c, "OK")
}
func (h *Handler) me(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.me.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodGet {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
str, ok := r.Context().Value("uid").(string)
if !ok {
c = http.StatusUnauthorized
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
uid, err := uuid.Parse(str)
if err != nil {
c = http.StatusUnauthorized
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
u, err := h.ctrl.GetUserByID(r.Context(), uid)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, u)
}
func (h *Handler) logout(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.logout.handler"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodPost {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
http.SetCookie(
w, &http.Cookie{
Name: "access",
Value: "",
MaxAge: -1,
HttpOnly: true,
Secure: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
},
)
http.SetCookie(
w, &http.Cookie{
Name: "refresh",
Value: "",
MaxAge: -1,
HttpOnly: true,
Secure: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
},
)
utils.SuccessResponse(w, c, "OK")
}
package http
import (
"context"
"errors"
"fmt"
controller "github.com/JMURv/sso/internal/controller"
"github.com/JMURv/sso/internal/handler/grpc"
utils "github.com/JMURv/sso/pkg/utils/http"
"github.com/opentracing/opentracing-go"
"go.uber.org/zap"
"net/http"
"strings"
"time"
)
type Handler struct {
srv *http.Server
ctrl grpc.Ctrl
auth controller.AuthService
}
func New(auth controller.AuthService, ctrl grpc.Ctrl) *Handler {
return &Handler{
auth: auth,
ctrl: ctrl,
}
}
func (h *Handler) Start(port int) {
mux := http.NewServeMux()
RegisterAuthRoutes(mux, h)
RegisterUserRoutes(mux, h)
RegisterPermRoutes(mux, h)
mux.HandleFunc(
"/health", func(w http.ResponseWriter, r *http.Request) {
utils.SuccessResponse(w, http.StatusOK, "OK")
},
)
h.srv = &http.Server{
Handler: mux,
Addr: fmt.Sprintf(":%v", port),
WriteTimeout: 15 * time.Second,
ReadTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
err := h.srv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
zap.L().Debug("Server error", zap.Error(err))
}
}
func (h *Handler) Close() error {
if err := h.srv.Shutdown(context.Background()); err != nil {
return err
}
return nil
}
func middlewareFunc(h http.HandlerFunc, middleware ...func(http.Handler) http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var handler http.Handler = h
for _, m := range middleware {
handler = m(handler)
}
handler.ServeHTTP(w, r)
}
}
func (h *Handler) authMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
utils.ErrResponse(w, http.StatusUnauthorized, errors.New("authorization header is missing"))
return
}
tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
if tokenStr == authHeader {
utils.ErrResponse(w, http.StatusUnauthorized, errors.New("invalid token format"))
return
}
claims, err := h.auth.VerifyToken(tokenStr)
if err != nil {
utils.ErrResponse(w, http.StatusUnauthorized, err)
return
}
ctx := context.WithValue(r.Context(), "uid", claims["uid"])
next.ServeHTTP(w, r.WithContext(ctx))
},
)
}
func (h *Handler) tracingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
span := opentracing.GlobalTracer().StartSpan(
fmt.Sprintf("%s %s", r.Method, r.URL),
)
defer span.Finish()
zap.L().Info("Request", zap.String("method", r.Method), zap.String("uri", r.RequestURI))
next.ServeHTTP(w, r)
},
)
}
package http
import (
"errors"
controller "github.com/JMURv/sso/internal/controller"
"github.com/JMURv/sso/internal/handler"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
"github.com/JMURv/sso/internal/validation"
"github.com/JMURv/sso/pkg/consts"
"github.com/JMURv/sso/pkg/model"
utils "github.com/JMURv/sso/pkg/utils/http"
"github.com/goccy/go-json"
"go.uber.org/zap"
"net/http"
"strconv"
"strings"
"time"
)
func RegisterPermRoutes(mux *http.ServeMux, h *Handler) {
mux.HandleFunc(
"/api/perm", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.listPerms(w, r)
case http.MethodPost:
h.createPerm(w, r)
default:
utils.ErrResponse(w, http.StatusMethodNotAllowed, handler.ErrMethodNotAllowed)
}
},
)
mux.HandleFunc(
"/api/perm/", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.getPerm(w, r)
case http.MethodPut:
middlewareFunc(h.updatePerm, h.authMiddleware)
case http.MethodDelete:
middlewareFunc(h.deletePerm, h.authMiddleware)
default:
utils.ErrResponse(w, http.StatusMethodNotAllowed, handler.ErrMethodNotAllowed)
}
},
)
}
func (h *Handler) listPerms(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.listPerms.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
page, err := strconv.Atoi(r.URL.Query().Get("page"))
if err != nil || page < 1 {
page = 1
}
size, err := strconv.Atoi(r.URL.Query().Get("size"))
if err != nil || size < 1 {
size = consts.DefaultPageSize
}
res, err := h.ctrl.ListPermissions(r.Context(), page, size)
if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessPaginatedResponse(w, c, res)
}
func (h *Handler) createPerm(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusCreated
const op = "sso.createPerm.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
req := &model.Permission{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if err := validation.PermValidation(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug("failed to validate user", zap.String("op", op), zap.Error(err))
utils.ErrResponse(w, c, err)
return
}
uid, err := h.ctrl.CreatePerm(r.Context(), req)
if err != nil && errors.Is(err, controller.ErrAlreadyExists) {
c = http.StatusConflict
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, uid)
}
func (h *Handler) getPerm(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.getPerm.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
uid, err := strconv.ParseUint(strings.TrimPrefix(r.URL.Path, "/api/perm/"), 10, 64)
if err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to parse id",
zap.String("op", op),
)
utils.ErrResponse(w, c, handler.ErrRetrievePathVars)
return
}
res, err := h.ctrl.GetPermission(r.Context(), uid)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, res)
}
func (h *Handler) updatePerm(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.updatePerm.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
uid, err := strconv.ParseUint(strings.TrimPrefix(r.URL.Path, "/api/perm/"), 10, 64)
if err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to parse id",
zap.String("op", op),
)
utils.ErrResponse(w, c, handler.ErrRetrievePathVars)
return
}
req := &model.Permission{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if err := validation.PermValidation(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to validate obj",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, err)
return
}
err = h.ctrl.UpdatePerm(r.Context(), uid, req)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
func (h *Handler) deletePerm(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusNoContent
const op = "sso.deletePerm.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
uid, err := strconv.ParseUint(strings.TrimPrefix(r.URL.Path, "/api/perm/"), 10, 64)
if err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to parse id",
zap.String("op", op),
)
utils.ErrResponse(w, c, handler.ErrRetrievePathVars)
return
}
err = h.ctrl.DeletePerm(r.Context(), uid)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
package http
import (
"errors"
"github.com/JMURv/sso/internal/auth"
controller "github.com/JMURv/sso/internal/controller"
"github.com/JMURv/sso/internal/handler"
metrics "github.com/JMURv/sso/internal/metrics/prometheus"
"github.com/JMURv/sso/internal/validation"
"github.com/JMURv/sso/pkg/consts"
"github.com/JMURv/sso/pkg/model"
utils "github.com/JMURv/sso/pkg/utils/http"
"github.com/goccy/go-json"
"github.com/google/uuid"
"go.uber.org/zap"
"io"
"net/http"
"strconv"
"strings"
"time"
)
func RegisterUserRoutes(mux *http.ServeMux, h *Handler) {
mux.HandleFunc("/api/users/search", h.searchUser)
mux.HandleFunc(
"/api/users", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.listUsers(w, r)
case http.MethodPost:
h.createUser(w, r)
default:
utils.ErrResponse(w, http.StatusMethodNotAllowed, handler.ErrMethodNotAllowed)
}
},
)
mux.HandleFunc(
"/api/users/", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.getUser(w, r)
case http.MethodPut:
middlewareFunc(h.updateUser, h.authMiddleware)
case http.MethodDelete:
middlewareFunc(h.deleteUser, h.authMiddleware)
default:
utils.ErrResponse(w, http.StatusMethodNotAllowed, handler.ErrMethodNotAllowed)
}
},
)
}
func (h *Handler) createUser(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusCreated
const op = "sso.createUser.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if err := r.ParseMultipartForm(10 << 20); err != nil {
c = http.StatusBadRequest
zap.L().Debug("failed to parse multipart form", zap.String("op", op), zap.Error(err))
utils.ErrResponse(w, c, err)
return
}
u := &model.User{
Name: r.FormValue("name"),
Email: r.FormValue("email"),
Password: r.FormValue("password"),
}
if err := validation.NewUserValidation(u); err != nil {
c = http.StatusBadRequest
zap.L().Debug("failed to validate user", zap.String("op", op), zap.Error(err))
utils.ErrResponse(w, c, err)
return
}
var fileName string
var bytes []byte
file, handler, err := r.FormFile("file")
if err != nil && err != http.ErrMissingFile {
c = http.StatusBadRequest
zap.L().Debug("failed to retrieve file", zap.String("op", op), zap.Error(err))
utils.ErrResponse(w, c, err)
return
}
if file != nil {
defer file.Close()
bytes, err = io.ReadAll(file)
if err != nil {
c = http.StatusInternalServerError
zap.L().Debug("failed to read file", zap.String("op", op), zap.Error(err))
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
fileName = handler.Filename
}
uid, access, refresh, err := h.ctrl.CreateUser(r.Context(), u, fileName, bytes)
if err != nil && errors.Is(err, controller.ErrAlreadyExists) {
c = http.StatusConflict
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
http.SetCookie(
w, &http.Cookie{
Name: "access",
Value: access,
Expires: time.Now().Add(auth.AccessTokenDuration),
HttpOnly: true,
Secure: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
},
)
http.SetCookie(
w, &http.Cookie{
Name: "refresh",
Value: refresh,
Expires: time.Now().Add(auth.RefreshTokenDuration),
HttpOnly: true,
Secure: true,
Path: "/",
SameSite: http.SameSiteStrictMode,
},
)
utils.SuccessResponse(w, c, uid)
}
func (h *Handler) searchUser(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.search.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
if r.Method != http.MethodGet {
c = http.StatusMethodNotAllowed
utils.ErrResponse(w, c, handler.ErrMethodNotAllowed)
return
}
query := r.URL.Query().Get("q")
if len(query) < 3 {
utils.SuccessPaginatedResponse(w, c, model.PaginatedUser{})
return
}
page, err := strconv.Atoi(r.URL.Query().Get("page"))
if err != nil || page < 1 {
page = 1
}
size, err := strconv.Atoi(r.URL.Query().Get("size"))
if err != nil || size < 1 {
size = consts.DefaultPageSize
}
res, err := h.ctrl.SearchUser(r.Context(), query, page, size)
if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessPaginatedResponse(w, c, res)
}
func (h *Handler) listUsers(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.listUsers.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
page, err := strconv.Atoi(r.URL.Query().Get("page"))
if err != nil || page < 1 {
page = 1
}
size, err := strconv.Atoi(r.URL.Query().Get("size"))
if err != nil || size < 1 {
size = consts.DefaultPageSize
}
res, err := h.ctrl.ListUsers(r.Context(), page, size)
if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessPaginatedResponse(w, c, res)
}
func (h *Handler) getUser(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.getUser.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
uid, err := uuid.Parse(strings.TrimPrefix(r.URL.Path, "/api/users/"))
if uid == uuid.Nil || err != nil {
c = http.StatusUnauthorized
zap.L().Debug(
"failed to parse uid",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
res, err := h.ctrl.GetUserByID(r.Context(), uid)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, res)
}
func (h *Handler) updateUser(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusOK
const op = "sso.updateUser.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
uid, err := uuid.Parse(strings.TrimPrefix(r.URL.Path, "/api/users/"))
if err != nil || uid == uuid.Nil {
c = http.StatusUnauthorized
zap.L().Debug(
"failed to parse uid",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
req := &model.User{}
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to decode request",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrDecodeRequest)
return
}
if err = validation.UserValidation(req); err != nil {
c = http.StatusBadRequest
zap.L().Debug(
"failed to validate obj",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, err)
return
}
err = h.ctrl.UpdateUser(r.Context(), uid, req)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
func (h *Handler) deleteUser(w http.ResponseWriter, r *http.Request) {
s, c := time.Now(), http.StatusNoContent
const op = "sso.deleteUser.hdl"
defer func() {
metrics.ObserveRequest(time.Since(s), c, op)
}()
defer func() {
if err := recover(); err != nil {
zap.L().Error("panic", zap.Any("err", err))
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
}
}()
uid, err := uuid.Parse(strings.TrimPrefix(r.URL.Path, "/api/users/"))
if err != nil {
c = http.StatusUnauthorized
zap.L().Debug(
"failed to parse uid",
zap.String("op", op), zap.Error(err),
)
utils.ErrResponse(w, c, controller.ErrParseUUID)
return
}
err = h.ctrl.DeleteUser(r.Context(), uid)
if err != nil && errors.Is(err, controller.ErrNotFound) {
c = http.StatusNotFound
utils.ErrResponse(w, c, err)
return
} else if err != nil {
c = http.StatusInternalServerError
utils.ErrResponse(w, c, controller.ErrInternalError)
return
}
utils.SuccessResponse(w, c, "OK")
}
package db
import (
"database/sql"
"fmt"
conf "github.com/JMURv/sso/pkg/config"
dbutils "github.com/JMURv/sso/pkg/utils/db"
"go.uber.org/zap"
)
type Repository struct {
conn *sql.DB
}
func New(conf *conf.DBConfig) *Repository {
conn, err := sql.Open(
"postgres", fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s?sslmode=disable",
conf.User,
conf.Password,
conf.Host,
conf.Port,
conf.Database,
),
)
if err != nil {
zap.L().Fatal("Failed to connect to the database", zap.Error(err))
}
if err = conn.Ping(); err != nil {
zap.L().Fatal("Failed to ping the database", zap.Error(err))
}
if err = dbutils.ApplyMigrations(conn, conf); err != nil {
zap.L().Fatal("Failed to apply migrations", zap.Error(err))
}
mustPrecreate(conn)
return &Repository{conn: conn}
}
func (r *Repository) Close() error {
return r.conn.Close()
}
package db
import (
"database/sql"
"github.com/JMURv/sso/pkg/model"
"github.com/goccy/go-json"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"os"
)
func mustPrecreate(db *sql.DB) {
var count int64
if err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count); err != nil {
panic(err)
}
if count == 0 {
type usrWithPerms struct {
Name string `json:"name"`
Password string `json:"password"`
Email string `json:"email"`
Avatar string `json:"avatar"`
Address string `json:"address"`
Phone string `json:"phone"`
Perms []model.Permission `json:"permissions"`
}
bytes, err := os.ReadFile("precreate.json")
if err != nil {
panic(err)
}
p := make([]usrWithPerms, 0, 2)
if err = json.Unmarshal(bytes, &p); err != nil {
panic(err)
}
for _, v := range p {
tx, err := db.Begin()
if err != nil {
panic(err)
}
password, err := bcrypt.GenerateFromPassword([]byte(v.Password), bcrypt.DefaultCost)
if err != nil {
panic(err)
}
v.Password = string(password)
var userID uuid.UUID
err = tx.QueryRow(
`INSERT INTO users (name, password, email, avatar, address, phone)
VALUES ($1, $2, $3, $4, $5, $6) RETURNING id`,
v.Name,
v.Password,
v.Email,
v.Avatar,
v.Address,
v.Phone,
).Scan(&userID)
if err != nil {
panic(err)
}
for _, perm := range v.Perms {
var permID uint64
err := tx.QueryRow(`SELECT id FROM permission WHERE name = $1`, perm.Name).Scan(&permID)
if err != nil && err == sql.ErrNoRows {
if err := tx.QueryRow(permCreate, perm.Name).Scan(&permID); err != nil {
tx.Rollback()
panic(err)
}
} else if err != nil {
tx.Rollback()
panic(err)
}
if _, err = tx.Exec(userCreatePermQ, userID, permID, true); err != nil {
tx.Rollback()
panic(err)
}
}
if err := tx.Commit(); err != nil {
panic(err)
}
}
zap.L().Debug("Users and permissions have been created")
} else {
zap.L().Debug("Users and permissions already exist")
}
}
package db
import (
"context"
"database/sql"
"errors"
repo "github.com/JMURv/sso/internal/repository"
"github.com/JMURv/sso/pkg/model"
"github.com/opentracing/opentracing-go"
"strings"
)
func (r *Repository) ListPermissions(ctx context.Context, page, size int) (*model.PaginatedPermission, error) {
const op = "sso.ListPermissions.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
var count int64
if err := r.conn.QueryRow(permSelect).Scan(&count); err != nil {
return nil, err
}
rows, err := r.conn.Query(permList, size, (page-1)*size)
if err != nil {
return nil, err
}
defer rows.Close()
res := make([]*model.Permission, 0, size)
for rows.Next() {
var p model.Permission
if err = rows.Scan(
&p.ID,
&p.Name,
); err != nil {
return nil, err
}
res = append(res, &p)
}
if err := rows.Err(); err != nil {
return nil, err
}
totalPages := int((count + int64(size) - 1) / int64(size))
return &model.PaginatedPermission{
Data: res,
Count: count,
TotalPages: totalPages,
CurrentPage: page,
HasNextPage: page < totalPages,
}, nil
}
func (r *Repository) GetPermission(ctx context.Context, id uint64) (*model.Permission, error) {
const op = "sso.GetPermission.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
res := &model.Permission{}
err := r.conn.QueryRow(permGet, id).Scan(&res.ID, &res.Name)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, repo.ErrNotFound
} else if err != nil {
return nil, err
}
return res, nil
}
func (r *Repository) CreatePerm(ctx context.Context, req *model.Permission) (uint64, error) {
const op = "sso.CreatePerm.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
tx, err := r.conn.Begin()
if err != nil {
return 0, err
}
var id uint64
err = tx.QueryRow(permCreate, req.Name).Scan(&id)
if err != nil {
tx.Rollback()
if strings.Contains(err.Error(), "unique constraint") {
return 0, repo.ErrAlreadyExists
}
return 0, err
}
if err := tx.Commit(); err != nil {
return 0, err
}
return id, nil
}
func (r *Repository) UpdatePerm(ctx context.Context, id uint64, req *model.Permission) error {
const op = "sso.UpdatePerm.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
tx, err := r.conn.Begin()
if err != nil {
return err
}
res, err := tx.Exec(permUpdate, req.Name, id)
if err != nil {
tx.Rollback()
return err
}
if aff, _ := res.RowsAffected(); aff == 0 {
return repo.ErrNotFound
}
return tx.Commit()
}
func (r *Repository) DeletePerm(ctx context.Context, id uint64) error {
const op = "sso.DeletePerm.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
res, err := r.conn.Exec(permDelete, id)
if err != nil {
return err
}
if aff, _ := res.RowsAffected(); aff == 0 {
return repo.ErrNotFound
}
return nil
}
package db
import (
"context"
"database/sql"
repo "github.com/JMURv/sso/internal/repository"
md "github.com/JMURv/sso/pkg/model"
utils "github.com/JMURv/sso/pkg/utils/db"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/bcrypt"
"strings"
)
func (r *Repository) SearchUser(ctx context.Context, query string, page, size int) (*md.PaginatedUser, error) {
const op = "sso.SearchUser.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
var count int64
if err := r.conn.QueryRow(userSearchSelectQ, "%"+query+"%", "%"+query+"%").
Scan(&count); err != nil {
return nil, err
}
rows, err := r.conn.Query(userSearchQ, "%"+query+"%", "%"+query+"%", size, (page-1)*size)
if err != nil {
return nil, err
}
defer rows.Close()
res := make([]*md.User, 0, size)
for rows.Next() {
user := &md.User{}
perms := make([]string, 0, 5)
if err = rows.Scan(
&user.ID,
&user.Name,
&user.Password,
&user.Email,
&user.Avatar,
&user.Address,
&user.Phone,
&user.CreatedAt,
&user.UpdatedAt,
pq.Array(&perms),
); err != nil {
return nil, err
}
user.Permissions, err = utils.ScanPermissions(perms)
if err != nil {
return nil, err
}
res = append(res, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
totalPages := int((count + int64(size) - 1) / int64(size))
return &md.PaginatedUser{
Data: res,
Count: count,
TotalPages: totalPages,
CurrentPage: page,
HasNextPage: page < totalPages,
}, nil
}
func (r *Repository) ListUsers(ctx context.Context, page, size int) (*md.PaginatedUser, error) {
const op = "sso.ListUsers.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
var count int64
if err := r.conn.QueryRow(userSelectQ).Scan(&count); err != nil {
return nil, err
}
rows, err := r.conn.Query(userListQ, size, (page-1)*size)
if err != nil {
return nil, err
}
defer rows.Close()
res := make([]*md.User, 0, size)
for rows.Next() {
user := &md.User{}
perms := make([]string, 0, 5)
if err := rows.Scan(
&user.ID,
&user.Name,
&user.Password,
&user.Email,
&user.Avatar,
&user.Address,
&user.Phone,
&user.CreatedAt,
&user.UpdatedAt,
pq.Array(&perms),
); err != nil {
return nil, err
}
user.Permissions, err = utils.ScanPermissions(perms)
if err != nil {
return nil, err
}
res = append(res, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
totalPages := int((count + int64(size) - 1) / int64(size))
return &md.PaginatedUser{
Data: res,
Count: count,
TotalPages: totalPages,
CurrentPage: page,
HasNextPage: page < totalPages,
}, nil
}
func (r *Repository) GetUserByID(ctx context.Context, userID uuid.UUID) (*md.User, error) {
const op = "sso.GetUserByID.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
res := &md.User{}
perms := make([]string, 0, 5)
err := r.conn.QueryRow(userGetByIDQ, userID).Scan(
&res.ID,
&res.Name,
&res.Password,
&res.Email,
&res.Avatar,
&res.Address,
&res.Phone,
&res.CreatedAt,
&res.UpdatedAt,
pq.Array(&perms),
)
if err == sql.ErrNoRows {
return nil, repo.ErrNotFound
} else if err != nil {
return nil, err
}
res.Permissions, err = utils.ScanPermissions(perms)
if err != nil {
return nil, err
}
return res, nil
}
func (r *Repository) GetUserByEmail(ctx context.Context, email string) (*md.User, error) {
const op = "sso.GetUserByEmail.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
res := &md.User{}
err := r.conn.QueryRow(userGetByEmailQ, email).
Scan(
&res.ID,
&res.Name,
&res.Password,
&res.Email,
&res.Avatar,
&res.Address,
&res.Phone,
&res.CreatedAt,
&res.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, repo.ErrNotFound
} else if err != nil {
return nil, err
}
return res, nil
}
func (r *Repository) CreateUser(ctx context.Context, req *md.User) (uuid.UUID, error) {
const op = "sso.CreateUser.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
password, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return uuid.Nil, repo.ErrGeneratingPassword
}
req.Password = string(password)
tx, err := r.conn.Begin()
if err != nil {
return uuid.Nil, err
}
var id uuid.UUID
err = tx.QueryRow(
userCreateQ,
req.Name,
req.Password,
req.Email,
req.Avatar,
req.Address,
req.Phone,
).Scan(&id)
if err != nil {
tx.Rollback()
if strings.Contains(err.Error(), "unique constraint") {
return uuid.Nil, repo.ErrAlreadyExists
}
return uuid.Nil, err
}
if len(req.Permissions) > 0 {
for _, v := range req.Permissions {
if _, err := tx.Exec(userCreatePermQ, id, v.ID, v.Value); err != nil {
tx.Rollback()
return uuid.Nil, err
}
}
}
if err = tx.Commit(); err != nil {
return uuid.Nil, err
}
return id, nil
}
func (r *Repository) UpdateUser(ctx context.Context, id uuid.UUID, req *md.User) error {
const op = "sso.UpdateUser.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
tx, err := r.conn.Begin()
if err != nil {
return err
}
res, err := tx.Exec(
userUpdateQ,
req.Name,
req.Password,
req.Email,
req.Avatar,
req.Address,
req.Phone,
id,
)
if err != nil {
return err
}
if aff, _ := res.RowsAffected(); aff == 0 {
return repo.ErrNotFound
}
if _, err = tx.Exec(userDeletePermQ, id); err != nil {
tx.Rollback()
return err
}
for _, v := range req.Permissions {
if _, err = tx.Exec(
userCreatePermQ,
id, v.ID, v.Value,
); err != nil {
tx.Rollback()
return err
}
}
return tx.Commit()
}
func (r *Repository) DeleteUser(ctx context.Context, id uuid.UUID) error {
const op = "sso.DeleteUser.repo"
span, _ := opentracing.StartSpanFromContext(ctx, op)
defer span.Finish()
res, err := r.conn.ExecContext(ctx, userDeleteQ, id)
if err != nil {
return err
}
if aff, _ := res.RowsAffected(); aff == 0 {
return repo.ErrNotFound
}
return nil
}