sql.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package logger
  2. import (
  3. "database/sql/driver"
  4. "fmt"
  5. "reflect"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "unicode"
  11. "gorm.io/gorm/utils"
  12. )
  13. const (
  14. tmFmtWithMS = "2006-01-02 15:04:05.999"
  15. tmFmtZero = "0000-00-00 00:00:00"
  16. nullStr = "NULL"
  17. )
  18. func isPrintable(s string) bool {
  19. for _, r := range s {
  20. if !unicode.IsPrint(r) {
  21. return false
  22. }
  23. }
  24. return true
  25. }
  26. // A list of Go types that should be converted to SQL primitives
  27. var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})}
  28. // RegEx matches only numeric values
  29. var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`)
  30. func isNumeric(k reflect.Kind) bool {
  31. switch k {
  32. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  33. return true
  34. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  35. return true
  36. case reflect.Float32, reflect.Float64:
  37. return true
  38. default:
  39. return false
  40. }
  41. }
  42. // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability
  43. func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string {
  44. var (
  45. convertParams func(interface{}, int)
  46. vars = make([]string, len(avars))
  47. )
  48. convertParams = func(v interface{}, idx int) {
  49. switch v := v.(type) {
  50. case bool:
  51. vars[idx] = strconv.FormatBool(v)
  52. case time.Time:
  53. if v.IsZero() {
  54. vars[idx] = escaper + tmFmtZero + escaper
  55. } else {
  56. vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
  57. }
  58. case *time.Time:
  59. if v != nil {
  60. if v.IsZero() {
  61. vars[idx] = escaper + tmFmtZero + escaper
  62. } else {
  63. vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper
  64. }
  65. } else {
  66. vars[idx] = nullStr
  67. }
  68. case driver.Valuer:
  69. reflectValue := reflect.ValueOf(v)
  70. if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
  71. r, _ := v.Value()
  72. convertParams(r, idx)
  73. } else {
  74. vars[idx] = nullStr
  75. }
  76. case fmt.Stringer:
  77. reflectValue := reflect.ValueOf(v)
  78. switch reflectValue.Kind() {
  79. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  80. vars[idx] = fmt.Sprintf("%d", reflectValue.Interface())
  81. case reflect.Float32, reflect.Float64:
  82. vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface())
  83. case reflect.Bool:
  84. vars[idx] = fmt.Sprintf("%t", reflectValue.Interface())
  85. case reflect.String:
  86. vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
  87. default:
  88. if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) {
  89. vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper
  90. } else {
  91. vars[idx] = nullStr
  92. }
  93. }
  94. case []byte:
  95. if s := string(v); isPrintable(s) {
  96. vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper
  97. } else {
  98. vars[idx] = escaper + "<binary>" + escaper
  99. }
  100. case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
  101. vars[idx] = utils.ToString(v)
  102. case float32:
  103. vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32)
  104. case float64:
  105. vars[idx] = strconv.FormatFloat(v, 'f', -1, 64)
  106. case string:
  107. vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper
  108. default:
  109. rv := reflect.ValueOf(v)
  110. if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() {
  111. vars[idx] = nullStr
  112. } else if valuer, ok := v.(driver.Valuer); ok {
  113. v, _ = valuer.Value()
  114. convertParams(v, idx)
  115. } else if rv.Kind() == reflect.Ptr && !rv.IsZero() {
  116. convertParams(reflect.Indirect(rv).Interface(), idx)
  117. } else if isNumeric(rv.Kind()) {
  118. if rv.CanInt() || rv.CanUint() {
  119. vars[idx] = fmt.Sprintf("%d", rv.Interface())
  120. } else {
  121. vars[idx] = fmt.Sprintf("%.6f", rv.Interface())
  122. }
  123. } else {
  124. for _, t := range convertibleTypes {
  125. if rv.Type().ConvertibleTo(t) {
  126. convertParams(rv.Convert(t).Interface(), idx)
  127. return
  128. }
  129. }
  130. vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper
  131. }
  132. }
  133. }
  134. for idx, v := range avars {
  135. convertParams(v, idx)
  136. }
  137. if numericPlaceholder == nil {
  138. var idx int
  139. var newSQL strings.Builder
  140. for _, v := range []byte(sql) {
  141. if v == '?' {
  142. if len(vars) > idx {
  143. newSQL.WriteString(vars[idx])
  144. idx++
  145. continue
  146. }
  147. }
  148. newSQL.WriteByte(v)
  149. }
  150. sql = newSQL.String()
  151. } else {
  152. sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$")
  153. sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string {
  154. num := v[1 : len(v)-1]
  155. n, _ := strconv.Atoi(num)
  156. // position var start from 1 ($1, $2)
  157. n -= 1
  158. if n >= 0 && n <= len(vars)-1 {
  159. return vars[n]
  160. }
  161. return v
  162. })
  163. }
  164. return sql
  165. }