123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772 |
- package gorm
- import (
- "database/sql"
- "errors"
- "fmt"
- "hash/maphash"
- "reflect"
- "strings"
- "gorm.io/gorm/clause"
- "gorm.io/gorm/logger"
- "gorm.io/gorm/schema"
- "gorm.io/gorm/utils"
- )
- // Create inserts value, returning the inserted data's primary key in value's id
- func (db *DB) Create(value interface{}) (tx *DB) {
- if db.CreateBatchSize > 0 {
- return db.CreateInBatches(value, db.CreateBatchSize)
- }
- tx = db.getInstance()
- tx.Statement.Dest = value
- return tx.callbacks.Create().Execute(tx)
- }
- // CreateInBatches inserts value in batches of batchSize
- func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) {
- reflectValue := reflect.Indirect(reflect.ValueOf(value))
- switch reflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- var rowsAffected int64
- tx = db.getInstance()
- // the reflection length judgment of the optimized value
- reflectLen := reflectValue.Len()
- callFc := func(tx *DB) error {
- for i := 0; i < reflectLen; i += batchSize {
- ends := i + batchSize
- if ends > reflectLen {
- ends = reflectLen
- }
- subtx := tx.getInstance()
- subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface()
- subtx.callbacks.Create().Execute(subtx)
- if subtx.Error != nil {
- return subtx.Error
- }
- rowsAffected += subtx.RowsAffected
- }
- return nil
- }
- if tx.SkipDefaultTransaction || reflectLen <= batchSize {
- tx.AddError(callFc(tx.Session(&Session{})))
- } else {
- tx.AddError(tx.Transaction(callFc))
- }
- tx.RowsAffected = rowsAffected
- default:
- tx = db.getInstance()
- tx.Statement.Dest = value
- tx = tx.callbacks.Create().Execute(tx)
- }
- return
- }
- // Save updates value in database. If value doesn't contain a matching primary key, value is inserted.
- func (db *DB) Save(value interface{}) (tx *DB) {
- tx = db.getInstance()
- tx.Statement.Dest = value
- reflectValue := reflect.Indirect(reflect.ValueOf(value))
- for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface {
- reflectValue = reflect.Indirect(reflectValue)
- }
- switch reflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok {
- tx = tx.Clauses(clause.OnConflict{UpdateAll: true})
- }
- tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true))
- case reflect.Struct:
- if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil {
- for _, pf := range tx.Statement.Schema.PrimaryFields {
- if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero {
- return tx.callbacks.Create().Execute(tx)
- }
- }
- }
- fallthrough
- default:
- selectedUpdate := len(tx.Statement.Selects) != 0
- // when updating, use all fields including those zero-value fields
- if !selectedUpdate {
- tx.Statement.Selects = append(tx.Statement.Selects, "*")
- }
- updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true}))
- if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate {
- return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value)
- }
- return updateTx
- }
- return
- }
- // First finds the first record ordered by primary key, matching given conditions conds
- func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
- tx = db.Limit(1).Order(clause.OrderByColumn{
- Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
- })
- if len(conds) > 0 {
- if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
- tx.Statement.AddClause(clause.Where{Exprs: exprs})
- }
- }
- tx.Statement.RaiseErrorOnNotFound = true
- tx.Statement.Dest = dest
- return tx.callbacks.Query().Execute(tx)
- }
- // Take finds the first record returned by the database in no specified order, matching given conditions conds
- func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
- tx = db.Limit(1)
- if len(conds) > 0 {
- if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
- tx.Statement.AddClause(clause.Where{Exprs: exprs})
- }
- }
- tx.Statement.RaiseErrorOnNotFound = true
- tx.Statement.Dest = dest
- return tx.callbacks.Query().Execute(tx)
- }
- // Last finds the last record ordered by primary key, matching given conditions conds
- func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
- tx = db.Limit(1).Order(clause.OrderByColumn{
- Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
- Desc: true,
- })
- if len(conds) > 0 {
- if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
- tx.Statement.AddClause(clause.Where{Exprs: exprs})
- }
- }
- tx.Statement.RaiseErrorOnNotFound = true
- tx.Statement.Dest = dest
- return tx.callbacks.Query().Execute(tx)
- }
- // Find finds all records matching given conditions conds
- func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {
- tx = db.getInstance()
- if len(conds) > 0 {
- if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
- tx.Statement.AddClause(clause.Where{Exprs: exprs})
- }
- }
- tx.Statement.Dest = dest
- return tx.callbacks.Query().Execute(tx)
- }
- // FindInBatches finds all records in batches of batchSize
- func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
- var (
- tx = db.Order(clause.OrderByColumn{
- Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
- }).Session(&Session{})
- queryDB = tx
- rowsAffected int64
- batch int
- )
- // user specified offset or limit
- var totalSize int
- if c, ok := tx.Statement.Clauses["LIMIT"]; ok {
- if limit, ok := c.Expression.(clause.Limit); ok {
- if limit.Limit != nil {
- totalSize = *limit.Limit
- }
- if totalSize > 0 && batchSize > totalSize {
- batchSize = totalSize
- }
- // reset to offset to 0 in next batch
- tx = tx.Offset(-1).Session(&Session{})
- }
- }
- for {
- result := queryDB.Limit(batchSize).Find(dest)
- rowsAffected += result.RowsAffected
- batch++
- if result.Error == nil && result.RowsAffected != 0 {
- fcTx := result.Session(&Session{NewDB: true})
- fcTx.RowsAffected = result.RowsAffected
- tx.AddError(fc(fcTx, batch))
- } else if result.Error != nil {
- tx.AddError(result.Error)
- }
- if tx.Error != nil || int(result.RowsAffected) < batchSize {
- break
- }
- if totalSize > 0 {
- if totalSize <= int(rowsAffected) {
- break
- }
- if totalSize/batchSize == batch {
- batchSize = totalSize % batchSize
- }
- }
- // Optimize for-break
- resultsValue := reflect.Indirect(reflect.ValueOf(dest))
- if result.Statement.Schema.PrioritizedPrimaryField == nil {
- tx.AddError(ErrPrimaryKeyRequired)
- break
- }
- primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
- if zero {
- tx.AddError(ErrPrimaryKeyRequired)
- break
- }
- queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
- }
- tx.RowsAffected = rowsAffected
- return tx
- }
- func (db *DB) assignInterfacesToValue(values ...interface{}) {
- for _, value := range values {
- switch v := value.(type) {
- case []clause.Expression:
- for _, expr := range v {
- if eq, ok := expr.(clause.Eq); ok {
- switch column := eq.Column.(type) {
- case string:
- if field := db.Statement.Schema.LookUpField(column); field != nil {
- db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
- }
- case clause.Column:
- if field := db.Statement.Schema.LookUpField(column.Name); field != nil {
- db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value))
- }
- }
- } else if andCond, ok := expr.(clause.AndConditions); ok {
- db.assignInterfacesToValue(andCond.Exprs)
- }
- }
- case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}:
- if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 {
- db.assignInterfacesToValue(exprs)
- }
- default:
- if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil {
- reflectValue := reflect.Indirect(reflect.ValueOf(value))
- switch reflectValue.Kind() {
- case reflect.Struct:
- for _, f := range s.Fields {
- if f.Readable {
- if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero {
- if field := db.Statement.Schema.LookUpField(f.Name); field != nil {
- db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v))
- }
- }
- }
- }
- }
- } else if len(values) > 0 {
- if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 {
- db.assignInterfacesToValue(exprs)
- }
- return
- }
- }
- }
- }
- // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds.
- // Each conds must be a struct or map.
- //
- // FirstOrInit never modifies the database. It is often used with Assign and Attrs.
- //
- // // assign an email if the record is not found
- // db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user)
- // // user -> User{Name: "non_existing", Email: "fake@fake.org"}
- //
- // // assign email regardless of if record is found
- // db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user)
- // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
- func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
- queryTx := db.Limit(1).Order(clause.OrderByColumn{
- Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
- })
- if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
- if c, ok := tx.Statement.Clauses["WHERE"]; ok {
- if where, ok := c.Expression.(clause.Where); ok {
- tx.assignInterfacesToValue(where.Exprs)
- }
- }
- // initialize with attrs, conds
- if len(tx.Statement.attrs) > 0 {
- tx.assignInterfacesToValue(tx.Statement.attrs...)
- }
- }
- // initialize with attrs, conds
- if len(tx.Statement.assigns) > 0 {
- tx.assignInterfacesToValue(tx.Statement.assigns...)
- }
- return
- }
- // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds.
- // Each conds must be a struct or map.
- //
- // Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists.
- //
- // // assign an email if the record is not found
- // result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
- // // user -> User{Name: "non_existing", Email: "fake@fake.org"}
- // // result.RowsAffected -> 1
- //
- // // assign email regardless of if record is found
- // result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user)
- // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
- // // result.RowsAffected -> 1
- func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
- tx = db.getInstance()
- queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
- Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
- })
- result := queryTx.Find(dest, conds...)
- if result.Error != nil {
- tx.Error = result.Error
- return tx
- }
- if result.RowsAffected == 0 {
- if c, ok := result.Statement.Clauses["WHERE"]; ok {
- if where, ok := c.Expression.(clause.Where); ok {
- result.assignInterfacesToValue(where.Exprs)
- }
- }
- // initialize with attrs, conds
- if len(db.Statement.attrs) > 0 {
- result.assignInterfacesToValue(db.Statement.attrs...)
- }
- // initialize with attrs, conds
- if len(db.Statement.assigns) > 0 {
- result.assignInterfacesToValue(db.Statement.assigns...)
- }
- return tx.Create(dest)
- } else if len(db.Statement.assigns) > 0 {
- exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...)
- assigns := map[string]interface{}{}
- for i := 0; i < len(exprs); i++ {
- expr := exprs[i]
- if eq, ok := expr.(clause.AndConditions); ok {
- exprs = append(exprs, eq.Exprs...)
- } else if eq, ok := expr.(clause.Eq); ok {
- switch column := eq.Column.(type) {
- case string:
- assigns[column] = eq.Value
- case clause.Column:
- assigns[column.Name] = eq.Value
- }
- }
- }
- return tx.Model(dest).Updates(assigns)
- }
- return tx
- }
- // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
- func (db *DB) Update(column string, value interface{}) (tx *DB) {
- tx = db.getInstance()
- tx.Statement.Dest = map[string]interface{}{column: value}
- return tx.callbacks.Update().Execute(tx)
- }
- // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields
- func (db *DB) Updates(values interface{}) (tx *DB) {
- tx = db.getInstance()
- tx.Statement.Dest = values
- return tx.callbacks.Update().Execute(tx)
- }
- func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
- tx = db.getInstance()
- tx.Statement.Dest = map[string]interface{}{column: value}
- tx.Statement.SkipHooks = true
- return tx.callbacks.Update().Execute(tx)
- }
- func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
- tx = db.getInstance()
- tx.Statement.Dest = values
- tx.Statement.SkipHooks = true
- return tx.callbacks.Update().Execute(tx)
- }
- // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If
- // value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current
- // time if null.
- func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) {
- tx = db.getInstance()
- if len(conds) > 0 {
- if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
- tx.Statement.AddClause(clause.Where{Exprs: exprs})
- }
- }
- tx.Statement.Dest = value
- return tx.callbacks.Delete().Execute(tx)
- }
- func (db *DB) Count(count *int64) (tx *DB) {
- tx = db.getInstance()
- if tx.Statement.Model == nil {
- tx.Statement.Model = tx.Statement.Dest
- defer func() {
- tx.Statement.Model = nil
- }()
- }
- if selectClause, ok := db.Statement.Clauses["SELECT"]; ok {
- defer func() {
- tx.Statement.Clauses["SELECT"] = selectClause
- }()
- } else {
- defer delete(tx.Statement.Clauses, "SELECT")
- }
- if len(tx.Statement.Selects) == 0 {
- tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}})
- } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") {
- expr := clause.Expr{SQL: "count(*)"}
- if len(tx.Statement.Selects) == 1 {
- dbName := tx.Statement.Selects[0]
- fields := strings.FieldsFunc(dbName, utils.IsValidDBNameChar)
- if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) {
- if tx.Statement.Parse(tx.Statement.Model) == nil {
- if f := tx.Statement.Schema.LookUpField(dbName); f != nil {
- dbName = f.DBName
- }
- }
- if tx.Statement.Distinct {
- expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}}
- } else if dbName != "*" {
- expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}}
- }
- }
- }
- tx.Statement.AddClause(clause.Select{Expression: expr})
- }
- if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok {
- if _, ok := db.Statement.Clauses["GROUP BY"]; !ok {
- delete(tx.Statement.Clauses, "ORDER BY")
- defer func() {
- tx.Statement.Clauses["ORDER BY"] = orderByClause
- }()
- }
- }
- tx.Statement.Dest = count
- tx = tx.callbacks.Query().Execute(tx)
- if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 {
- *count = tx.RowsAffected
- }
- return
- }
- func (db *DB) Row() *sql.Row {
- tx := db.getInstance().Set("rows", false)
- tx = tx.callbacks.Row().Execute(tx)
- row, ok := tx.Statement.Dest.(*sql.Row)
- if !ok && tx.DryRun {
- db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error())
- }
- return row
- }
- func (db *DB) Rows() (*sql.Rows, error) {
- tx := db.getInstance().Set("rows", true)
- tx = tx.callbacks.Row().Execute(tx)
- rows, ok := tx.Statement.Dest.(*sql.Rows)
- if !ok && tx.DryRun && tx.Error == nil {
- tx.Error = ErrDryRunModeUnsupported
- }
- return rows, tx.Error
- }
- // Scan scans selected value to the struct dest
- func (db *DB) Scan(dest interface{}) (tx *DB) {
- config := *db.Config
- currentLogger, newLogger := config.Logger, logger.Recorder.New()
- config.Logger = newLogger
- tx = db.getInstance()
- tx.Config = &config
- if rows, err := tx.Rows(); err == nil {
- if rows.Next() {
- tx.ScanRows(rows, dest)
- } else {
- tx.RowsAffected = 0
- tx.AddError(rows.Err())
- }
- tx.AddError(rows.Close())
- }
- currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) {
- return newLogger.SQL, tx.RowsAffected
- }, tx.Error)
- tx.Logger = currentLogger
- return
- }
- // Pluck queries a single column from a model, returning in the slice dest. E.g.:
- //
- // var ages []int64
- // db.Model(&users).Pluck("age", &ages)
- func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
- tx = db.getInstance()
- if tx.Statement.Model != nil {
- if tx.Statement.Parse(tx.Statement.Model) == nil {
- if f := tx.Statement.Schema.LookUpField(column); f != nil {
- column = f.DBName
- }
- }
- }
- if len(tx.Statement.Selects) != 1 {
- fields := strings.FieldsFunc(column, utils.IsValidDBNameChar)
- tx.Statement.AddClauseIfNotExists(clause.Select{
- Distinct: tx.Statement.Distinct,
- Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}},
- })
- }
- tx.Statement.Dest = dest
- return tx.callbacks.Query().Execute(tx)
- }
- func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
- tx := db.getInstance()
- if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) {
- tx.AddError(err)
- }
- tx.Statement.Dest = dest
- tx.Statement.ReflectValue = reflect.ValueOf(dest)
- for tx.Statement.ReflectValue.Kind() == reflect.Ptr {
- elem := tx.Statement.ReflectValue.Elem()
- if !elem.IsValid() {
- elem = reflect.New(tx.Statement.ReflectValue.Type().Elem())
- tx.Statement.ReflectValue.Set(elem)
- }
- tx.Statement.ReflectValue = elem
- }
- Scan(rows, tx, ScanInitialized)
- return tx.Error
- }
- // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is
- // returned to the connection pool.
- func (db *DB) Connection(fc func(tx *DB) error) (err error) {
- if db.Error != nil {
- return db.Error
- }
- tx := db.getInstance()
- sqlDB, err := tx.DB()
- if err != nil {
- return
- }
- conn, err := sqlDB.Conn(tx.Statement.Context)
- if err != nil {
- return
- }
- defer conn.Close()
- tx.Statement.ConnPool = conn
- return fc(tx)
- }
- // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an
- // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs
- // they are rolled back.
- func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
- panicked := true
- if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
- // nested transaction
- if !db.DisableNestedTransaction {
- spID := new(maphash.Hash).Sum64()
- err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
- if err != nil {
- return
- }
- defer func() {
- // Make sure to rollback when panic, Block error or Commit error
- if panicked || err != nil {
- db.RollbackTo(fmt.Sprintf("sp%d", spID))
- }
- }()
- }
- err = fc(db.Session(&Session{NewDB: db.clone == 1}))
- } else {
- tx := db.Begin(opts...)
- if tx.Error != nil {
- return tx.Error
- }
- defer func() {
- // Make sure to rollback when panic, Block error or Commit error
- if panicked || err != nil {
- tx.Rollback()
- }
- }()
- if err = fc(tx); err == nil {
- panicked = false
- return tx.Commit().Error
- }
- }
- panicked = false
- return
- }
- // Begin begins a transaction with any transaction options opts
- func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
- var (
- // clone statement
- tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1})
- opt *sql.TxOptions
- err error
- )
- if len(opts) > 0 {
- opt = opts[0]
- }
- switch beginner := tx.Statement.ConnPool.(type) {
- case TxBeginner:
- tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
- case ConnPoolBeginner:
- tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
- default:
- err = ErrInvalidTransaction
- }
- if err != nil {
- tx.AddError(err)
- }
- return tx
- }
- // Commit commits the changes in a transaction
- func (db *DB) Commit() *DB {
- if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() {
- db.AddError(committer.Commit())
- } else {
- db.AddError(ErrInvalidTransaction)
- }
- return db
- }
- // Rollback rollbacks the changes in a transaction
- func (db *DB) Rollback() *DB {
- if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
- if !reflect.ValueOf(committer).IsNil() {
- db.AddError(committer.Rollback())
- }
- } else {
- db.AddError(ErrInvalidTransaction)
- }
- return db
- }
- func (db *DB) SavePoint(name string) *DB {
- if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
- // close prepared statement, because SavePoint not support prepared statement.
- // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
- var (
- preparedStmtTx *PreparedStmtTX
- isPreparedStmtTx bool
- )
- // close prepared statement, because SavePoint not support prepared statement.
- if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
- db.Statement.ConnPool = preparedStmtTx.Tx
- }
- db.AddError(savePointer.SavePoint(db, name))
- // restore prepared statement
- if isPreparedStmtTx {
- db.Statement.ConnPool = preparedStmtTx
- }
- } else {
- db.AddError(ErrUnsupportedDriver)
- }
- return db
- }
- func (db *DB) RollbackTo(name string) *DB {
- if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
- // close prepared statement, because RollbackTo not support prepared statement.
- // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html
- var (
- preparedStmtTx *PreparedStmtTX
- isPreparedStmtTx bool
- )
- // close prepared statement, because SavePoint not support prepared statement.
- if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx {
- db.Statement.ConnPool = preparedStmtTx.Tx
- }
- db.AddError(savePointer.RollbackTo(db, name))
- // restore prepared statement
- if isPreparedStmtTx {
- db.Statement.ConnPool = preparedStmtTx
- }
- } else {
- db.AddError(ErrUnsupportedDriver)
- }
- return db
- }
- // Exec executes raw sql
- func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
- tx = db.getInstance()
- tx.Statement.SQL = strings.Builder{}
- if strings.Contains(sql, "@") {
- clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
- } else {
- clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
- }
- return tx.callbacks.Raw().Execute(tx)
- }
|