jwt.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. package auth
  2. import (
  3. "context"
  4. "errors"
  5. "time"
  6. "github.com/golang-jwt/jwt"
  7. )
  8. // Token types
  9. const (
  10. AccessToken = "access"
  11. RefreshToken = "refresh"
  12. )
  13. // Claims represents the JWT claims structure
  14. type Claims struct {
  15. jwt.StandardClaims
  16. ClientID string `json:"client_id"`
  17. Role string `json:"role"`
  18. Type string `json:"type"` // Token type: "access" or "refresh"
  19. }
  20. // JWTService implements the auth.Service interface using JWT tokens
  21. type JWTService struct {
  22. privateKey []byte
  23. tokenDuration time.Duration
  24. refreshTokenDuration time.Duration // Duration for refresh tokens (typically longer)
  25. tokenStore TokenStore // Interface for blacklist storage
  26. }
  27. // TokenStore defines storage operations for token management
  28. type TokenStore interface {
  29. IsBlacklisted(ctx context.Context, token string) (bool, error)
  30. Blacklist(ctx context.Context, token string, expiry time.Time) error
  31. }
  32. // NewJWTService creates a new JWT-based auth service
  33. func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore) *JWTService {
  34. return &JWTService{
  35. privateKey: privateKey,
  36. tokenDuration: tokenDuration,
  37. refreshTokenDuration: tokenDuration * 24, // Refresh tokens valid for 24x longer than access tokens
  38. tokenStore: store,
  39. }
  40. }
  41. // generateTokenWithClaims generates a JWT token with the given claims
  42. func (s *JWTService) generateTokenWithClaims(claims *Claims) (string, error) {
  43. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  44. signedToken, err := token.SignedString(s.privateKey)
  45. if err != nil {
  46. return "", err
  47. }
  48. return signedToken, nil
  49. }
  50. // GenerateToken generates new access and refresh tokens for a client
  51. func (s *JWTService) GenerateToken(ctx context.Context, clientID string, role string) (*TokenResponse, error) {
  52. now := time.Now()
  53. accessExpiration := now.Add(s.tokenDuration)
  54. refreshExpiration := now.Add(s.refreshTokenDuration)
  55. // Create access token claims
  56. accessClaims := &Claims{
  57. ClientID: clientID,
  58. Role: role,
  59. Type: AccessToken,
  60. StandardClaims: jwt.StandardClaims{
  61. ExpiresAt: accessExpiration.Unix(),
  62. IssuedAt: now.Unix(),
  63. },
  64. }
  65. // Create refresh token claims
  66. refreshClaims := &Claims{
  67. ClientID: clientID,
  68. Role: role,
  69. Type: RefreshToken,
  70. StandardClaims: jwt.StandardClaims{
  71. ExpiresAt: refreshExpiration.Unix(),
  72. IssuedAt: now.Unix(),
  73. },
  74. }
  75. // Generate the tokens
  76. accessToken, err := s.generateTokenWithClaims(accessClaims)
  77. if err != nil {
  78. return nil, err
  79. }
  80. refreshToken, err := s.generateTokenWithClaims(refreshClaims)
  81. if err != nil {
  82. return nil, err
  83. }
  84. return &TokenResponse{
  85. AccessToken: accessToken,
  86. RefreshToken: refreshToken,
  87. ExpiresIn: int64(s.tokenDuration.Seconds()),
  88. TokenType: "Bearer",
  89. }, nil
  90. }
  91. // ValidateToken validates a JWT token and returns the client ID and role if valid
  92. func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, string, error) {
  93. // Check if the token is blacklisted
  94. isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString)
  95. if err != nil {
  96. return "", "", err
  97. }
  98. if isBlacklisted {
  99. return "", "", ErrTokenBlacklisted
  100. }
  101. // Parse and validate the token
  102. claims, err := s.parseToken(tokenString)
  103. if err != nil {
  104. return "", "", err
  105. }
  106. // For validation purposes, we only accept access tokens
  107. if claims.Type != AccessToken {
  108. return "", "", errors.New("invalid token type")
  109. }
  110. return claims.ClientID, claims.Role, nil
  111. }
  112. // parseToken parses and validates a JWT token and returns the claims if valid
  113. func (s *JWTService) parseToken(tokenString string) (*Claims, error) {
  114. token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  115. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  116. return nil, errors.New("unexpected signing method")
  117. }
  118. return s.privateKey, nil
  119. })
  120. if err != nil {
  121. if ve, ok := err.(*jwt.ValidationError); ok {
  122. if ve.Errors&jwt.ValidationErrorExpired != 0 {
  123. return nil, ErrTokenExpired
  124. }
  125. }
  126. return nil, ErrInvalidToken
  127. }
  128. // Extract claims
  129. claims, ok := token.Claims.(*Claims)
  130. if !ok || !token.Valid {
  131. return nil, ErrInvalidToken
  132. }
  133. return claims, nil
  134. }
  135. // RefreshToken validates a refresh token and generates new access and refresh tokens
  136. func (s *JWTService) RefreshToken(ctx context.Context, refreshTokenString string) (*TokenResponse, error) {
  137. // Check if the refresh token is blacklisted
  138. isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, refreshTokenString)
  139. if err != nil {
  140. return nil, err
  141. }
  142. if isBlacklisted {
  143. return nil, ErrTokenBlacklisted
  144. }
  145. // Parse and validate the token
  146. claims, err := s.parseToken(refreshTokenString)
  147. if err != nil {
  148. return nil, err
  149. }
  150. // Check if it's a refresh token
  151. if claims.Type != RefreshToken {
  152. return nil, errors.New("not a refresh token")
  153. }
  154. // Blacklist the old refresh token to prevent reuse
  155. expiry := time.Unix(claims.ExpiresAt, 0)
  156. if err := s.tokenStore.Blacklist(ctx, refreshTokenString, expiry); err != nil {
  157. return nil, err
  158. }
  159. // Generate new access and refresh tokens
  160. return s.GenerateToken(ctx, claims.ClientID, claims.Role)
  161. }
  162. // Logout invalidates a token by blacklisting it
  163. func (s *JWTService) Logout(ctx context.Context, tokenString string) error {
  164. // Parse the token to extract expiration time
  165. claims, err := s.parseToken(tokenString)
  166. if err != nil {
  167. return err
  168. }
  169. expiry := time.Unix(claims.ExpiresAt, 0)
  170. return s.tokenStore.Blacklist(ctx, tokenString, expiry)
  171. }