123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- package auth
- import (
- "context"
- "errors"
- "time"
- "github.com/golang-jwt/jwt"
- )
- // Token types
- const (
- AccessToken = "access"
- RefreshToken = "refresh"
- )
- // Claims represents the JWT claims structure
- type Claims struct {
- jwt.StandardClaims
- ClientID string `json:"client_id"`
- Role string `json:"role"`
- Type string `json:"type"` // Token type: "access" or "refresh"
- }
- // JWTService implements the auth.Service interface using JWT tokens
- type JWTService struct {
- privateKey []byte
- tokenDuration time.Duration
- refreshTokenDuration time.Duration // Duration for refresh tokens (typically longer)
- tokenStore TokenStore // Interface for blacklist storage
- }
- // TokenStore defines storage operations for token management
- type TokenStore interface {
- IsBlacklisted(ctx context.Context, token string) (bool, error)
- Blacklist(ctx context.Context, token string, expiry time.Time) error
- }
- // NewJWTService creates a new JWT-based auth service
- func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore) *JWTService {
- return &JWTService{
- privateKey: privateKey,
- tokenDuration: tokenDuration,
- refreshTokenDuration: tokenDuration * 24, // Refresh tokens valid for 24x longer than access tokens
- tokenStore: store,
- }
- }
- // generateTokenWithClaims generates a JWT token with the given claims
- func (s *JWTService) generateTokenWithClaims(claims *Claims) (string, error) {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- signedToken, err := token.SignedString(s.privateKey)
- if err != nil {
- return "", err
- }
- return signedToken, nil
- }
- // GenerateToken generates new access and refresh tokens for a client
- func (s *JWTService) GenerateToken(ctx context.Context, clientID string, role string) (*TokenResponse, error) {
- now := time.Now()
- accessExpiration := now.Add(s.tokenDuration)
- refreshExpiration := now.Add(s.refreshTokenDuration)
- // Create access token claims
- accessClaims := &Claims{
- ClientID: clientID,
- Role: role,
- Type: AccessToken,
- StandardClaims: jwt.StandardClaims{
- ExpiresAt: accessExpiration.Unix(),
- IssuedAt: now.Unix(),
- },
- }
- // Create refresh token claims
- refreshClaims := &Claims{
- ClientID: clientID,
- Role: role,
- Type: RefreshToken,
- StandardClaims: jwt.StandardClaims{
- ExpiresAt: refreshExpiration.Unix(),
- IssuedAt: now.Unix(),
- },
- }
- // Generate the tokens
- accessToken, err := s.generateTokenWithClaims(accessClaims)
- if err != nil {
- return nil, err
- }
- refreshToken, err := s.generateTokenWithClaims(refreshClaims)
- if err != nil {
- return nil, err
- }
- return &TokenResponse{
- AccessToken: accessToken,
- RefreshToken: refreshToken,
- ExpiresIn: int64(s.tokenDuration.Seconds()),
- TokenType: "Bearer",
- }, nil
- }
- // ValidateToken validates a JWT token and returns the client ID and role if valid
- func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, string, error) {
- // Check if the token is blacklisted
- isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString)
- if err != nil {
- return "", "", err
- }
- if isBlacklisted {
- return "", "", ErrTokenBlacklisted
- }
- // Parse and validate the token
- claims, err := s.parseToken(tokenString)
- if err != nil {
- return "", "", err
- }
- // For validation purposes, we only accept access tokens
- if claims.Type != AccessToken {
- return "", "", errors.New("invalid token type")
- }
- return claims.ClientID, claims.Role, nil
- }
- // parseToken parses and validates a JWT token and returns the claims if valid
- func (s *JWTService) parseToken(tokenString string) (*Claims, error) {
- token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, errors.New("unexpected signing method")
- }
- return s.privateKey, nil
- })
- if err != nil {
- if ve, ok := err.(*jwt.ValidationError); ok {
- if ve.Errors&jwt.ValidationErrorExpired != 0 {
- return nil, ErrTokenExpired
- }
- }
- return nil, ErrInvalidToken
- }
- // Extract claims
- claims, ok := token.Claims.(*Claims)
- if !ok || !token.Valid {
- return nil, ErrInvalidToken
- }
- return claims, nil
- }
- // RefreshToken validates a refresh token and generates new access and refresh tokens
- func (s *JWTService) RefreshToken(ctx context.Context, refreshTokenString string) (*TokenResponse, error) {
- // Check if the refresh token is blacklisted
- isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, refreshTokenString)
- if err != nil {
- return nil, err
- }
- if isBlacklisted {
- return nil, ErrTokenBlacklisted
- }
- // Parse and validate the token
- claims, err := s.parseToken(refreshTokenString)
- if err != nil {
- return nil, err
- }
- // Check if it's a refresh token
- if claims.Type != RefreshToken {
- return nil, errors.New("not a refresh token")
- }
- // Blacklist the old refresh token to prevent reuse
- expiry := time.Unix(claims.ExpiresAt, 0)
- if err := s.tokenStore.Blacklist(ctx, refreshTokenString, expiry); err != nil {
- return nil, err
- }
- // Generate new access and refresh tokens
- return s.GenerateToken(ctx, claims.ClientID, claims.Role)
- }
- // Logout invalidates a token by blacklisting it
- func (s *JWTService) Logout(ctx context.Context, tokenString string) error {
- // Parse the token to extract expiration time
- claims, err := s.parseToken(tokenString)
- if err != nil {
- return err
- }
- expiry := time.Unix(claims.ExpiresAt, 0)
- return s.tokenStore.Blacklist(ctx, tokenString, expiry)
- }
|