middleware.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. package server
  2. import (
  3. "net/http"
  4. "runtime/debug"
  5. "time"
  6. "git.linuxforward.com/byom/byom-golang-lib/pkg/validation"
  7. "github.com/gin-gonic/gin"
  8. "github.com/google/uuid"
  9. )
  10. func (s *Server) corsMiddleware() gin.HandlerFunc {
  11. return func(c *gin.Context) {
  12. origin := c.Request.Header.Get("Origin")
  13. if len(s.config.AllowedOrigins) == 0 {
  14. s.logger.Warn("no allowed origins configured")
  15. }
  16. allowed := false
  17. for _, allowedOrigin := range s.config.AllowedOrigins {
  18. if origin == allowedOrigin {
  19. c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
  20. allowed = true
  21. break
  22. }
  23. }
  24. if !allowed && origin != "" {
  25. s.logger.WithField("origin", origin).Warn("blocked request from unauthorized origin")
  26. c.AbortWithError(http.StatusForbidden, NewCORSError(origin, ErrUnauthorizedOrigin, "origin not allowed"))
  27. return
  28. }
  29. c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
  30. c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Request-ID")
  31. c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
  32. if c.Request.Method == "OPTIONS" {
  33. c.AbortWithStatus(http.StatusNoContent)
  34. return
  35. }
  36. c.Next()
  37. }
  38. }
  39. func (s *Server) requestIDMiddleware() gin.HandlerFunc {
  40. return func(c *gin.Context) {
  41. requestID := c.GetHeader("X-Request-ID")
  42. if requestID == "" {
  43. requestID = uuid.New().String()
  44. }
  45. c.Set("RequestID", requestID)
  46. c.Header("X-Request-ID", requestID)
  47. c.Next()
  48. }
  49. }
  50. func (s *Server) loggerMiddleware() gin.HandlerFunc {
  51. return func(c *gin.Context) {
  52. start := time.Now()
  53. path := c.Request.URL.Path
  54. query := c.Request.URL.RawQuery
  55. c.Next()
  56. s.logger.WithFields(map[string]interface{}{
  57. "status_code": c.Writer.Status(),
  58. "method": c.Request.Method,
  59. "path": path,
  60. "query": query,
  61. "ip": c.ClientIP(),
  62. "latency": time.Since(start),
  63. "request_id": c.GetString("RequestID"),
  64. "user_agent": c.Request.UserAgent(),
  65. }).Info("request completed")
  66. }
  67. }
  68. func (s *Server) recoveryMiddleware() gin.HandlerFunc {
  69. return func(c *gin.Context) {
  70. defer func() {
  71. if err := recover(); err != nil {
  72. requestID := c.GetString("RequestID")
  73. s.logger.WithFields(map[string]interface{}{
  74. "error": err,
  75. "request_id": requestID,
  76. "stack": string(debug.Stack()),
  77. }).Error("panic recovered")
  78. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
  79. "error": ErrInternalServer.Error(),
  80. "request_id": requestID,
  81. })
  82. }
  83. }()
  84. c.Next()
  85. }
  86. }
  87. func (s *Server) errorHandlerMiddleware() gin.HandlerFunc {
  88. return func(c *gin.Context) {
  89. c.Next()
  90. if len(c.Errors) > 0 {
  91. err := c.Errors.Last().Err
  92. requestID := c.GetString("RequestID")
  93. // Log the error with appropriate level and details
  94. s.logger.WithFields(map[string]interface{}{
  95. "error": err.Error(),
  96. "request_id": requestID,
  97. }).Error("request error")
  98. // Create a server error with request context
  99. serverErr := NewRequestError(
  100. "handle_request",
  101. c.Request.Method,
  102. c.Request.URL.Path,
  103. StatusCode(err),
  104. err,
  105. "request processing failed",
  106. )
  107. response := gin.H{
  108. "error": serverErr.Error(),
  109. "request_id": requestID,
  110. }
  111. c.JSON(serverErr.StatusCode, response)
  112. }
  113. }
  114. }
  115. // ValidateRequest returns a middleware function that validates request bodies against
  116. // the specified request type. It integrates with the gin context for error handling and logging.
  117. func ValidateRequest(validateFn func(interface{}) []validation.ValidationError) gin.HandlerFunc {
  118. return func(c *gin.Context) {
  119. requestID := c.GetString("RequestID")
  120. var req interface{}
  121. // Bind JSON request body
  122. if err := c.ShouldBindJSON(&req); err != nil {
  123. bindErr := NewRequestError(
  124. "validate_request",
  125. c.Request.Method,
  126. c.Request.URL.Path,
  127. http.StatusBadRequest,
  128. err,
  129. "invalid request format",
  130. )
  131. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
  132. "error": bindErr.Error(),
  133. "request_id": requestID,
  134. "errors": []validation.ValidationError{{Field: "body", Message: err.Error()}},
  135. })
  136. return
  137. }
  138. // Validate request struct
  139. if errors := validateFn(req); len(errors) > 0 {
  140. validationErr := NewRequestError(
  141. "validate_request",
  142. c.Request.Method,
  143. c.Request.URL.Path,
  144. http.StatusBadRequest,
  145. ErrValidationFailed,
  146. "validation failed",
  147. )
  148. c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{
  149. "error": validationErr.Error(),
  150. "request_id": requestID,
  151. "errors": errors,
  152. })
  153. return
  154. }
  155. // Store validated request in context
  156. c.Set("ValidatedRequest", req)
  157. c.Next()
  158. }
  159. }
  160. // GetValidatedRequest retrieves the validated request from the gin context.
  161. // Returns the request and true if found and type matches, nil and false otherwise.
  162. func GetValidatedRequest[T any](c *gin.Context) (*T, bool) {
  163. if req, exists := c.Get("ValidatedRequest"); exists {
  164. if typed, ok := req.(*T); ok {
  165. return typed, true
  166. }
  167. }
  168. return nil, false
  169. }