create.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. package callbacks
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/schema"
  9. "gorm.io/gorm/utils"
  10. )
  11. // BeforeCreate before create hooks
  12. func BeforeCreate(db *gorm.DB) {
  13. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
  14. callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
  15. if db.Statement.Schema.BeforeSave {
  16. if i, ok := value.(BeforeSaveInterface); ok {
  17. called = true
  18. db.AddError(i.BeforeSave(tx))
  19. }
  20. }
  21. if db.Statement.Schema.BeforeCreate {
  22. if i, ok := value.(BeforeCreateInterface); ok {
  23. called = true
  24. db.AddError(i.BeforeCreate(tx))
  25. }
  26. }
  27. return called
  28. })
  29. }
  30. }
  31. // Create create hook
  32. func Create(config *Config) func(db *gorm.DB) {
  33. supportReturning := utils.Contains(config.CreateClauses, "RETURNING")
  34. return func(db *gorm.DB) {
  35. if db.Error != nil {
  36. return
  37. }
  38. if db.Statement.Schema != nil {
  39. if !db.Statement.Unscoped {
  40. for _, c := range db.Statement.Schema.CreateClauses {
  41. db.Statement.AddClause(c)
  42. }
  43. }
  44. if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
  45. if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
  46. fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
  47. for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
  48. fromColumns = append(fromColumns, clause.Column{Name: field.DBName})
  49. }
  50. db.Statement.AddClause(clause.Returning{Columns: fromColumns})
  51. }
  52. }
  53. }
  54. if db.Statement.SQL.Len() == 0 {
  55. db.Statement.SQL.Grow(180)
  56. db.Statement.AddClauseIfNotExists(clause.Insert{})
  57. db.Statement.AddClause(ConvertToCreateValues(db.Statement))
  58. db.Statement.Build(db.Statement.BuildClauses...)
  59. }
  60. isDryRun := !db.DryRun && db.Error == nil
  61. if !isDryRun {
  62. return
  63. }
  64. ok, mode := hasReturning(db, supportReturning)
  65. if ok {
  66. if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
  67. if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
  68. mode |= gorm.ScanOnConflictDoNothing
  69. }
  70. }
  71. rows, err := db.Statement.ConnPool.QueryContext(
  72. db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
  73. )
  74. if db.AddError(err) == nil {
  75. defer func() {
  76. db.AddError(rows.Close())
  77. }()
  78. gorm.Scan(rows, db, mode)
  79. }
  80. return
  81. }
  82. result, err := db.Statement.ConnPool.ExecContext(
  83. db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
  84. )
  85. if err != nil {
  86. db.AddError(err)
  87. return
  88. }
  89. db.RowsAffected, _ = result.RowsAffected()
  90. if db.RowsAffected == 0 {
  91. return
  92. }
  93. var (
  94. pkField *schema.Field
  95. pkFieldName = "@id"
  96. )
  97. insertID, err := result.LastInsertId()
  98. insertOk := err == nil && insertID > 0
  99. if !insertOk {
  100. if !supportReturning {
  101. db.AddError(err)
  102. }
  103. return
  104. }
  105. if db.Statement.Schema != nil {
  106. if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
  107. return
  108. }
  109. pkField = db.Statement.Schema.PrioritizedPrimaryField
  110. pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName
  111. }
  112. // append @id column with value for auto-increment primary key
  113. // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1
  114. switch values := db.Statement.Dest.(type) {
  115. case map[string]interface{}:
  116. values[pkFieldName] = insertID
  117. case *map[string]interface{}:
  118. (*values)[pkFieldName] = insertID
  119. case []map[string]interface{}, *[]map[string]interface{}:
  120. mapValues, ok := values.([]map[string]interface{})
  121. if !ok {
  122. if v, ok := values.(*[]map[string]interface{}); ok {
  123. if *v != nil {
  124. mapValues = *v
  125. }
  126. }
  127. }
  128. if config.LastInsertIDReversed {
  129. insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement
  130. }
  131. for _, mapValue := range mapValues {
  132. if mapValue != nil {
  133. mapValue[pkFieldName] = insertID
  134. }
  135. insertID += schema.DefaultAutoIncrementIncrement
  136. }
  137. default:
  138. if pkField == nil {
  139. return
  140. }
  141. switch db.Statement.ReflectValue.Kind() {
  142. case reflect.Slice, reflect.Array:
  143. if config.LastInsertIDReversed {
  144. for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
  145. rv := db.Statement.ReflectValue.Index(i)
  146. if reflect.Indirect(rv).Kind() != reflect.Struct {
  147. break
  148. }
  149. _, isZero := pkField.ValueOf(db.Statement.Context, rv)
  150. if isZero {
  151. db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
  152. insertID -= pkField.AutoIncrementIncrement
  153. }
  154. }
  155. } else {
  156. for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
  157. rv := db.Statement.ReflectValue.Index(i)
  158. if reflect.Indirect(rv).Kind() != reflect.Struct {
  159. break
  160. }
  161. if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero {
  162. db.AddError(pkField.Set(db.Statement.Context, rv, insertID))
  163. insertID += pkField.AutoIncrementIncrement
  164. }
  165. }
  166. }
  167. case reflect.Struct:
  168. _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
  169. if isZero {
  170. db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
  171. }
  172. }
  173. }
  174. }
  175. }
  176. // AfterCreate after create hooks
  177. func AfterCreate(db *gorm.DB) {
  178. if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
  179. callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
  180. if db.Statement.Schema.AfterCreate {
  181. if i, ok := value.(AfterCreateInterface); ok {
  182. called = true
  183. db.AddError(i.AfterCreate(tx))
  184. }
  185. }
  186. if db.Statement.Schema.AfterSave {
  187. if i, ok := value.(AfterSaveInterface); ok {
  188. called = true
  189. db.AddError(i.AfterSave(tx))
  190. }
  191. }
  192. return called
  193. })
  194. }
  195. }
  196. // ConvertToCreateValues convert to create values
  197. func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
  198. curTime := stmt.DB.NowFunc()
  199. switch value := stmt.Dest.(type) {
  200. case map[string]interface{}:
  201. values = ConvertMapToValuesForCreate(stmt, value)
  202. case *map[string]interface{}:
  203. values = ConvertMapToValuesForCreate(stmt, *value)
  204. case []map[string]interface{}:
  205. values = ConvertSliceOfMapToValuesForCreate(stmt, value)
  206. case *[]map[string]interface{}:
  207. values = ConvertSliceOfMapToValuesForCreate(stmt, *value)
  208. default:
  209. var (
  210. selectColumns, restricted = stmt.SelectAndOmitColumns(true, false)
  211. _, updateTrackTime = stmt.Get("gorm:update_track_time")
  212. isZero bool
  213. )
  214. stmt.Settings.Delete("gorm:update_track_time")
  215. values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
  216. for _, db := range stmt.Schema.DBNames {
  217. if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
  218. if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
  219. values.Columns = append(values.Columns, clause.Column{Name: db})
  220. }
  221. }
  222. }
  223. switch stmt.ReflectValue.Kind() {
  224. case reflect.Slice, reflect.Array:
  225. rValLen := stmt.ReflectValue.Len()
  226. if rValLen == 0 {
  227. stmt.AddError(gorm.ErrEmptySlice)
  228. return
  229. }
  230. stmt.SQL.Grow(rValLen * 18)
  231. stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns))
  232. values.Values = make([][]interface{}, rValLen)
  233. defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{}
  234. for i := 0; i < rValLen; i++ {
  235. rv := reflect.Indirect(stmt.ReflectValue.Index(i))
  236. if !rv.IsValid() {
  237. stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData))
  238. return
  239. }
  240. values.Values[i] = make([]interface{}, len(values.Columns))
  241. for idx, column := range values.Columns {
  242. field := stmt.Schema.FieldsByDBName[column.Name]
  243. if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero {
  244. if field.DefaultValueInterface != nil {
  245. values.Values[i][idx] = field.DefaultValueInterface
  246. stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface))
  247. } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
  248. stmt.AddError(field.Set(stmt.Context, rv, curTime))
  249. values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
  250. }
  251. } else if field.AutoUpdateTime > 0 && updateTrackTime {
  252. stmt.AddError(field.Set(stmt.Context, rv, curTime))
  253. values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv)
  254. }
  255. }
  256. for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
  257. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
  258. if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero {
  259. if len(defaultValueFieldsHavingValue[field]) == 0 {
  260. defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen)
  261. }
  262. defaultValueFieldsHavingValue[field][i] = rvOfvalue
  263. }
  264. }
  265. }
  266. }
  267. for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
  268. if vs, ok := defaultValueFieldsHavingValue[field]; ok {
  269. values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
  270. for idx := range values.Values {
  271. if vs[idx] == nil {
  272. values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field))
  273. } else {
  274. values.Values[idx] = append(values.Values[idx], vs[idx])
  275. }
  276. }
  277. }
  278. }
  279. case reflect.Struct:
  280. values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
  281. for idx, column := range values.Columns {
  282. field := stmt.Schema.FieldsByDBName[column.Name]
  283. if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
  284. if field.DefaultValueInterface != nil {
  285. values.Values[0][idx] = field.DefaultValueInterface
  286. stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
  287. } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
  288. stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
  289. values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
  290. }
  291. } else if field.AutoUpdateTime > 0 && updateTrackTime {
  292. stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
  293. values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
  294. }
  295. }
  296. for _, field := range stmt.Schema.FieldsWithDefaultDBValue {
  297. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil {
  298. if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero {
  299. values.Columns = append(values.Columns, clause.Column{Name: field.DBName})
  300. values.Values[0] = append(values.Values[0], rvOfvalue)
  301. }
  302. }
  303. }
  304. default:
  305. stmt.AddError(gorm.ErrInvalidData)
  306. }
  307. }
  308. if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
  309. if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
  310. if stmt.Schema != nil && len(values.Columns) >= 1 {
  311. selectColumns, restricted := stmt.SelectAndOmitColumns(true, true)
  312. columns := make([]string, 0, len(values.Columns)-1)
  313. for _, column := range values.Columns {
  314. if field := stmt.Schema.LookUpField(column.Name); field != nil {
  315. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
  316. if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil ||
  317. strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 {
  318. if field.AutoUpdateTime > 0 {
  319. assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime}
  320. switch field.AutoUpdateTime {
  321. case schema.UnixNanosecond:
  322. assignment.Value = curTime.UnixNano()
  323. case schema.UnixMillisecond:
  324. assignment.Value = curTime.UnixMilli()
  325. case schema.UnixSecond:
  326. assignment.Value = curTime.Unix()
  327. }
  328. onConflict.DoUpdates = append(onConflict.DoUpdates, assignment)
  329. } else {
  330. columns = append(columns, column.Name)
  331. }
  332. }
  333. }
  334. }
  335. }
  336. onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...)
  337. if len(onConflict.DoUpdates) == 0 {
  338. onConflict.DoNothing = true
  339. }
  340. // use primary fields as default OnConflict columns
  341. if len(onConflict.Columns) == 0 {
  342. for _, field := range stmt.Schema.PrimaryFields {
  343. onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName})
  344. }
  345. }
  346. stmt.AddClause(onConflict)
  347. }
  348. }
  349. }
  350. return values
  351. }