1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- package middleware
- import (
- "net/http"
- "strings"
- "time"
- "github.com/gin-gonic/gin"
- )
- // SecurityHeaders adds security-related headers to all responses
- func SecurityHeaders() gin.HandlerFunc {
- return func(c *gin.Context) {
- // Security headers
- c.Header("X-Content-Type-Options", "nosniff")
- c.Header("X-Frame-Options", "DENY")
- c.Header("X-XSS-Protection", "1; mode=block")
- c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
- c.Header("Content-Security-Policy", "default-src 'self'; frame-ancestors 'none'")
- c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
- c.Header("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
- c.Next()
- }
- }
- // RequestSanitizer sanitizes incoming requests
- func RequestSanitizer() gin.HandlerFunc {
- return func(c *gin.Context) {
- // Sanitize headers
- sanitizeHeaders(c)
- // Block potentially dangerous file extensions
- if containsDangerousExtension(c.Request.URL.Path) {
- c.AbortWithStatus(http.StatusForbidden)
- return
- }
- c.Next()
- }
- }
- // TimeoutMiddleware adds a timeout to the request context
- func TimeoutMiddleware(timeout time.Duration) gin.HandlerFunc {
- return func(c *gin.Context) {
- // Wrap the request in a timeout
- ch := make(chan struct{})
- go func() {
- c.Next()
- ch <- struct{}{}
- }()
- select {
- case <-ch:
- return
- case <-time.After(timeout):
- c.AbortWithStatusJSON(http.StatusGatewayTimeout, gin.H{
- "code": "REQUEST_TIMEOUT",
- "error": "Request timeout exceeded",
- })
- return
- }
- }
- }
- // Helper functions
- func sanitizeHeaders(c *gin.Context) {
- // Remove potentially dangerous headers
- c.Request.Header.Del("X-Forwarded-For")
- c.Request.Header.Del("X-Real-IP")
- c.Request.Header.Del("X-Forwarded-Proto")
- // Sanitize User-Agent
- if ua := c.Request.Header.Get("User-Agent"); ua != "" {
- c.Request.Header.Set("User-Agent", sanitizeString(ua))
- }
- }
- func containsDangerousExtension(path string) bool {
- dangerous := []string{".php", ".asp", ".aspx", ".jsp", ".cgi", ".exe", ".bat", ".cmd", ".sh", ".pl"}
- path = strings.ToLower(path)
- for _, ext := range dangerous {
- if strings.HasSuffix(path, ext) {
- return true
- }
- }
- return false
- }
- func sanitizeString(s string) string {
- // Remove any control characters
- return strings.Map(func(r rune) rune {
- if r < 32 || r == 127 {
- return -1
- }
- return r
- }, s)
- }
|