update.go 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. package callbacks
  2. import (
  3. "reflect"
  4. "sort"
  5. "gorm.io/gorm"
  6. "gorm.io/gorm/clause"
  7. "gorm.io/gorm/schema"
  8. "gorm.io/gorm/utils"
  9. )
  10. func SetupUpdateReflectValue(db *gorm.DB) {
  11. if db.Error == nil && db.Statement.Schema != nil {
  12. if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest {
  13. db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model)
  14. for db.Statement.ReflectValue.Kind() == reflect.Ptr {
  15. db.Statement.ReflectValue = db.Statement.ReflectValue.Elem()
  16. }
  17. if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
  18. for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
  19. if _, ok := dest[rel.Name]; ok {
  20. db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name]))
  21. }
  22. }
  23. }
  24. }
  25. }
  26. }
  27. // BeforeUpdate before update hooks
  28. func BeforeUpdate(db *gorm.DB) {
  29. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
  30. callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
  31. if db.Statement.Schema.BeforeSave {
  32. if i, ok := value.(BeforeSaveInterface); ok {
  33. called = true
  34. db.AddError(i.BeforeSave(tx))
  35. }
  36. }
  37. if db.Statement.Schema.BeforeUpdate {
  38. if i, ok := value.(BeforeUpdateInterface); ok {
  39. called = true
  40. db.AddError(i.BeforeUpdate(tx))
  41. }
  42. }
  43. return called
  44. })
  45. }
  46. }
  47. // Update update hook
  48. func Update(config *Config) func(db *gorm.DB) {
  49. supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")
  50. return func(db *gorm.DB) {
  51. if db.Error != nil {
  52. return
  53. }
  54. if db.Statement.Schema != nil {
  55. for _, c := range db.Statement.Schema.UpdateClauses {
  56. db.Statement.AddClause(c)
  57. }
  58. }
  59. if db.Statement.SQL.Len() == 0 {
  60. db.Statement.SQL.Grow(180)
  61. db.Statement.AddClauseIfNotExists(clause.Update{})
  62. if _, ok := db.Statement.Clauses["SET"]; !ok {
  63. if set := ConvertToAssignments(db.Statement); len(set) != 0 {
  64. defer delete(db.Statement.Clauses, "SET")
  65. db.Statement.AddClause(set)
  66. } else {
  67. return
  68. }
  69. }
  70. db.Statement.Build(db.Statement.BuildClauses...)
  71. }
  72. checkMissingWhereConditions(db)
  73. if !db.DryRun && db.Error == nil {
  74. if ok, mode := hasReturning(db, supportReturning); ok {
  75. if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
  76. dest := db.Statement.Dest
  77. db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface()
  78. gorm.Scan(rows, db, mode)
  79. db.Statement.Dest = dest
  80. db.AddError(rows.Close())
  81. }
  82. } else {
  83. result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
  84. if db.AddError(err) == nil {
  85. db.RowsAffected, _ = result.RowsAffected()
  86. }
  87. }
  88. }
  89. }
  90. }
  91. // AfterUpdate after update hooks
  92. func AfterUpdate(db *gorm.DB) {
  93. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
  94. callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
  95. if db.Statement.Schema.AfterUpdate {
  96. if i, ok := value.(AfterUpdateInterface); ok {
  97. called = true
  98. db.AddError(i.AfterUpdate(tx))
  99. }
  100. }
  101. if db.Statement.Schema.AfterSave {
  102. if i, ok := value.(AfterSaveInterface); ok {
  103. called = true
  104. db.AddError(i.AfterSave(tx))
  105. }
  106. }
  107. return called
  108. })
  109. }
  110. }
  111. // ConvertToAssignments convert to update assignments
  112. func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
  113. var (
  114. selectColumns, restricted = stmt.SelectAndOmitColumns(false, true)
  115. assignValue func(field *schema.Field, value interface{})
  116. )
  117. switch stmt.ReflectValue.Kind() {
  118. case reflect.Slice, reflect.Array:
  119. assignValue = func(field *schema.Field, value interface{}) {
  120. for i := 0; i < stmt.ReflectValue.Len(); i++ {
  121. if stmt.ReflectValue.CanAddr() {
  122. field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)
  123. }
  124. }
  125. }
  126. case reflect.Struct:
  127. assignValue = func(field *schema.Field, value interface{}) {
  128. if stmt.ReflectValue.CanAddr() {
  129. field.Set(stmt.Context, stmt.ReflectValue, value)
  130. }
  131. }
  132. default:
  133. assignValue = func(field *schema.Field, value interface{}) {
  134. }
  135. }
  136. updatingValue := reflect.ValueOf(stmt.Dest)
  137. for updatingValue.Kind() == reflect.Ptr {
  138. updatingValue = updatingValue.Elem()
  139. }
  140. if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
  141. switch stmt.ReflectValue.Kind() {
  142. case reflect.Slice, reflect.Array:
  143. if size := stmt.ReflectValue.Len(); size > 0 {
  144. var isZero bool
  145. for i := 0; i < size; i++ {
  146. for _, field := range stmt.Schema.PrimaryFields {
  147. _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i))
  148. if !isZero {
  149. break
  150. }
  151. }
  152. }
  153. if !isZero {
  154. _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields)
  155. column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues)
  156. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
  157. }
  158. }
  159. case reflect.Struct:
  160. for _, field := range stmt.Schema.PrimaryFields {
  161. if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
  162. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
  163. }
  164. }
  165. }
  166. }
  167. switch value := updatingValue.Interface().(type) {
  168. case map[string]interface{}:
  169. set = make([]clause.Assignment, 0, len(value))
  170. keys := make([]string, 0, len(value))
  171. for k := range value {
  172. keys = append(keys, k)
  173. }
  174. sort.Strings(keys)
  175. for _, k := range keys {
  176. kv := value[k]
  177. if _, ok := kv.(*gorm.DB); ok {
  178. kv = []interface{}{kv}
  179. }
  180. if stmt.Schema != nil {
  181. if field := stmt.Schema.LookUpField(k); field != nil {
  182. if field.DBName != "" {
  183. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
  184. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv})
  185. assignValue(field, value[k])
  186. }
  187. } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) {
  188. assignValue(field, value[k])
  189. }
  190. continue
  191. }
  192. }
  193. if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) {
  194. set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv})
  195. }
  196. }
  197. if !stmt.SkipHooks && stmt.Schema != nil {
  198. for _, dbName := range stmt.Schema.DBNames {
  199. field := stmt.Schema.LookUpField(dbName)
  200. if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
  201. if v, ok := selectColumns[field.DBName]; (ok && v) || !ok {
  202. now := stmt.DB.NowFunc()
  203. assignValue(field, now)
  204. if field.AutoUpdateTime == schema.UnixNanosecond {
  205. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()})
  206. } else if field.AutoUpdateTime == schema.UnixMillisecond {
  207. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()})
  208. } else if field.AutoUpdateTime == schema.UnixSecond {
  209. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()})
  210. } else {
  211. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now})
  212. }
  213. }
  214. }
  215. }
  216. }
  217. default:
  218. updatingSchema := stmt.Schema
  219. var isDiffSchema bool
  220. if !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
  221. // different schema
  222. updatingStmt := &gorm.Statement{DB: stmt.DB}
  223. if err := updatingStmt.Parse(stmt.Dest); err == nil {
  224. updatingSchema = updatingStmt.Schema
  225. isDiffSchema = true
  226. }
  227. }
  228. switch updatingValue.Kind() {
  229. case reflect.Struct:
  230. set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName))
  231. for _, dbName := range stmt.Schema.DBNames {
  232. if field := updatingSchema.LookUpField(dbName); field != nil {
  233. if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model {
  234. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) {
  235. value, isZero := field.ValueOf(stmt.Context, updatingValue)
  236. if !stmt.SkipHooks && field.AutoUpdateTime > 0 {
  237. if field.AutoUpdateTime == schema.UnixNanosecond {
  238. value = stmt.DB.NowFunc().UnixNano()
  239. } else if field.AutoUpdateTime == schema.UnixMillisecond {
  240. value = stmt.DB.NowFunc().UnixMilli()
  241. } else if field.AutoUpdateTime == schema.UnixSecond {
  242. value = stmt.DB.NowFunc().Unix()
  243. } else {
  244. value = stmt.DB.NowFunc()
  245. }
  246. isZero = false
  247. }
  248. if (ok || !isZero) && field.Updatable {
  249. set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value})
  250. assignField := field
  251. if isDiffSchema {
  252. if originField := stmt.Schema.LookUpField(dbName); originField != nil {
  253. assignField = originField
  254. }
  255. }
  256. assignValue(assignField, value)
  257. }
  258. }
  259. } else {
  260. if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero {
  261. stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}})
  262. }
  263. }
  264. }
  265. }
  266. default:
  267. stmt.AddError(gorm.ErrInvalidData)
  268. }
  269. }
  270. return
  271. }