package server import ( "net/http" "runtime/debug" "time" "git.linuxforward.com/byom/byom-golang-lib/pkg/validation" "github.com/gin-gonic/gin" "github.com/google/uuid" ) func (s *Server) corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { origin := c.Request.Header.Get("Origin") if len(s.config.AllowedOrigins) == 0 { s.logger.Warn("no allowed origins configured") } allowed := false for _, allowedOrigin := range s.config.AllowedOrigins { if origin == allowedOrigin { c.Writer.Header().Set("Access-Control-Allow-Origin", origin) allowed = true break } } if !allowed && origin != "" { s.logger.WithField("origin", origin).Warn("blocked request from unauthorized origin") c.AbortWithError(http.StatusForbidden, NewCORSError(origin, ErrUnauthorizedOrigin, "origin not allowed")) return } c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Request-ID") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } func (s *Server) requestIDMiddleware() gin.HandlerFunc { return func(c *gin.Context) { requestID := c.GetHeader("X-Request-ID") if requestID == "" { requestID = uuid.New().String() } c.Set("RequestID", requestID) c.Header("X-Request-ID", requestID) c.Next() } } func (s *Server) loggerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() path := c.Request.URL.Path query := c.Request.URL.RawQuery c.Next() s.logger.WithFields(map[string]interface{}{ "status_code": c.Writer.Status(), "method": c.Request.Method, "path": path, "query": query, "ip": c.ClientIP(), "latency": time.Since(start), "request_id": c.GetString("RequestID"), "user_agent": c.Request.UserAgent(), }).Info("request completed") } } func (s *Server) recoveryMiddleware() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { requestID := c.GetString("RequestID") s.logger.WithFields(map[string]interface{}{ "error": err, "request_id": requestID, "stack": string(debug.Stack()), }).Error("panic recovered") c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ "error": ErrInternalServer.Error(), "request_id": requestID, }) } }() c.Next() } } func (s *Server) errorHandlerMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Next() if len(c.Errors) > 0 { err := c.Errors.Last().Err requestID := c.GetString("RequestID") // Log the error with appropriate level and details s.logger.WithFields(map[string]interface{}{ "error": err.Error(), "request_id": requestID, }).Error("request error") // Create a server error with request context serverErr := NewRequestError( "handle_request", c.Request.Method, c.Request.URL.Path, StatusCode(err), err, "request processing failed", ) response := gin.H{ "error": serverErr.Error(), "request_id": requestID, } c.JSON(serverErr.StatusCode, response) } } } // ValidateRequest returns a middleware function that validates request bodies against // the specified request type. It integrates with the gin context for error handling and logging. func ValidateRequest(validateFn func(interface{}) []validation.ValidationError) gin.HandlerFunc { return func(c *gin.Context) { requestID := c.GetString("RequestID") var req interface{} // Bind JSON request body if err := c.ShouldBindJSON(&req); err != nil { bindErr := NewRequestError( "validate_request", c.Request.Method, c.Request.URL.Path, http.StatusBadRequest, err, "invalid request format", ) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ "error": bindErr.Error(), "request_id": requestID, "errors": []validation.ValidationError{{Field: "body", Message: err.Error()}}, }) return } // Validate request struct if errors := validateFn(req); len(errors) > 0 { validationErr := NewRequestError( "validate_request", c.Request.Method, c.Request.URL.Path, http.StatusBadRequest, ErrValidationFailed, "validation failed", ) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{ "error": validationErr.Error(), "request_id": requestID, "errors": errors, }) return } // Store validated request in context c.Set("ValidatedRequest", req) c.Next() } } // GetValidatedRequest retrieves the validated request from the gin context. // Returns the request and true if found and type matches, nil and false otherwise. func GetValidatedRequest[T any](c *gin.Context) (*T, bool) { if req, exists := c.Get("ValidatedRequest"); exists { if typed, ok := req.(*T); ok { return typed, true } } return nil, false }