package auth import ( "context" "fmt" "time" "git.linuxforward.com/byop/byop-engine/dbstore" // For user lookup "git.linuxforward.com/byop/byop-engine/models" // For user model and roles "github.com/golang-jwt/jwt" "golang.org/x/crypto/bcrypt" // For password comparison ) // Token types const ( AccessToken = "access" RefreshToken = "refresh" ) // Claims represents the JWT claims structure type Claims struct { jwt.StandardClaims UserID string `json:"user_id"` // Changed from ClientID to UserID Role string `json:"role"` Type string `json:"type"` // Token type: "access" or "refresh" } // JWTService implements the auth.Service interface using JWT tokens type JWTService struct { privateKey []byte tokenDuration time.Duration refreshTokenDuration time.Duration tokenStore TokenStore // Interface for blacklist storage userStore *dbstore.SQLiteStore // Added userStore for Login method } // TokenStore defines storage operations for token management type TokenStore interface { IsBlacklisted(ctx context.Context, token string) (bool, error) Blacklist(ctx context.Context, token string, expiry time.Time) error } // NewJWTService creates a new JWT-based auth service func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore, userStore *dbstore.SQLiteStore) *JWTService { return &JWTService{ privateKey: privateKey, tokenDuration: tokenDuration, refreshTokenDuration: tokenDuration * 24, // Refresh tokens valid for 24x longer than access tokens tokenStore: store, userStore: userStore, // Initialize userStore } } // Login authenticates a user and generates tokens. func (s *JWTService) Login(ctx context.Context, email string, password string) (*TokenResponse, error) { user, err := s.userStore.GetUserByEmail(ctx, email) if err != nil { if models.IsErrNotFound(err) { // Use IsErrNotFound to check for ErrNotFound type return nil, models.NewErrUnauthorized("invalid_credentials", err) // Pass original err as cause } return nil, models.NewErrInternalServer("db_error_get_user", err) } if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { return nil, ErrInvalidCredentials } // User is authenticated, now generate tokens return s.GenerateToken(ctx, fmt.Sprintf("%d", user.ID), string(user.Role)) } // generateTokenWithClaims generates a JWT token with the given claims func (s *JWTService) generateTokenWithClaims(claims *Claims) (string, error) { token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signedToken, err := token.SignedString(s.privateKey) if err != nil { return "", err } return signedToken, nil } // GenerateToken generates new access and refresh tokens for a user func (s *JWTService) GenerateToken(ctx context.Context, userID string, role string) (*TokenResponse, error) { now := time.Now() accessExpiration := now.Add(s.tokenDuration) refreshExpiration := now.Add(s.refreshTokenDuration) accessClaims := &Claims{ UserID: userID, Role: role, Type: AccessToken, StandardClaims: jwt.StandardClaims{ ExpiresAt: accessExpiration.Unix(), IssuedAt: now.Unix(), Subject: userID, // Set subject to userID }, } refreshClaims := &Claims{ UserID: userID, Role: role, // Role might not be strictly necessary in refresh token, but can be included Type: RefreshToken, StandardClaims: jwt.StandardClaims{ ExpiresAt: refreshExpiration.Unix(), IssuedAt: now.Unix(), Subject: userID, // Set subject to userID }, } accessToken, err := s.generateTokenWithClaims(accessClaims) if err != nil { return nil, ErrInvalidToken // Or a more specific generation error } refreshToken, err := s.generateTokenWithClaims(refreshClaims) if err != nil { return nil, ErrInvalidToken // Or a more specific generation error } return &TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, ExpiresIn: int64(s.tokenDuration.Seconds()), TokenType: "Bearer", UserID: userID, UserRole: role, }, nil } // ValidateToken validates a JWT token and returns the userID and role if valid func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, string, error) { isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString) if err != nil { // Consider wrapping this error or returning a standard auth error return "", "", err } if isBlacklisted { return "", "", ErrTokenBlacklisted } claims, err := s.parseToken(tokenString) if err != nil { return "", "", err // parseToken already returns auth-specific errors like ErrTokenExpired, ErrInvalidToken } if claims.Type != AccessToken { return "", "", ErrInvalidToken // Use standard error for wrong token type } return claims.UserID, claims.Role, nil } // parseToken parses and validates a JWT token and returns the claims if valid func (s *JWTService) parseToken(tokenString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return s.privateKey, nil }) if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { if ve.Errors&jwt.ValidationErrorMalformed != 0 { return nil, ErrInvalidToken } else if ve.Errors&jwt.ValidationErrorExpired != 0 { return nil, ErrTokenExpired } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { return nil, ErrInvalidToken // Or a specific "not valid yet" error } } return nil, ErrInvalidToken // Default for other parsing errors } claims, ok := token.Claims.(*Claims) if !ok || !token.Valid { return nil, ErrInvalidToken } return claims, nil } // RefreshToken validates a refresh token and generates new access and refresh tokens func (s *JWTService) RefreshToken(ctx context.Context, refreshTokenString string) (*TokenResponse, error) { isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, refreshTokenString) if err != nil { return nil, err } if isBlacklisted { return nil, ErrTokenBlacklisted } claims, err := s.parseToken(refreshTokenString) if err != nil { return nil, err } if claims.Type != RefreshToken { return nil, ErrInvalidToken // Not a refresh token } // Optional: Blacklist the old refresh token to prevent reuse if not already handled by single-use policy // expiryOldToken := time.Unix(claims.ExpiresAt, 0) // if err := s.tokenStore.Blacklist(ctx, refreshTokenString, expiryOldToken); err != nil { // return nil, err // Or log and continue if blacklisting is not critical path for refresh // } // Generate new access and refresh tokens using UserID and Role from the valid refresh token return s.GenerateToken(ctx, claims.UserID, claims.Role) } // Logout invalidates a token by blacklisting it func (s *JWTService) Logout(ctx context.Context, tokenString string) error { claims, err := s.parseToken(tokenString) if err != nil { // If token is already invalid (e.g. expired, malformed), blacklisting might not be necessary or might fail. // Depending on policy, might return nil or the parsing error. // For now, return the error to indicate why blacklisting couldn't proceed based on token data. return err } // Blacklist both access and potentially associated refresh tokens if strategy requires. // For simplicity, blacklisting the provided token (which should be an access token). expiry := time.Unix(claims.ExpiresAt, 0) if time.Now().After(expiry) { return ErrTokenExpired // No need to blacklist already expired token } return s.tokenStore.Blacklist(ctx, tokenString, expiry) }