finisher_api.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772
  1. package gorm
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "hash/maphash"
  7. "reflect"
  8. "strings"
  9. "gorm.io/gorm/clause"
  10. "gorm.io/gorm/logger"
  11. "gorm.io/gorm/schema"
  12. "gorm.io/gorm/utils"
  13. )
  14. // Create inserts value, returning the inserted data's primary key in value's id
  15. func (db *DB) Create(value interface{}) (tx *DB) {
  16. if db.CreateBatchSize > 0 {
  17. return db.CreateInBatches(value, db.CreateBatchSize)
  18. }
  19. tx = db.getInstance()
  20. tx.Statement.Dest = value
  21. return tx.callbacks.Create().Execute(tx)
  22. }
  23. // CreateInBatches inserts value in batches of batchSize
  24. func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
  25. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  26. switch reflectValue.Kind() {
  27. case reflect.Slice, reflect.Array:
  28. var rowsAffected int64
  29. tx = db.getInstance()
  30. // the reflection length judgment of the optimized value
  31. reflectLen := reflectValue.Len()
  32. callFc := func(tx *DB) error {
  33. for i := 0; i < reflectLen; i += batchSize {
  34. ends := i + batchSize
  35. if ends > reflectLen {
  36. ends = reflectLen
  37. }
  38. subtx := tx.getInstance()
  39. subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
  40. subtx.callbacks.Create().Execute(subtx)
  41. if subtx.Error != nil {
  42. return subtx.Error
  43. }
  44. rowsAffected += subtx.RowsAffected
  45. }
  46. return nil
  47. }
  48. if tx.SkipDefaultTransaction || reflectLen <= batchSize {
  49. tx.AddError(callFc(tx.Session(&Session{})))
  50. } else {
  51. tx.AddError(tx.Transaction(callFc))
  52. }
  53. tx.RowsAffected = rowsAffected
  54. default:
  55. tx = db.getInstance()
  56. tx.Statement.Dest = value
  57. tx = tx.callbacks.Create().Execute(tx)
  58. }
  59. return
  60. }
  61. // Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
  62. func (db *DB) Save(value interface{}) (tx *DB) {
  63. tx = db.getInstance()
  64. tx.Statement.Dest = value
  65. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  66. for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
  67. reflectValue = reflect.Indirect(reflectValue)
  68. }
  69. switch reflectValue.Kind() {
  70. case reflect.Slice, reflect.Array:
  71. if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
  72. tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
  73. }
  74. tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
  75. case reflect.Struct:
  76. if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
  77. for _, pf := range tx.Statement.Schema.PrimaryFields {
  78. if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
  79. return tx.callbacks.Create().Execute(tx)
  80. }
  81. }
  82. }
  83. fallthrough
  84. default:
  85. selectedUpdate := len(tx.Statement.Selects) != 0
  86. // when updating, use all fields including those zero-value fields
  87. if !selectedUpdate {
  88. tx.Statement.Selects = append(tx.Statement.Selects, "*")
  89. }
  90. updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
  91. if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
  92. return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
  93. }
  94. return updateTx
  95. }
  96. return
  97. }
  98. // First finds the first record ordered by primary key, matching given conditions conds
  99. func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
  100. tx = db.Limit(1).Order(clause.OrderByColumn{
  101. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  102. })
  103. if len(conds) > 0 {
  104. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  105. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  106. }
  107. }
  108. tx.Statement.RaiseErrorOnNotFound = true
  109. tx.Statement.Dest = dest
  110. return tx.callbacks.Query().Execute(tx)
  111. }
  112. // Take finds the first record returned by the database in no specified order, matching given conditions conds
  113. func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
  114. tx = db.Limit(1)
  115. if len(conds) > 0 {
  116. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  117. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  118. }
  119. }
  120. tx.Statement.RaiseErrorOnNotFound = true
  121. tx.Statement.Dest = dest
  122. return tx.callbacks.Query().Execute(tx)
  123. }
  124. // Last finds the last record ordered by primary key, matching given conditions conds
  125. func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
  126. tx = db.Limit(1).Order(clause.OrderByColumn{
  127. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  128. Desc: true,
  129. })
  130. if len(conds) > 0 {
  131. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  132. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  133. }
  134. }
  135. tx.Statement.RaiseErrorOnNotFound = true
  136. tx.Statement.Dest = dest
  137. return tx.callbacks.Query().Execute(tx)
  138. }
  139. // Find finds all records matching given conditions conds
  140. func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
  141. tx = db.getInstance()
  142. if len(conds) > 0 {
  143. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  144. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  145. }
  146. }
  147. tx.Statement.Dest = dest
  148. return tx.callbacks.Query().Execute(tx)
  149. }
  150. // FindInBatches finds all records in batches of batchSize
  151. func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
  152. var (
  153. tx = db.Order(clause.OrderByColumn{
  154. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  155. }).Session(&Session{})
  156. queryDB = tx
  157. rowsAffected int64
  158. batch int
  159. )
  160. // user specified offset or limit
  161. var totalSize int
  162. if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
  163. if limit, ok := c.Expression.(clause.Limit); ok {
  164. if limit.Limit != nil {
  165. totalSize = *limit.Limit
  166. }
  167. if totalSize > 0 && batchSize > totalSize {
  168. batchSize = totalSize
  169. }
  170. // reset to offset to 0 in next batch
  171. tx = tx.Offset(-1).Session(&Session{})
  172. }
  173. }
  174. for {
  175. result := queryDB.Limit(batchSize).Find(dest)
  176. rowsAffected += result.RowsAffected
  177. batch++
  178. if result.Error == nil && result.RowsAffected != 0 {
  179. fcTx := result.Session(&Session{NewDB: true})
  180. fcTx.RowsAffected = result.RowsAffected
  181. tx.AddError(fc(fcTx, batch))
  182. } else if result.Error != nil {
  183. tx.AddError(result.Error)
  184. }
  185. if tx.Error != nil || int(result.RowsAffected) < batchSize {
  186. break
  187. }
  188. if totalSize > 0 {
  189. if totalSize <= int(rowsAffected) {
  190. break
  191. }
  192. if totalSize/batchSize == batch {
  193. batchSize = totalSize % batchSize
  194. }
  195. }
  196. // Optimize for-break
  197. resultsValue := reflect.Indirect(reflect.ValueOf(dest))
  198. if result.Statement.Schema.PrioritizedPrimaryField == nil {
  199. tx.AddError(ErrPrimaryKeyRequired)
  200. break
  201. }
  202. primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
  203. if zero {
  204. tx.AddError(ErrPrimaryKeyRequired)
  205. break
  206. }
  207. queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
  208. }
  209. tx.RowsAffected = rowsAffected
  210. return tx
  211. }
  212. func (db *DB) assignInterfacesToValue(values ...interface{}) {
  213. for _, value := range values {
  214. switch v := value.(type) {
  215. case []clause.Expression:
  216. for _, expr := range v {
  217. if eq, ok := expr.(clause.Eq); ok {
  218. switch column := eq.Column.(type) {
  219. case string:
  220. if field := db.Statement.Schema.LookUpField(column); field != nil {
  221. db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
  222. }
  223. case clause.Column:
  224. if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
  225. db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
  226. }
  227. }
  228. } else if andCond, ok := expr.(clause.AndConditions); ok {
  229. db.assignInterfacesToValue(andCond.Exprs)
  230. }
  231. }
  232. case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
  233. if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
  234. db.assignInterfacesToValue(exprs)
  235. }
  236. default:
  237. if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
  238. reflectValue := reflect.Indirect(reflect.ValueOf(value))
  239. switch reflectValue.Kind() {
  240. case reflect.Struct:
  241. for _, f := range s.Fields {
  242. if f.Readable {
  243. if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
  244. if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
  245. db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
  246. }
  247. }
  248. }
  249. }
  250. }
  251. } else if len(values) > 0 {
  252. if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
  253. db.assignInterfacesToValue(exprs)
  254. }
  255. return
  256. }
  257. }
  258. }
  259. }
  260. // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
  261. // Each conds must be a struct or map.
  262. //
  263. // FirstOrInit never modifies the database. It is often used with Assign and Attrs.
  264. //
  265. // // assign an email if the record is not found
  266. // db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
  267. // // user -> User{Name: "non_existing", Email: "fake@fake.org"}
  268. //
  269. // // assign email regardless of if record is found
  270. // db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
  271. // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
  272. func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
  273. queryTx := db.Limit(1).Order(clause.OrderByColumn{
  274. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  275. })
  276. if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
  277. if c, ok := tx.Statement.Clauses["WHERE"]; ok {
  278. if where, ok := c.Expression.(clause.Where); ok {
  279. tx.assignInterfacesToValue(where.Exprs)
  280. }
  281. }
  282. // initialize with attrs, conds
  283. if len(tx.Statement.attrs) > 0 {
  284. tx.assignInterfacesToValue(tx.Statement.attrs...)
  285. }
  286. }
  287. // initialize with attrs, conds
  288. if len(tx.Statement.assigns) > 0 {
  289. tx.assignInterfacesToValue(tx.Statement.assigns...)
  290. }
  291. return
  292. }
  293. // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
  294. // Each conds must be a struct or map.
  295. //
  296. // Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
  297. //
  298. // // assign an email if the record is not found
  299. // result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
  300. // // user -> User{Name: "non_existing", Email: "fake@fake.org"}
  301. // // result.RowsAffected -> 1
  302. //
  303. // // assign email regardless of if record is found
  304. // result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
  305. // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
  306. // // result.RowsAffected -> 1
  307. func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
  308. tx = db.getInstance()
  309. queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
  310. Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
  311. })
  312. result := queryTx.Find(dest, conds...)
  313. if result.Error != nil {
  314. tx.Error = result.Error
  315. return tx
  316. }
  317. if result.RowsAffected == 0 {
  318. if c, ok := result.Statement.Clauses["WHERE"]; ok {
  319. if where, ok := c.Expression.(clause.Where); ok {
  320. result.assignInterfacesToValue(where.Exprs)
  321. }
  322. }
  323. // initialize with attrs, conds
  324. if len(db.Statement.attrs) > 0 {
  325. result.assignInterfacesToValue(db.Statement.attrs...)
  326. }
  327. // initialize with attrs, conds
  328. if len(db.Statement.assigns) > 0 {
  329. result.assignInterfacesToValue(db.Statement.assigns...)
  330. }
  331. return tx.Create(dest)
  332. } else if len(db.Statement.assigns) > 0 {
  333. exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
  334. assigns := map[string]interface{}{}
  335. for i := 0; i < len(exprs); i++ {
  336. expr := exprs[i]
  337. if eq, ok := expr.(clause.AndConditions); ok {
  338. exprs = append(exprs, eq.Exprs...)
  339. } else if eq, ok := expr.(clause.Eq); ok {
  340. switch column := eq.Column.(type) {
  341. case string:
  342. assigns[column] = eq.Value
  343. case clause.Column:
  344. assigns[column.Name] = eq.Value
  345. }
  346. }
  347. }
  348. return tx.Model(dest).Updates(assigns)
  349. }
  350. return tx
  351. }
  352. // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
  353. func (db *DB) Update(column string, value interface{}) (tx *DB) {
  354. tx = db.getInstance()
  355. tx.Statement.Dest = map[string]interface{}{column: value}
  356. return tx.callbacks.Update().Execute(tx)
  357. }
  358. // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
  359. func (db *DB) Updates(values interface{}) (tx *DB) {
  360. tx = db.getInstance()
  361. tx.Statement.Dest = values
  362. return tx.callbacks.Update().Execute(tx)
  363. }
  364. func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
  365. tx = db.getInstance()
  366. tx.Statement.Dest = map[string]interface{}{column: value}
  367. tx.Statement.SkipHooks = true
  368. return tx.callbacks.Update().Execute(tx)
  369. }
  370. func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
  371. tx = db.getInstance()
  372. tx.Statement.Dest = values
  373. tx.Statement.SkipHooks = true
  374. return tx.callbacks.Update().Execute(tx)
  375. }
  376. // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
  377. // value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
  378. // time if null.
  379. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
  380. tx = db.getInstance()
  381. if len(conds) > 0 {
  382. if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
  383. tx.Statement.AddClause(clause.Where{Exprs: exprs})
  384. }
  385. }
  386. tx.Statement.Dest = value
  387. return tx.callbacks.Delete().Execute(tx)
  388. }
  389. func (db *DB) Count(count *int64) (tx *DB) {
  390. tx = db.getInstance()
  391. if tx.Statement.Model == nil {
  392. tx.Statement.Model = tx.Statement.Dest
  393. defer func() {
  394. tx.Statement.Model = nil
  395. }()
  396. }
  397. if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
  398. defer func() {
  399. tx.Statement.Clauses["SELECT"] = selectClause
  400. }()
  401. } else {
  402. defer delete(tx.Statement.Clauses, "SELECT")
  403. }
  404. if len(tx.Statement.Selects) == 0 {
  405. tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
  406. } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
  407. expr := clause.Expr{SQL: "count(*)"}
  408. if len(tx.Statement.Selects) == 1 {
  409. dbName := tx.Statement.Selects[0]
  410. fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
  411. if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
  412. if tx.Statement.Parse(tx.Statement.Model) == nil {
  413. if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
  414. dbName = f.DBName
  415. }
  416. }
  417. if tx.Statement.Distinct {
  418. expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
  419. } else if dbName != "*" {
  420. expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
  421. }
  422. }
  423. }
  424. tx.Statement.AddClause(clause.Select{Expression: expr})
  425. }
  426. if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
  427. if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
  428. delete(tx.Statement.Clauses, "ORDER BY")
  429. defer func() {
  430. tx.Statement.Clauses["ORDER BY"] = orderByClause
  431. }()
  432. }
  433. }
  434. tx.Statement.Dest = count
  435. tx = tx.callbacks.Query().Execute(tx)
  436. if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
  437. *count = tx.RowsAffected
  438. }
  439. return
  440. }
  441. func (db *DB) Row() *sql.Row {
  442. tx := db.getInstance().Set("rows", false)
  443. tx = tx.callbacks.Row().Execute(tx)
  444. row, ok := tx.Statement.Dest.(*sql.Row)
  445. if !ok && tx.DryRun {
  446. db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
  447. }
  448. return row
  449. }
  450. func (db *DB) Rows() (*sql.Rows, error) {
  451. tx := db.getInstance().Set("rows", true)
  452. tx = tx.callbacks.Row().Execute(tx)
  453. rows, ok := tx.Statement.Dest.(*sql.Rows)
  454. if !ok && tx.DryRun && tx.Error == nil {
  455. tx.Error = ErrDryRunModeUnsupported
  456. }
  457. return rows, tx.Error
  458. }
  459. // Scan scans selected value to the struct dest
  460. func (db *DB) Scan(dest interface{}) (tx *DB) {
  461. config := *db.Config
  462. currentLogger, newLogger := config.Logger, logger.Recorder.New()
  463. config.Logger = newLogger
  464. tx = db.getInstance()
  465. tx.Config = &config
  466. if rows, err := tx.Rows(); err == nil {
  467. if rows.Next() {
  468. tx.ScanRows(rows, dest)
  469. } else {
  470. tx.RowsAffected = 0
  471. tx.AddError(rows.Err())
  472. }
  473. tx.AddError(rows.Close())
  474. }
  475. currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
  476. return newLogger.SQL, tx.RowsAffected
  477. }, tx.Error)
  478. tx.Logger = currentLogger
  479. return
  480. }
  481. // Pluck queries a single column from a model, returning in the slice dest. E.g.:
  482. //
  483. // var ages []int64
  484. // db.Model(&users).Pluck("age", &ages)
  485. func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
  486. tx = db.getInstance()
  487. if tx.Statement.Model != nil {
  488. if tx.Statement.Parse(tx.Statement.Model) == nil {
  489. if f := tx.Statement.Schema.LookUpField(column); f != nil {
  490. column = f.DBName
  491. }
  492. }
  493. }
  494. if len(tx.Statement.Selects) != 1 {
  495. fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
  496. tx.Statement.AddClauseIfNotExists(clause.Select{
  497. Distinct: tx.Statement.Distinct,
  498. Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
  499. })
  500. }
  501. tx.Statement.Dest = dest
  502. return tx.callbacks.Query().Execute(tx)
  503. }
  504. func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
  505. tx := db.getInstance()
  506. if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
  507. tx.AddError(err)
  508. }
  509. tx.Statement.Dest = dest
  510. tx.Statement.ReflectValue = reflect.ValueOf(dest)
  511. for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
  512. elem := tx.Statement.ReflectValue.Elem()
  513. if !elem.IsValid() {
  514. elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
  515. tx.Statement.ReflectValue.Set(elem)
  516. }
  517. tx.Statement.ReflectValue = elem
  518. }
  519. Scan(rows, tx, ScanInitialized)
  520. return tx.Error
  521. }
  522. // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
  523. // returned to the connection pool.
  524. func (db *DB) Connection(fc func(tx *DB) error) (err error) {
  525. if db.Error != nil {
  526. return db.Error
  527. }
  528. tx := db.getInstance()
  529. sqlDB, err := tx.DB()
  530. if err != nil {
  531. return
  532. }
  533. conn, err := sqlDB.Conn(tx.Statement.Context)
  534. if err != nil {
  535. return
  536. }
  537. defer conn.Close()
  538. tx.Statement.ConnPool = conn
  539. return fc(tx)
  540. }
  541. // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
  542. // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
  543. // they are rolled back.
  544. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
  545. panicked := true
  546. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
  547. // nested transaction
  548. if !db.DisableNestedTransaction {
  549. spID := new(maphash.Hash).Sum64()
  550. err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
  551. if err != nil {
  552. return
  553. }
  554. defer func() {
  555. // Make sure to rollback when panic, Block error or Commit error
  556. if panicked || err != nil {
  557. db.RollbackTo(fmt.Sprintf("sp%d", spID))
  558. }
  559. }()
  560. }
  561. err = fc(db.Session(&Session{NewDB: db.clone == 1}))
  562. } else {
  563. tx := db.Begin(opts...)
  564. if tx.Error != nil {
  565. return tx.Error
  566. }
  567. defer func() {
  568. // Make sure to rollback when panic, Block error or Commit error
  569. if panicked || err != nil {
  570. tx.Rollback()
  571. }
  572. }()
  573. if err = fc(tx); err == nil {
  574. panicked = false
  575. return tx.Commit().Error
  576. }
  577. }
  578. panicked = false
  579. return
  580. }
  581. // Begin begins a transaction with any transaction options opts
  582. func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
  583. var (
  584. // clone statement
  585. tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
  586. opt *sql.TxOptions
  587. err error
  588. )
  589. if len(opts) > 0 {
  590. opt = opts[0]
  591. }
  592. switch beginner := tx.Statement.ConnPool.(type) {
  593. case TxBeginner:
  594. tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
  595. case ConnPoolBeginner:
  596. tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
  597. default:
  598. err = ErrInvalidTransaction
  599. }
  600. if err != nil {
  601. tx.AddError(err)
  602. }
  603. return tx
  604. }
  605. // Commit commits the changes in a transaction
  606. func (db *DB) Commit() *DB {
  607. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
  608. db.AddError(committer.Commit())
  609. } else {
  610. db.AddError(ErrInvalidTransaction)
  611. }
  612. return db
  613. }
  614. // Rollback rollbacks the changes in a transaction
  615. func (db *DB) Rollback() *DB {
  616. if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
  617. if !reflect.ValueOf(committer).IsNil() {
  618. db.AddError(committer.Rollback())
  619. }
  620. } else {
  621. db.AddError(ErrInvalidTransaction)
  622. }
  623. return db
  624. }
  625. func (db *DB) SavePoint(name string) *DB {
  626. if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
  627. // close prepared statement, because SavePoint not support prepared statement.
  628. // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
  629. var (
  630. preparedStmtTx *PreparedStmtTX
  631. isPreparedStmtTx bool
  632. )
  633. // close prepared statement, because SavePoint not support prepared statement.
  634. if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
  635. db.Statement.ConnPool = preparedStmtTx.Tx
  636. }
  637. db.AddError(savePointer.SavePoint(db, name))
  638. // restore prepared statement
  639. if isPreparedStmtTx {
  640. db.Statement.ConnPool = preparedStmtTx
  641. }
  642. } else {
  643. db.AddError(ErrUnsupportedDriver)
  644. }
  645. return db
  646. }
  647. func (db *DB) RollbackTo(name string) *DB {
  648. if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
  649. // close prepared statement, because RollbackTo not support prepared statement.
  650. // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
  651. var (
  652. preparedStmtTx *PreparedStmtTX
  653. isPreparedStmtTx bool
  654. )
  655. // close prepared statement, because SavePoint not support prepared statement.
  656. if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
  657. db.Statement.ConnPool = preparedStmtTx.Tx
  658. }
  659. db.AddError(savePointer.RollbackTo(db, name))
  660. // restore prepared statement
  661. if isPreparedStmtTx {
  662. db.Statement.ConnPool = preparedStmtTx
  663. }
  664. } else {
  665. db.AddError(ErrUnsupportedDriver)
  666. }
  667. return db
  668. }
  669. // Exec executes raw sql
  670. func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
  671. tx = db.getInstance()
  672. tx.Statement.SQL = strings.Builder{}
  673. if strings.Contains(sql, "@") {
  674. clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
  675. } else {
  676. clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
  677. }
  678. return tx.callbacks.Raw().Execute(tx)
  679. }