123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- package auth
- import (
- "context"
- "fmt"
- "time"
- "git.linuxforward.com/byop/byop-engine/dbstore" // For user lookup
- "git.linuxforward.com/byop/byop-engine/models" // For user model and roles
- "github.com/golang-jwt/jwt"
- "golang.org/x/crypto/bcrypt" // For password comparison
- )
- // Token types
- const (
- AccessToken = "access"
- RefreshToken = "refresh"
- )
- // Claims represents the JWT claims structure
- type Claims struct {
- jwt.StandardClaims
- UserID string `json:"user_id"` // Changed from ClientID to UserID
- Role string `json:"role"`
- Type string `json:"type"` // Token type: "access" or "refresh"
- }
- // JWTService implements the auth.Service interface using JWT tokens
- type JWTService struct {
- privateKey []byte
- tokenDuration time.Duration
- refreshTokenDuration time.Duration
- tokenStore TokenStore // Interface for blacklist storage
- userStore *dbstore.SQLiteStore // Added userStore for Login method
- }
- // TokenStore defines storage operations for token management
- type TokenStore interface {
- IsBlacklisted(ctx context.Context, token string) (bool, error)
- Blacklist(ctx context.Context, token string, expiry time.Time) error
- }
- // NewJWTService creates a new JWT-based auth service
- func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore, userStore *dbstore.SQLiteStore) *JWTService {
- return &JWTService{
- privateKey: privateKey,
- tokenDuration: tokenDuration,
- refreshTokenDuration: tokenDuration * 24, // Refresh tokens valid for 24x longer than access tokens
- tokenStore: store,
- userStore: userStore, // Initialize userStore
- }
- }
- // Login authenticates a user and generates tokens.
- func (s *JWTService) Login(ctx context.Context, email string, password string) (*TokenResponse, error) {
- user, err := s.userStore.GetUserByEmail(ctx, email)
- if err != nil {
- if models.IsErrNotFound(err) { // Use IsErrNotFound to check for ErrNotFound type
- return nil, models.NewErrUnauthorized("invalid_credentials", err) // Pass original err as cause
- }
- return nil, models.NewErrInternalServer("db_error_get_user", err)
- }
- if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
- return nil, ErrInvalidCredentials
- }
- // User is authenticated, now generate tokens
- return s.GenerateToken(ctx, fmt.Sprintf("%d", user.ID), string(user.Role))
- }
- // generateTokenWithClaims generates a JWT token with the given claims
- func (s *JWTService) generateTokenWithClaims(claims *Claims) (string, error) {
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- signedToken, err := token.SignedString(s.privateKey)
- if err != nil {
- return "", err
- }
- return signedToken, nil
- }
- // GenerateToken generates new access and refresh tokens for a user
- func (s *JWTService) GenerateToken(ctx context.Context, userID string, role string) (*TokenResponse, error) {
- now := time.Now()
- accessExpiration := now.Add(s.tokenDuration)
- refreshExpiration := now.Add(s.refreshTokenDuration)
- accessClaims := &Claims{
- UserID: userID,
- Role: role,
- Type: AccessToken,
- StandardClaims: jwt.StandardClaims{
- ExpiresAt: accessExpiration.Unix(),
- IssuedAt: now.Unix(),
- Subject: userID, // Set subject to userID
- },
- }
- refreshClaims := &Claims{
- UserID: userID,
- Role: role, // Role might not be strictly necessary in refresh token, but can be included
- Type: RefreshToken,
- StandardClaims: jwt.StandardClaims{
- ExpiresAt: refreshExpiration.Unix(),
- IssuedAt: now.Unix(),
- Subject: userID, // Set subject to userID
- },
- }
- accessToken, err := s.generateTokenWithClaims(accessClaims)
- if err != nil {
- return nil, ErrInvalidToken // Or a more specific generation error
- }
- refreshToken, err := s.generateTokenWithClaims(refreshClaims)
- if err != nil {
- return nil, ErrInvalidToken // Or a more specific generation error
- }
- return &TokenResponse{
- AccessToken: accessToken,
- RefreshToken: refreshToken,
- ExpiresIn: int64(s.tokenDuration.Seconds()),
- TokenType: "Bearer",
- UserID: userID,
- UserRole: role,
- }, nil
- }
- // ValidateToken validates a JWT token and returns the userID and role if valid
- func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, string, error) {
- isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString)
- if err != nil {
- // Consider wrapping this error or returning a standard auth error
- return "", "", err
- }
- if isBlacklisted {
- return "", "", ErrTokenBlacklisted
- }
- claims, err := s.parseToken(tokenString)
- if err != nil {
- return "", "", err // parseToken already returns auth-specific errors like ErrTokenExpired, ErrInvalidToken
- }
- if claims.Type != AccessToken {
- return "", "", ErrInvalidToken // Use standard error for wrong token type
- }
- return claims.UserID, claims.Role, nil
- }
- // parseToken parses and validates a JWT token and returns the claims if valid
- func (s *JWTService) parseToken(tokenString string) (*Claims, error) {
- token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
- }
- return s.privateKey, nil
- })
- if err != nil {
- if ve, ok := err.(*jwt.ValidationError); ok {
- if ve.Errors&jwt.ValidationErrorMalformed != 0 {
- return nil, ErrInvalidToken
- } else if ve.Errors&jwt.ValidationErrorExpired != 0 {
- return nil, ErrTokenExpired
- } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
- return nil, ErrInvalidToken // Or a specific "not valid yet" error
- }
- }
- return nil, ErrInvalidToken // Default for other parsing errors
- }
- claims, ok := token.Claims.(*Claims)
- if !ok || !token.Valid {
- return nil, ErrInvalidToken
- }
- return claims, nil
- }
- // RefreshToken validates a refresh token and generates new access and refresh tokens
- func (s *JWTService) RefreshToken(ctx context.Context, refreshTokenString string) (*TokenResponse, error) {
- isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, refreshTokenString)
- if err != nil {
- return nil, err
- }
- if isBlacklisted {
- return nil, ErrTokenBlacklisted
- }
- claims, err := s.parseToken(refreshTokenString)
- if err != nil {
- return nil, err
- }
- if claims.Type != RefreshToken {
- return nil, ErrInvalidToken // Not a refresh token
- }
- // Optional: Blacklist the old refresh token to prevent reuse if not already handled by single-use policy
- // expiryOldToken := time.Unix(claims.ExpiresAt, 0)
- // if err := s.tokenStore.Blacklist(ctx, refreshTokenString, expiryOldToken); err != nil {
- // return nil, err // Or log and continue if blacklisting is not critical path for refresh
- // }
- // Generate new access and refresh tokens using UserID and Role from the valid refresh token
- return s.GenerateToken(ctx, claims.UserID, claims.Role)
- }
- // Logout invalidates a token by blacklisting it
- func (s *JWTService) Logout(ctx context.Context, tokenString string) error {
- claims, err := s.parseToken(tokenString)
- if err != nil {
- // If token is already invalid (e.g. expired, malformed), blacklisting might not be necessary or might fail.
- // Depending on policy, might return nil or the parsing error.
- // For now, return the error to indicate why blacklisting couldn't proceed based on token data.
- return err
- }
- // Blacklist both access and potentially associated refresh tokens if strategy requires.
- // For simplicity, blacklisting the provided token (which should be an access token).
- expiry := time.Unix(claims.ExpiresAt, 0)
- if time.Now().After(expiry) {
- return ErrTokenExpired // No need to blacklist already expired token
- }
- return s.tokenStore.Blacklist(ctx, tokenString, expiry)
- }
|