package jot import ( "errors" "fmt" "github.com/golang-jwt/jwt/v5" "net/http" "strings" "time" ) // Auth is the type used to instantiate this package. type Auth struct { Issuer string // Who issues the token, e.g. company.com. Audience string // Who is the token for, e.g. company.com. Secret string // A strong secret, used to sign the tokens. TokenExpiry time.Duration // When does the token expire, e.g. time.Minute * 15. RefreshExpiry time.Duration // When does the refresh token expire, e.g. time.Hour * 24. CookieDomain string // The domain, for refresh cookies. CookiePath string // The path, for refresh cookies. CookieName string // The name of the refresh token cookie. } // User is a generic type used to hold the minimal amount of data we require in order to issue tokens. type User struct { ID int `json:"id"` FirstName string `json:"first_name"` LastName string `json:"last_name"` } // TokenPairs is the type used to generate JSON containing the JWT token and the refresh token. type TokenPairs struct { Token string `json:"access_token"` RefreshToken string `json:"refresh_token"` } // Claims is the type used to describe the claims in a given token. type Claims struct { jwt.RegisteredClaims } // New returns an instance of Auth, with sensible defaults where possible. Naturally, // any of defaults can be overridden. func New(d string) Auth { return Auth{ Issuer: d, Audience: d, TokenExpiry: time.Minute * 15, RefreshExpiry: time.Hour * 24, CookieName: "refresh_token", CookiePath: "/", CookieDomain: d, } } // GetTokenFromHeaderAndVerify extracts a token from the Authorization header, verifies it, and returns the // token, the claims, and error, if any. func (j *Auth) GetTokenFromHeaderAndVerify(w http.ResponseWriter, r *http.Request) (string, *Claims, error) { w.Header().Add("Vary", "Authorization") // Get the Authorization header. authHeader := r.Header.Get("Authorization") // Sanity check. if authHeader == "" { return "", nil, errors.New("no auth header") } // Split the header up on spaces. headerParts := strings.Split(authHeader, " ") if len(headerParts) != 2 { return "", nil, errors.New("invalid auth header") } // Check to see if we have the word "Bearer" in the right spot (we should). if headerParts[0] != "Bearer" { return "", nil, errors.New("unauthorized - no bearer") } // Get the actual token. token := headerParts[1] // Declare an empty Claims variable. claims := &Claims{} // Parse the token with our claims (we read into claims), using our secret (from the receiver). _, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) { // validate the signing algorithm is what we expect if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return []byte(j.Secret), nil }) // Check for errors. Note that this catches expired tokens as well. if err != nil { // return an easy to spot error if the token is expired if strings.HasPrefix(err.Error(), "token is expired by") { return "", nil, errors.New("expired token") } return "", nil, err } // Make sure we issued this token. if claims.Issuer != j.Issuer { // we did not issue this token return "", nil, errors.New("incorrect issuer") } // If we get this far, the token is valid, so we return it, along with the claims. return token, claims, nil } // GenerateTokenPair takes a user of type jot.User and attempts to generate a pair of tokens for that user // (jwt and refresh tokens). func (j *Auth) GenerateTokenPair(user *User) (TokenPairs, error) { // Create token. token := jwt.New(jwt.SigningMethodHS256) // Set claims. claims := token.Claims.(jwt.MapClaims) claims["name"] = fmt.Sprintf("%s %s", user.FirstName, user.LastName) claims["sub"] = fmt.Sprint(user.ID) claims["aud"] = j.Audience claims["iss"] = j.Issuer claims["iat"] = time.Now().UTC().Unix() claims["typ"] = "JWT" // Set expiry; should be short! claims["exp"] = time.Now().UTC().Add(j.TokenExpiry).Unix() // Create signed token. signedAccessToken, err := token.SignedString([]byte(j.Secret)) if err != nil { return TokenPairs{}, err } // Create refresh token and set claims (just subject, issued at, and expiry). refreshToken := jwt.New(jwt.SigningMethodHS256) refreshTokenClaims := refreshToken.Claims.(jwt.MapClaims) refreshTokenClaims["sub"] = fmt.Sprint(user.ID) refreshTokenClaims["iat"] = time.Now().UTC().Unix() // Set expiry; must be longer than JWT token expiry! refreshTokenClaims["exp"] = time.Now().UTC().Add(j.RefreshExpiry).Unix() // Create signed refresh token. signedRefreshToken, err := refreshToken.SignedString([]byte(j.Secret)) if err != nil { return TokenPairs{}, err } // Create token pairs and populate with signed tokens. var tokenPairs = TokenPairs{ Token: signedAccessToken, RefreshToken: signedRefreshToken, } // Return the token pair, and no error. return tokenPairs, nil } // GetRefreshCookie returns a cookie containing the refresh token. Note that the cookie is http only, secure, // and set to same site strict mode. func (j *Auth) GetRefreshCookie(refreshToken string) *http.Cookie { return &http.Cookie{ Name: j.CookieName, Path: j.CookiePath, Value: refreshToken, Expires: time.Now().Add(j.RefreshExpiry), MaxAge: int(j.RefreshExpiry.Seconds()), SameSite: http.SameSiteStrictMode, Domain: j.CookieDomain, HttpOnly: true, Secure: true, } } // GetExpiredRefreshCookie is a convenience method to return a cookie suitable for forcing a user's browser // to delete the existing cookie. func (j *Auth) GetExpiredRefreshCookie() *http.Cookie { return &http.Cookie{ Name: j.CookieName, Value: "", Domain: j.CookieDomain, Path: j.CookiePath, MaxAge: -1, Expires: time.Unix(0, 0), HttpOnly: true, SameSite: http.SameSiteStrictMode, Secure: true, } }