123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- package auth
- import (
- "context"
- "errors"
- "time"
- "github.com/golang-jwt/jwt"
- )
- // Claims represents the JWT claims structure
- type Claims struct {
- jwt.StandardClaims
- ClientID string `json:"client_id"`
- Role string `json:"role"`
- }
- // JWTService implements the auth.Service interface using JWT tokens
- type JWTService struct {
- privateKey []byte
- tokenDuration time.Duration
- tokenStore TokenStore // Interface for blacklist storage
- }
- // TokenStore defines storage operations for token management
- type TokenStore interface {
- IsBlacklisted(ctx context.Context, token string) (bool, error)
- Blacklist(ctx context.Context, token string, expiry time.Time) error
- }
- // NewJWTService creates a new JWT-based auth service
- func NewJWTService(privateKey []byte, tokenDuration time.Duration, store TokenStore) *JWTService {
- return &JWTService{
- privateKey: privateKey,
- tokenDuration: tokenDuration,
- tokenStore: store,
- }
- }
- // GenerateToken generates a new JWT token for a client
- func (s *JWTService) GenerateToken(ctx context.Context, clientID string, role string) (string, error) {
- expirationTime := time.Now().Add(s.tokenDuration)
- claims := &Claims{
- ClientID: clientID,
- Role: role,
- StandardClaims: jwt.StandardClaims{
- ExpiresAt: expirationTime.Unix(),
- },
- }
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
- signedToken, err := token.SignedString(s.privateKey)
- if err != nil {
- return "", err
- }
- return signedToken, nil
- }
- // ValidateToken validates a JWT token and returns the client ID if valid
- func (s *JWTService) ValidateToken(ctx context.Context, tokenString string) (string, error) {
- // Check if the token is blacklisted
- isBlacklisted, err := s.tokenStore.IsBlacklisted(ctx, tokenString)
- if err != nil {
- return "", err
- }
- if isBlacklisted {
- return "", errors.New("token is blacklisted")
- }
- // Parse and validate the token
- token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
- return nil, errors.New("unexpected signing method")
- }
- return s.privateKey, nil
- })
- if err != nil {
- return "", err
- }
- // Extract claims
- claims, ok := token.Claims.(*Claims)
- if !ok || !token.Valid {
- return "", errors.New("invalid token")
- }
- return claims.ClientID, nil
- }
- // RefreshToken refreshes an existing token and returns a new one
- func (s *JWTService) RefreshToken(ctx context.Context, tokenString string) (string, error) {
- // Validate the existing token
- clientID, err := s.ValidateToken(ctx, tokenString)
- if err != nil {
- return "", err
- }
- // Generate a new token
- return s.GenerateToken(ctx, clientID, "")
- }
- // Logout invalidates a token by blacklisting it
- func (s *JWTService) Logout(ctx context.Context, tokenString string) error {
- // Parse the token to extract expiration time
- token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
- return s.privateKey, nil
- })
- if err != nil {
- return err
- }
- claims, ok := token.Claims.(*Claims)
- if !ok || !token.Valid {
- return errors.New("invalid token")
- }
- expiry := time.Unix(claims.ExpiresAt, 0)
- return s.tokenStore.Blacklist(ctx, tokenString, expiry)
- }
|