helper.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package callbacks
  2. import (
  3. "reflect"
  4. "sort"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. )
  8. // ConvertMapToValuesForCreate convert map to values
  9. func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) {
  10. values.Columns = make([]clause.Column, 0, len(mapValue))
  11. selectColumns, restricted := stmt.SelectAndOmitColumns(true, false)
  12. keys := make([]string, 0, len(mapValue))
  13. for k := range mapValue {
  14. keys = append(keys, k)
  15. }
  16. sort.Strings(keys)
  17. for _, k := range keys {
  18. value := mapValue[k]
  19. if stmt.Schema != nil {
  20. if field := stmt.Schema.LookUpField(k); field != nil {
  21. k = field.DBName
  22. }
  23. }
  24. if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
  25. values.Columns = append(values.Columns, clause.Column{Name: k})
  26. if len(values.Values) == 0 {
  27. values.Values = [][]interface{}{{}}
  28. }
  29. values.Values[0] = append(values.Values[0], value)
  30. }
  31. }
  32. return
  33. }
  34. // ConvertSliceOfMapToValuesForCreate convert slice of map to values
  35. func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) {
  36. columns := make([]string, 0, len(mapValues))
  37. // when the length of mapValues is zero,return directly here
  38. // no need to call stmt.SelectAndOmitColumns method
  39. if len(mapValues) == 0 {
  40. stmt.AddError(gorm.ErrEmptySlice)
  41. return
  42. }
  43. var (
  44. result = make(map[string][]interface{}, len(mapValues))
  45. selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
  46. )
  47. for idx, mapValue := range mapValues {
  48. for k, v := range mapValue {
  49. if stmt.Schema != nil {
  50. if field := stmt.Schema.LookUpField(k); field != nil {
  51. k = field.DBName
  52. }
  53. }
  54. if _, ok := result[k]; !ok {
  55. if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
  56. result[k] = make([]interface{}, len(mapValues))
  57. columns = append(columns, k)
  58. } else {
  59. continue
  60. }
  61. }
  62. result[k][idx] = v
  63. }
  64. }
  65. sort.Strings(columns)
  66. values.Values = make([][]interface{}, len(mapValues))
  67. values.Columns = make([]clause.Column, len(columns))
  68. for idx, column := range columns {
  69. values.Columns[idx] = clause.Column{Name: column}
  70. for i, v := range result[column] {
  71. if len(values.Values[i]) == 0 {
  72. values.Values[i] = make([]interface{}, len(columns))
  73. }
  74. values.Values[i][idx] = v
  75. }
  76. }
  77. return
  78. }
  79. func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
  80. if supportReturning {
  81. if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
  82. returning, _ := c.Expression.(clause.Returning)
  83. if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
  84. return true, 0
  85. }
  86. return true, gorm.ScanUpdate
  87. }
  88. }
  89. return false, 0
  90. }
  91. func checkMissingWhereConditions(db *gorm.DB) {
  92. if !db.AllowGlobalUpdate && db.Error == nil {
  93. where, withCondition := db.Statement.Clauses["WHERE"]
  94. if withCondition {
  95. if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete {
  96. whereClause, _ := where.Expression.(clause.Where)
  97. withCondition = len(whereClause.Exprs) > 1
  98. }
  99. }
  100. if !withCondition {
  101. db.AddError(gorm.ErrMissingWhereClause)
  102. }
  103. return
  104. }
  105. }
  106. type visitMap = map[reflect.Value]bool
  107. // Check if circular values, return true if loaded
  108. func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) {
  109. if v.Kind() == reflect.Ptr {
  110. v = v.Elem()
  111. }
  112. switch v.Kind() {
  113. case reflect.Slice, reflect.Array:
  114. loaded = true
  115. for i := 0; i < v.Len(); i++ {
  116. if !loadOrStoreVisitMap(visitMap, v.Index(i)) {
  117. loaded = false
  118. }
  119. }
  120. case reflect.Struct, reflect.Interface:
  121. if v.CanAddr() {
  122. p := v.Addr()
  123. if _, ok := (*visitMap)[p]; ok {
  124. return true
  125. }
  126. (*visitMap)[p] = true
  127. }
  128. }
  129. return
  130. }