package middleware import ( "net/http" "strings" "time" "github.com/gin-gonic/gin" ) // SecurityHeaders adds security-related headers to all responses func SecurityHeaders() gin.HandlerFunc { return func(c *gin.Context) { // Security headers c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("X-XSS-Protection", "1; mode=block") c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains") c.Header("Content-Security-Policy", "default-src 'self'; frame-ancestors 'none'") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") c.Header("Permissions-Policy", "geolocation=(), microphone=(), camera=()") c.Next() } } // RequestSanitizer sanitizes incoming requests func RequestSanitizer() gin.HandlerFunc { return func(c *gin.Context) { // Sanitize headers sanitizeHeaders(c) // Block potentially dangerous file extensions if containsDangerousExtension(c.Request.URL.Path) { c.AbortWithStatus(http.StatusForbidden) return } c.Next() } } // TimeoutMiddleware adds a timeout to the request context func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc { return func(c *gin.Context) { // Wrap the request in a timeout ch := make(chan struct{}) go func() { c.Next() ch <- struct{}{} }() select { case <-ch: return case <-time.After(timeout): c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{ "code": "REQUEST_TIMEOUT", "error": "Request timeout exceeded", }) return } } } // Helper functions func sanitizeHeaders(c *gin.Context) { // Remove potentially dangerous headers c.Request.Header.Del("X-Forwarded-For") c.Request.Header.Del("X-Real-IP") c.Request.Header.Del("X-Forwarded-Proto") // Sanitize User-Agent if ua := c.Request.Header.Get("User-Agent"); ua != "" { c.Request.Header.Set("User-Agent", sanitizeString(ua)) } } func containsDangerousExtension(path string) bool { dangerous := []string{".php", ".asp", ".aspx", ".jsp", ".cgi", ".exe", ".bat", ".cmd", ".sh", ".pl"} path = strings.ToLower(path) for _, ext := range dangerous { if strings.HasSuffix(path, ext) { return true } } return false } func sanitizeString(s string) string { // Remove any control characters return strings.Map(func(r rune) rune { if r < 32 || r == 127 { return -1 } return r }, s) }