middleware.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package server
  2. import (
  3. "net/http"
  4. "runtime/debug"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/google/uuid"
  8. )
  9. func (s *Server) corsMiddleware() gin.HandlerFunc {
  10. return func(c *gin.Context) {
  11. origin := c.Request.Header.Get("Origin")
  12. if len(s.config.AllowedOrigins) == 0 {
  13. s.logger.Warn("no allowed origins configured")
  14. }
  15. allowed := false
  16. for _, allowedOrigin := range s.config.AllowedOrigins {
  17. if origin == allowedOrigin {
  18. c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
  19. allowed = true
  20. break
  21. }
  22. }
  23. if !allowed && origin != "" {
  24. s.logger.WithField("origin", origin).Warn("blocked request from unauthorized origin")
  25. c.AbortWithError(http.StatusForbidden, NewCORSError(origin, ErrUnauthorizedOrigin, "origin not allowed"))
  26. return
  27. }
  28. c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
  29. c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Request-ID")
  30. c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
  31. if c.Request.Method == "OPTIONS" {
  32. c.AbortWithStatus(http.StatusNoContent)
  33. return
  34. }
  35. c.Next()
  36. }
  37. }
  38. func (s *Server) requestIDMiddleware() gin.HandlerFunc {
  39. return func(c *gin.Context) {
  40. requestID := c.GetHeader("X-Request-ID")
  41. if requestID == "" {
  42. requestID = uuid.New().String()
  43. }
  44. c.Set("RequestID", requestID)
  45. c.Header("X-Request-ID", requestID)
  46. c.Next()
  47. }
  48. }
  49. func (s *Server) loggerMiddleware() gin.HandlerFunc {
  50. return func(c *gin.Context) {
  51. start := time.Now()
  52. path := c.Request.URL.Path
  53. query := c.Request.URL.RawQuery
  54. c.Next()
  55. s.logger.WithFields(map[string]interface{}{
  56. "status_code": c.Writer.Status(),
  57. "method": c.Request.Method,
  58. "path": path,
  59. "query": query,
  60. "ip": c.ClientIP(),
  61. "latency": time.Since(start),
  62. "request_id": c.GetString("RequestID"),
  63. "user_agent": c.Request.UserAgent(),
  64. }).Info("request completed")
  65. }
  66. }
  67. func (s *Server) recoveryMiddleware() gin.HandlerFunc {
  68. return func(c *gin.Context) {
  69. defer func() {
  70. if err := recover(); err != nil {
  71. requestID := c.GetString("RequestID")
  72. s.logger.WithFields(map[string]interface{}{
  73. "error": err,
  74. "request_id": requestID,
  75. "stack": string(debug.Stack()),
  76. }).Error("panic recovered")
  77. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
  78. "error": ErrInternalServer.Error(),
  79. "request_id": requestID,
  80. })
  81. }
  82. }()
  83. c.Next()
  84. }
  85. }
  86. func (s *Server) errorHandlerMiddleware() gin.HandlerFunc {
  87. return func(c *gin.Context) {
  88. c.Next()
  89. if len(c.Errors) > 0 {
  90. err := c.Errors.Last().Err
  91. requestID := c.GetString("RequestID")
  92. // Log the error with appropriate level and details
  93. s.logger.WithFields(map[string]interface{}{
  94. "error": err.Error(),
  95. "request_id": requestID,
  96. }).Error("request error")
  97. // Create a server error with request context
  98. serverErr := NewRequestError(
  99. "handle_request",
  100. c.Request.Method,
  101. c.Request.URL.Path,
  102. StatusCode(err),
  103. err,
  104. "request processing failed",
  105. )
  106. response := gin.H{
  107. "error": serverErr.Error(),
  108. "request_id": requestID,
  109. }
  110. c.JSON(serverErr.StatusCode, response)
  111. }
  112. }
  113. }