jwt.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package auth
  2. import (
  3. "context"
  4. "errors"
  5. "time"
  6. "github.com/golang-jwt/jwt"
  7. )
  8. // Claims represents the JWT claims structure
  9. type Claims struct {
  10. jwt.StandardClaims
  11. ClientID string `json:"client_id"`
  12. Role string `json:"role"`
  13. }
  14. // JWTService implements the auth.Service interface using JWT tokens
  15. type JWTService struct {
  16. privateKey []byte
  17. tokenDuration time.Duration
  18. tokenStore TokenStore // Interface for blacklist storage
  19. }
  20. // TokenStore defines storage operations for token management
  21. type TokenStore interface {
  22. IsBlacklisted(ctx context.Context, token string) (bool, error)
  23. Blacklist(ctx context.Context, token string, expiry time.Time) error
  24. }
  25. // NewJWTService creates a new JWT-based auth service
  26. func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore) *JWTService {
  27. return &JWTService{
  28. privateKey: privateKey,
  29. tokenDuration: tokenDuration,
  30. tokenStore: store,
  31. }
  32. }
  33. // GenerateToken generates a new JWT token for a client
  34. func (s *JWTService) GenerateToken(ctx context.Context, clientID string, role string) (string, error) {
  35. expirationTime := time.Now().Add(s.tokenDuration)
  36. claims := &Claims{
  37. ClientID: clientID,
  38. Role: role,
  39. StandardClaims: jwt.StandardClaims{
  40. ExpiresAt: expirationTime.Unix(),
  41. },
  42. }
  43. token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  44. signedToken, err := token.SignedString(s.privateKey)
  45. if err != nil {
  46. return "", err
  47. }
  48. return signedToken, nil
  49. }
  50. // ValidateToken validates a JWT token and returns the client ID if valid
  51. func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, error) {
  52. // Check if the token is blacklisted
  53. isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString)
  54. if err != nil {
  55. return "", err
  56. }
  57. if isBlacklisted {
  58. return "", errors.New("token is blacklisted")
  59. }
  60. // Parse and validate the token
  61. token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  62. if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  63. return nil, errors.New("unexpected signing method")
  64. }
  65. return s.privateKey, nil
  66. })
  67. if err != nil {
  68. return "", err
  69. }
  70. // Extract claims
  71. claims, ok := token.Claims.(*Claims)
  72. if !ok || !token.Valid {
  73. return "", errors.New("invalid token")
  74. }
  75. return claims.ClientID, nil
  76. }
  77. // RefreshToken refreshes an existing token and returns a new one
  78. func (s *JWTService) RefreshToken(ctx context.Context, tokenString string) (string, error) {
  79. // Validate the existing token
  80. clientID, err := s.ValidateToken(ctx, tokenString)
  81. if err != nil {
  82. return "", err
  83. }
  84. // Generate a new token
  85. return s.GenerateToken(ctx, clientID, "")
  86. }
  87. // Logout invalidates a token by blacklisting it
  88. func (s *JWTService) Logout(ctx context.Context, tokenString string) error {
  89. // Parse the token to extract expiration time
  90. token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
  91. return s.privateKey, nil
  92. })
  93. if err != nil {
  94. return err
  95. }
  96. claims, ok := token.Claims.(*Claims)
  97. if !ok || !token.Valid {
  98. return errors.New("invalid token")
  99. }
  100. expiry := time.Unix(claims.ExpiresAt, 0)
  101. return s.tokenStore.Blacklist(ctx, tokenString, expiry)
  102. }