middleware.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package server
  2. import (
  3. "net/http"
  4. "runtime/debug"
  5. "time"
  6. "git.linuxforward.com/byom/byom-golang-lib/pkg/errors"
  7. "github.com/gin-gonic/gin"
  8. "github.com/google/uuid"
  9. )
  10. func (s *Server) corsMiddleware() gin.HandlerFunc {
  11. return func(c *gin.Context) {
  12. origin := c.Request.Header.Get("Origin")
  13. if len(s.config.AllowedOrigins) == 0 {
  14. s.logger.Warn("no allowed origins configured")
  15. }
  16. allowed := false
  17. for _, allowedOrigin := range s.config.AllowedOrigins {
  18. if origin == allowedOrigin {
  19. c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
  20. allowed = true
  21. break
  22. }
  23. }
  24. if !allowed && origin != "" {
  25. s.logger.WithField("origin", origin).Warn("blocked request from unauthorized origin")
  26. }
  27. c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
  28. c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Request-ID")
  29. c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
  30. if c.Request.Method == "OPTIONS" {
  31. c.AbortWithStatus(http.StatusNoContent)
  32. return
  33. }
  34. c.Next()
  35. }
  36. }
  37. func (s *Server) requestIDMiddleware() gin.HandlerFunc {
  38. return func(c *gin.Context) {
  39. requestID := c.GetHeader("X-Request-ID")
  40. if requestID == "" {
  41. requestID = uuid.New().String()
  42. }
  43. c.Set("RequestID", requestID)
  44. c.Header("X-Request-ID", requestID)
  45. c.Next()
  46. }
  47. }
  48. func (s *Server) loggerMiddleware() gin.HandlerFunc {
  49. return func(c *gin.Context) {
  50. start := time.Now()
  51. path := c.Request.URL.Path
  52. query := c.Request.URL.RawQuery
  53. c.Next()
  54. s.logger.WithFields(map[string]interface{}{
  55. "status_code": c.Writer.Status(),
  56. "method": c.Request.Method,
  57. "path": path,
  58. "query": query,
  59. "ip": c.ClientIP(),
  60. "latency": time.Since(start),
  61. "request_id": c.GetString("RequestID"),
  62. "user_agent": c.Request.UserAgent(),
  63. }).Info("request completed")
  64. }
  65. }
  66. func (s *Server) recoveryMiddleware() gin.HandlerFunc {
  67. return func(c *gin.Context) {
  68. defer func() {
  69. if err := recover(); err != nil {
  70. requestID := c.GetString("RequestID")
  71. s.logger.WithFields(map[string]interface{}{
  72. "error": err,
  73. "request_id": requestID,
  74. "stack": string(debug.Stack()),
  75. }).Error("panic recovered")
  76. c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
  77. "error": errors.ErrInternalServer.Error(),
  78. "request_id": requestID,
  79. })
  80. }
  81. }()
  82. c.Next()
  83. }
  84. }
  85. func (s *Server) errorHandlerMiddleware() gin.HandlerFunc {
  86. return func(c *gin.Context) {
  87. c.Next()
  88. if len(c.Errors) > 0 {
  89. err := c.Errors.Last().Err
  90. requestID := c.GetString("RequestID")
  91. // Log the error with appropriate level and details
  92. s.logger.WithFields(map[string]interface{}{
  93. "error": err.Error(),
  94. "request_id": requestID,
  95. }).Error("request error")
  96. // Determine status code based on error type
  97. var statusCode int
  98. var response gin.H
  99. switch {
  100. case errors.Is(err, errors.ErrNotFound):
  101. statusCode = http.StatusNotFound
  102. case errors.Is(err, errors.ErrUnauthorized):
  103. statusCode = http.StatusUnauthorized
  104. case errors.Is(err, errors.ErrInvalidInput):
  105. statusCode = http.StatusBadRequest
  106. default:
  107. statusCode = http.StatusInternalServerError
  108. }
  109. response = gin.H{
  110. "error": err.Error(),
  111. "request_id": requestID,
  112. }
  113. c.JSON(statusCode, response)
  114. }
  115. }
  116. }