123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- package server
- import (
- "net/http"
- "runtime/debug"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/google/uuid"
- )
- func (s *Server) corsMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- origin := c.Request.Header.Get("Origin")
- if len(s.config.AllowedOrigins) == 0 {
- s.logger.Warn("no allowed origins configured")
- }
- allowed := false
- for _, allowedOrigin := range s.config.AllowedOrigins {
- if origin == allowedOrigin {
- c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
- allowed = true
- break
- }
- }
- if !allowed && origin != "" {
- s.logger.WithField("origin", origin).Warn("blocked request from unauthorized origin")
- c.AbortWithError(http.StatusForbidden, NewCORSError(origin, ErrUnauthorizedOrigin, "origin not allowed"))
- return
- }
- c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
- c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Request-ID")
- c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
- if c.Request.Method == "OPTIONS" {
- c.AbortWithStatus(http.StatusNoContent)
- return
- }
- c.Next()
- }
- }
- func (s *Server) requestIDMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- requestID := c.GetHeader("X-Request-ID")
- if requestID == "" {
- requestID = uuid.New().String()
- }
- c.Set("RequestID", requestID)
- c.Header("X-Request-ID", requestID)
- c.Next()
- }
- }
- func (s *Server) loggerMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- start := time.Now()
- path := c.Request.URL.Path
- query := c.Request.URL.RawQuery
- c.Next()
- s.logger.WithFields(map[string]interface{}{
- "status_code": c.Writer.Status(),
- "method": c.Request.Method,
- "path": path,
- "query": query,
- "ip": c.ClientIP(),
- "latency": time.Since(start),
- "request_id": c.GetString("RequestID"),
- "user_agent": c.Request.UserAgent(),
- }).Info("request completed")
- }
- }
- func (s *Server) recoveryMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- defer func() {
- if err := recover(); err != nil {
- requestID := c.GetString("RequestID")
- s.logger.WithFields(map[string]interface{}{
- "error": err,
- "request_id": requestID,
- "stack": string(debug.Stack()),
- }).Error("panic recovered")
- c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{
- "error": ErrInternalServer.Error(),
- "request_id": requestID,
- })
- }
- }()
- c.Next()
- }
- }
- func (s *Server) errorHandlerMiddleware() gin.HandlerFunc {
- return func(c *gin.Context) {
- c.Next()
- if len(c.Errors) > 0 {
- err := c.Errors.Last().Err
- requestID := c.GetString("RequestID")
- // Log the error with appropriate level and details
- s.logger.WithFields(map[string]interface{}{
- "error": err.Error(),
- "request_id": requestID,
- }).Error("request error")
- // Create a server error with request context
- serverErr := NewRequestError(
- "handle_request",
- c.Request.Method,
- c.Request.URL.Path,
- StatusCode(err),
- err,
- "request processing failed",
- )
- response := gin.H{
- "error": serverErr.Error(),
- "request_id": requestID,
- }
- c.JSON(serverErr.StatusCode, response)
- }
- }
- }
|