package auth import ( "context" "errors" "time" "github.com/golang-jwt/jwt" ) // Token types const ( AccessToken = "access" RefreshToken = "refresh" ) // Claims represents the JWT claims structure type Claims struct { jwt.StandardClaims ClientID string `json:"client_id"` Role string `json:"role"` Type string `json:"type"` // Token type: "access" or "refresh" } // JWTService implements the auth.Service interface using JWT tokens type JWTService struct { privateKey []byte tokenDuration time.Duration refreshTokenDuration time.Duration // Duration for refresh tokens (typically longer) tokenStore TokenStore // Interface for blacklist storage } // TokenStore defines storage operations for token management type TokenStore interface { IsBlacklisted(ctx context.Context, token string) (bool, error) Blacklist(ctx context.Context, token string, expiry time.Time) error } // NewJWTService creates a new JWT-based auth service func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore) *JWTService { return &JWTService{ privateKey: privateKey, tokenDuration: tokenDuration, refreshTokenDuration: tokenDuration * 24, // Refresh tokens valid for 24x longer than access tokens tokenStore: store, } } // generateTokenWithClaims generates a JWT token with the given claims func (s *JWTService) generateTokenWithClaims(claims *Claims) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signedToken, err := token.SignedString(s.privateKey) if err != nil { return "", err } return signedToken, nil } // GenerateToken generates new access and refresh tokens for a client func (s *JWTService) GenerateToken(ctx context.Context, clientID string, role string) (*TokenResponse, error) { now := time.Now() accessExpiration := now.Add(s.tokenDuration) refreshExpiration := now.Add(s.refreshTokenDuration) // Create access token claims accessClaims := &Claims{ ClientID: clientID, Role: role, Type: AccessToken, StandardClaims: jwt.StandardClaims{ ExpiresAt: accessExpiration.Unix(), IssuedAt: now.Unix(), }, } // Create refresh token claims refreshClaims := &Claims{ ClientID: clientID, Role: role, Type: RefreshToken, StandardClaims: jwt.StandardClaims{ ExpiresAt: refreshExpiration.Unix(), IssuedAt: now.Unix(), }, } // Generate the tokens accessToken, err := s.generateTokenWithClaims(accessClaims) if err != nil { return nil, err } refreshToken, err := s.generateTokenWithClaims(refreshClaims) if err != nil { return nil, err } return &TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: int64(s.tokenDuration.Seconds()), TokenType: "Bearer", }, nil } // ValidateToken validates a JWT token and returns the client ID if valid func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, error) { // Check if the token is blacklisted isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString) if err != nil { return "", err } if isBlacklisted { return "", ErrTokenBlacklisted } // Parse and validate the token claims, err := s.parseToken(tokenString) if err != nil { return "", err } // For validation purposes, we only accept access tokens if claims.Type != AccessToken { return "", errors.New("invalid token type") } return claims.ClientID, nil } // parseToken parses and validates a JWT token and returns the claims if valid func (s *JWTService) parseToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("unexpected signing method") } return s.privateKey, nil }) if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { if ve.Errors&jwt.ValidationErrorExpired != 0 { return nil, ErrTokenExpired } } return nil, ErrInvalidToken } // Extract claims claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, ErrInvalidToken } return claims, nil } // RefreshToken validates a refresh token and generates new access and refresh tokens func (s *JWTService) RefreshToken(ctx context.Context, refreshTokenString string) (*TokenResponse, error) { // Check if the refresh token is blacklisted isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, refreshTokenString) if err != nil { return nil, err } if isBlacklisted { return nil, ErrTokenBlacklisted } // Parse and validate the token claims, err := s.parseToken(refreshTokenString) if err != nil { return nil, err } // Check if it's a refresh token if claims.Type != RefreshToken { return nil, errors.New("not a refresh token") } // Blacklist the old refresh token to prevent reuse expiry := time.Unix(claims.ExpiresAt, 0) if err := s.tokenStore.Blacklist(ctx, refreshTokenString, expiry); err != nil { return nil, err } // Generate new access and refresh tokens return s.GenerateToken(ctx, claims.ClientID, claims.Role) } // Logout invalidates a token by blacklisting it func (s *JWTService) Logout(ctx context.Context, tokenString string) error { // Parse the token to extract expiration time claims, err := s.parseToken(tokenString) if err != nil { return err } expiry := time.Unix(claims.ExpiresAt, 0) return s.tokenStore.Blacklist(ctx, tokenString, expiry) }