jwt.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package auth
  2. import (
  3. "context"
  4. "fmt"
  5. "time"
  6. "git.linuxforward.com/byop/byop-engine/dbstore" // For user lookup
  7. "git.linuxforward.com/byop/byop-engine/models" // For user model and roles
  8. "github.com/golang-jwt/jwt"
  9. "golang.org/x/crypto/bcrypt" // For password comparison
  10. )
  11. // Token types
  12. const (
  13. AccessToken = "access"
  14. RefreshToken = "refresh"
  15. )
  16. // Claims represents the JWT claims structure
  17. type Claims struct {
  18. jwt.StandardClaims
  19. UserID string `json:"user_id"` // Changed from ClientID to UserID
  20. Role string `json:"role"`
  21. Type string `json:"type"` // Token type: "access" or "refresh"
  22. }
  23. // JWTService implements the auth.Service interface using JWT tokens
  24. type JWTService struct {
  25. privateKey []byte
  26. tokenDuration time.Duration
  27. refreshTokenDuration time.Duration
  28. tokenStore TokenStore // Interface for blacklist storage
  29. userStore *dbstore.SQLiteStore // Added userStore for Login method
  30. }
  31. // TokenStore defines storage operations for token management
  32. type TokenStore interface {
  33. IsBlacklisted(ctx context.Context, token string) (bool, error)
  34. Blacklist(ctx context.Context, token string, expiry time.Time) error
  35. }
  36. // NewJWTService creates a new JWT-based auth service
  37. func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore, userStore *dbstore.SQLiteStore) *JWTService {
  38. return &JWTService{
  39. privateKey: privateKey,
  40. tokenDuration: tokenDuration,
  41. refreshTokenDuration: tokenDuration * 24, // Refresh tokens valid for 24x longer than access tokens
  42. tokenStore: store,
  43. userStore: userStore, // Initialize userStore
  44. }
  45. }
  46. // Login authenticates a user and generates tokens.
  47. func (s *JWTService) Login(ctx context.Context, email string, password string) (*TokenResponse, error) {
  48. user, err := s.userStore.GetUserByEmail(ctx, email)
  49. if err != nil {
  50. if models.IsErrNotFound(err) { // Use IsErrNotFound to check for ErrNotFound type
  51. return nil, models.NewErrUnauthorized("invalid_credentials", err) // Pass original err as cause
  52. }
  53. return nil, models.NewErrInternalServer("db_error_get_user", err)
  54. }
  55. if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
  56. return nil, ErrInvalidCredentials
  57. }
  58. // User is authenticated, now generate tokens
  59. return s.GenerateToken(ctx, fmt.Sprintf("%d", user.ID), string(user.Role))
  60. }
  61. // generateTokenWithClaims generates a JWT token with the given claims
  62. func (s *JWTService) generateTokenWithClaims(claims *Claims) (string, error) {
  63. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  64. signedToken, err := token.SignedString(s.privateKey)
  65. if err != nil {
  66. return "", err
  67. }
  68. return signedToken, nil
  69. }
  70. // GenerateToken generates new access and refresh tokens for a user
  71. func (s *JWTService) GenerateToken(ctx context.Context, userID string, role string) (*TokenResponse, error) {
  72. now := time.Now()
  73. accessExpiration := now.Add(s.tokenDuration)
  74. refreshExpiration := now.Add(s.refreshTokenDuration)
  75. accessClaims := &Claims{
  76. UserID: userID,
  77. Role: role,
  78. Type: AccessToken,
  79. StandardClaims: jwt.StandardClaims{
  80. ExpiresAt: accessExpiration.Unix(),
  81. IssuedAt: now.Unix(),
  82. Subject: userID, // Set subject to userID
  83. },
  84. }
  85. refreshClaims := &Claims{
  86. UserID: userID,
  87. Role: role, // Role might not be strictly necessary in refresh token, but can be included
  88. Type: RefreshToken,
  89. StandardClaims: jwt.StandardClaims{
  90. ExpiresAt: refreshExpiration.Unix(),
  91. IssuedAt: now.Unix(),
  92. Subject: userID, // Set subject to userID
  93. },
  94. }
  95. accessToken, err := s.generateTokenWithClaims(accessClaims)
  96. if err != nil {
  97. return nil, ErrInvalidToken // Or a more specific generation error
  98. }
  99. refreshToken, err := s.generateTokenWithClaims(refreshClaims)
  100. if err != nil {
  101. return nil, ErrInvalidToken // Or a more specific generation error
  102. }
  103. return &TokenResponse{
  104. AccessToken: accessToken,
  105. RefreshToken: refreshToken,
  106. ExpiresIn: int64(s.tokenDuration.Seconds()),
  107. TokenType: "Bearer",
  108. UserID: userID,
  109. UserRole: role,
  110. }, nil
  111. }
  112. // ValidateToken validates a JWT token and returns the userID and role if valid
  113. func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, string, error) {
  114. isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString)
  115. if err != nil {
  116. // Consider wrapping this error or returning a standard auth error
  117. return "", "", err
  118. }
  119. if isBlacklisted {
  120. return "", "", ErrTokenBlacklisted
  121. }
  122. claims, err := s.parseToken(tokenString)
  123. if err != nil {
  124. return "", "", err // parseToken already returns auth-specific errors like ErrTokenExpired, ErrInvalidToken
  125. }
  126. if claims.Type != AccessToken {
  127. return "", "", ErrInvalidToken // Use standard error for wrong token type
  128. }
  129. return claims.UserID, claims.Role, nil
  130. }
  131. // parseToken parses and validates a JWT token and returns the claims if valid
  132. func (s *JWTService) parseToken(tokenString string) (*Claims, error) {
  133. token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  134. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  135. return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  136. }
  137. return s.privateKey, nil
  138. })
  139. if err != nil {
  140. if ve, ok := err.(*jwt.ValidationError); ok {
  141. if ve.Errors&jwt.ValidationErrorMalformed != 0 {
  142. return nil, ErrInvalidToken
  143. } else if ve.Errors&jwt.ValidationErrorExpired != 0 {
  144. return nil, ErrTokenExpired
  145. } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
  146. return nil, ErrInvalidToken // Or a specific "not valid yet" error
  147. }
  148. }
  149. return nil, ErrInvalidToken // Default for other parsing errors
  150. }
  151. claims, ok := token.Claims.(*Claims)
  152. if !ok || !token.Valid {
  153. return nil, ErrInvalidToken
  154. }
  155. return claims, nil
  156. }
  157. // RefreshToken validates a refresh token and generates new access and refresh tokens
  158. func (s *JWTService) RefreshToken(ctx context.Context, refreshTokenString string) (*TokenResponse, error) {
  159. isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, refreshTokenString)
  160. if err != nil {
  161. return nil, err
  162. }
  163. if isBlacklisted {
  164. return nil, ErrTokenBlacklisted
  165. }
  166. claims, err := s.parseToken(refreshTokenString)
  167. if err != nil {
  168. return nil, err
  169. }
  170. if claims.Type != RefreshToken {
  171. return nil, ErrInvalidToken // Not a refresh token
  172. }
  173. // Optional: Blacklist the old refresh token to prevent reuse if not already handled by single-use policy
  174. // expiryOldToken := time.Unix(claims.ExpiresAt, 0)
  175. // if err := s.tokenStore.Blacklist(ctx, refreshTokenString, expiryOldToken); err != nil {
  176. // return nil, err // Or log and continue if blacklisting is not critical path for refresh
  177. // }
  178. // Generate new access and refresh tokens using UserID and Role from the valid refresh token
  179. return s.GenerateToken(ctx, claims.UserID, claims.Role)
  180. }
  181. // Logout invalidates a token by blacklisting it
  182. func (s *JWTService) Logout(ctx context.Context, tokenString string) error {
  183. claims, err := s.parseToken(tokenString)
  184. if err != nil {
  185. // If token is already invalid (e.g. expired, malformed), blacklisting might not be necessary or might fail.
  186. // Depending on policy, might return nil or the parsing error.
  187. // For now, return the error to indicate why blacklisting couldn't proceed based on token data.
  188. return err
  189. }
  190. // Blacklist both access and potentially associated refresh tokens if strategy requires.
  191. // For simplicity, blacklisting the provided token (which should be an access token).
  192. expiry := time.Unix(claims.ExpiresAt, 0)
  193. if time.Now().After(expiry) {
  194. return ErrTokenExpired // No need to blacklist already expired token
  195. }
  196. return s.tokenStore.Blacklist(ctx, tokenString, expiry)
  197. }