gorm.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "fmt"
  6. "reflect"
  7. "sort"
  8. "sync"
  9. "time"
  10. "gorm.io/gorm/clause"
  11. "gorm.io/gorm/logger"
  12. "gorm.io/gorm/schema"
  13. )
  14. // for Config.cacheStore store PreparedStmtDB key
  15. const preparedStmtDBKey = "preparedStmt"
  16. // Config GORM config
  17. type Config struct {
  18. // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
  19. // You can disable it by setting `SkipDefaultTransaction` to true
  20. SkipDefaultTransaction bool
  21. // NamingStrategy tables, columns naming strategy
  22. NamingStrategy schema.Namer
  23. // FullSaveAssociations full save associations
  24. FullSaveAssociations bool
  25. // Logger
  26. Logger logger.Interface
  27. // NowFunc the function to be used when creating a new timestamp
  28. NowFunc func() time.Time
  29. // DryRun generate sql without execute
  30. DryRun bool
  31. // PrepareStmt executes the given query in cached statement
  32. PrepareStmt bool
  33. // PrepareStmt cache support LRU expired,
  34. // default maxsize=int64 Max value and ttl=1h
  35. PrepareStmtMaxSize int
  36. PrepareStmtTTL time.Duration
  37. // DisableAutomaticPing
  38. DisableAutomaticPing bool
  39. // DisableForeignKeyConstraintWhenMigrating
  40. DisableForeignKeyConstraintWhenMigrating bool
  41. // IgnoreRelationshipsWhenMigrating
  42. IgnoreRelationshipsWhenMigrating bool
  43. // DisableNestedTransaction disable nested transaction
  44. DisableNestedTransaction bool
  45. // AllowGlobalUpdate allow global update
  46. AllowGlobalUpdate bool
  47. // QueryFields executes the SQL query with all fields of the table
  48. QueryFields bool
  49. // CreateBatchSize default create batch size
  50. CreateBatchSize int
  51. // TranslateError enabling error translation
  52. TranslateError bool
  53. // PropagateUnscoped propagate Unscoped to every other nested statement
  54. PropagateUnscoped bool
  55. // ClauseBuilders clause builder
  56. ClauseBuilders map[string]clause.ClauseBuilder
  57. // ConnPool db conn pool
  58. ConnPool ConnPool
  59. // Dialector database dialector
  60. Dialector
  61. // Plugins registered plugins
  62. Plugins map[string]Plugin
  63. callbacks *callbacks
  64. cacheStore *sync.Map
  65. }
  66. // Apply update config to new config
  67. func (c *Config) Apply(config *Config) error {
  68. if config != c {
  69. *config = *c
  70. }
  71. return nil
  72. }
  73. // AfterInitialize initialize plugins after db connected
  74. func (c *Config) AfterInitialize(db *DB) error {
  75. if db != nil {
  76. for _, plugin := range c.Plugins {
  77. if err := plugin.Initialize(db); err != nil {
  78. return err
  79. }
  80. }
  81. }
  82. return nil
  83. }
  84. // Option gorm option interface
  85. type Option interface {
  86. Apply(*Config) error
  87. AfterInitialize(*DB) error
  88. }
  89. // DB GORM DB definition
  90. type DB struct {
  91. *Config
  92. Error error
  93. RowsAffected int64
  94. Statement *Statement
  95. clone int
  96. }
  97. // Session session config when create session with Session() method
  98. type Session struct {
  99. DryRun bool
  100. PrepareStmt bool
  101. NewDB bool
  102. Initialized bool
  103. SkipHooks bool
  104. SkipDefaultTransaction bool
  105. DisableNestedTransaction bool
  106. AllowGlobalUpdate bool
  107. FullSaveAssociations bool
  108. PropagateUnscoped bool
  109. QueryFields bool
  110. Context context.Context
  111. Logger logger.Interface
  112. NowFunc func() time.Time
  113. CreateBatchSize int
  114. }
  115. // Open initialize db session based on dialector
  116. func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
  117. config := &Config{}
  118. sort.Slice(opts, func(i, j int) bool {
  119. _, isConfig := opts[i].(*Config)
  120. _, isConfig2 := opts[j].(*Config)
  121. return isConfig && !isConfig2
  122. })
  123. for _, opt := range opts {
  124. if opt != nil {
  125. if applyErr := opt.Apply(config); applyErr != nil {
  126. return nil, applyErr
  127. }
  128. defer func(opt Option) {
  129. if errr := opt.AfterInitialize(db); errr != nil {
  130. err = errr
  131. }
  132. }(opt)
  133. }
  134. }
  135. if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
  136. if err = d.Apply(config); err != nil {
  137. return
  138. }
  139. }
  140. if config.NamingStrategy == nil {
  141. config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
  142. }
  143. if config.Logger == nil {
  144. config.Logger = logger.Default
  145. }
  146. if config.NowFunc == nil {
  147. config.NowFunc = func() time.Time { return time.Now().Local() }
  148. }
  149. if dialector != nil {
  150. config.Dialector = dialector
  151. }
  152. if config.Plugins == nil {
  153. config.Plugins = map[string]Plugin{}
  154. }
  155. if config.cacheStore == nil {
  156. config.cacheStore = &sync.Map{}
  157. }
  158. db = &DB{Config: config, clone: 1}
  159. db.callbacks = initializeCallbacks(db)
  160. if config.ClauseBuilders == nil {
  161. config.ClauseBuilders = map[string]clause.ClauseBuilder{}
  162. }
  163. if config.Dialector != nil {
  164. err = config.Dialector.Initialize(db)
  165. if err != nil {
  166. if db, _ := db.DB(); db != nil {
  167. _ = db.Close()
  168. }
  169. }
  170. if config.TranslateError {
  171. if _, ok := db.Dialector.(ErrorTranslator); !ok {
  172. config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
  173. }
  174. }
  175. }
  176. if config.PrepareStmt {
  177. preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL)
  178. db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
  179. db.ConnPool = preparedStmt
  180. }
  181. db.Statement = &Statement{
  182. DB: db,
  183. ConnPool: db.ConnPool,
  184. Context: context.Background(),
  185. Clauses: map[string]clause.Clause{},
  186. }
  187. if err == nil && !config.DisableAutomaticPing {
  188. if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
  189. err = pinger.Ping()
  190. }
  191. }
  192. if err != nil {
  193. config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
  194. }
  195. return
  196. }
  197. // Session create new db session
  198. func (db *DB) Session(config *Session) *DB {
  199. var (
  200. txConfig = *db.Config
  201. tx = &DB{
  202. Config: &txConfig,
  203. Statement: db.Statement,
  204. Error: db.Error,
  205. clone: 1,
  206. }
  207. )
  208. if config.CreateBatchSize > 0 {
  209. tx.Config.CreateBatchSize = config.CreateBatchSize
  210. }
  211. if config.SkipDefaultTransaction {
  212. tx.Config.SkipDefaultTransaction = true
  213. }
  214. if config.AllowGlobalUpdate {
  215. txConfig.AllowGlobalUpdate = true
  216. }
  217. if config.FullSaveAssociations {
  218. txConfig.FullSaveAssociations = true
  219. }
  220. if config.PropagateUnscoped {
  221. txConfig.PropagateUnscoped = true
  222. }
  223. if config.Context != nil || config.PrepareStmt || config.SkipHooks {
  224. tx.Statement = tx.Statement.clone()
  225. tx.Statement.DB = tx
  226. }
  227. if config.Context != nil {
  228. tx.Statement.Context = config.Context
  229. }
  230. if config.PrepareStmt {
  231. var preparedStmt *PreparedStmtDB
  232. if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
  233. preparedStmt = v.(*PreparedStmtDB)
  234. } else {
  235. preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL)
  236. db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
  237. }
  238. switch t := tx.Statement.ConnPool.(type) {
  239. case Tx:
  240. tx.Statement.ConnPool = &PreparedStmtTX{
  241. Tx: t,
  242. PreparedStmtDB: preparedStmt,
  243. }
  244. default:
  245. tx.Statement.ConnPool = &PreparedStmtDB{
  246. ConnPool: db.Config.ConnPool,
  247. Mux: preparedStmt.Mux,
  248. Stmts: preparedStmt.Stmts,
  249. }
  250. }
  251. txConfig.ConnPool = tx.Statement.ConnPool
  252. txConfig.PrepareStmt = true
  253. }
  254. if config.SkipHooks {
  255. tx.Statement.SkipHooks = true
  256. }
  257. if config.DisableNestedTransaction {
  258. txConfig.DisableNestedTransaction = true
  259. }
  260. if !config.NewDB {
  261. tx.clone = 2
  262. }
  263. if config.DryRun {
  264. tx.Config.DryRun = true
  265. }
  266. if config.QueryFields {
  267. tx.Config.QueryFields = true
  268. }
  269. if config.Logger != nil {
  270. tx.Config.Logger = config.Logger
  271. }
  272. if config.NowFunc != nil {
  273. tx.Config.NowFunc = config.NowFunc
  274. }
  275. if config.Initialized {
  276. tx = tx.getInstance()
  277. }
  278. return tx
  279. }
  280. // WithContext change current instance db's context to ctx
  281. func (db *DB) WithContext(ctx context.Context) *DB {
  282. return db.Session(&Session{Context: ctx})
  283. }
  284. // Debug start debug mode
  285. func (db *DB) Debug() (tx *DB) {
  286. tx = db.getInstance()
  287. return tx.Session(&Session{
  288. Logger: db.Logger.LogMode(logger.Info),
  289. })
  290. }
  291. // Set store value with key into current db instance's context
  292. func (db *DB) Set(key string, value interface{}) *DB {
  293. tx := db.getInstance()
  294. tx.Statement.Settings.Store(key, value)
  295. return tx
  296. }
  297. // Get get value with key from current db instance's context
  298. func (db *DB) Get(key string) (interface{}, bool) {
  299. return db.Statement.Settings.Load(key)
  300. }
  301. // InstanceSet store value with key into current db instance's context
  302. func (db *DB) InstanceSet(key string, value interface{}) *DB {
  303. tx := db.getInstance()
  304. tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
  305. return tx
  306. }
  307. // InstanceGet get value with key from current db instance's context
  308. func (db *DB) InstanceGet(key string) (interface{}, bool) {
  309. return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
  310. }
  311. // Callback returns callback manager
  312. func (db *DB) Callback() *callbacks {
  313. return db.callbacks
  314. }
  315. // AddError add error to db
  316. func (db *DB) AddError(err error) error {
  317. if err != nil {
  318. if db.Config.TranslateError {
  319. if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
  320. err = errTranslator.Translate(err)
  321. }
  322. }
  323. if db.Error == nil {
  324. db.Error = err
  325. } else {
  326. db.Error = fmt.Errorf("%v; %w", db.Error, err)
  327. }
  328. }
  329. return db.Error
  330. }
  331. // DB returns `*sql.DB`
  332. func (db *DB) DB() (*sql.DB, error) {
  333. connPool := db.ConnPool
  334. if db.Statement != nil && db.Statement.ConnPool != nil {
  335. connPool = db.Statement.ConnPool
  336. }
  337. if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
  338. return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
  339. }
  340. if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
  341. if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
  342. return sqldb, err
  343. }
  344. }
  345. if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
  346. return sqldb, nil
  347. }
  348. return nil, ErrInvalidDB
  349. }
  350. func (db *DB) getInstance() *DB {
  351. if db.clone > 0 {
  352. tx := &DB{Config: db.Config, Error: db.Error}
  353. if db.clone == 1 {
  354. // clone with new statement
  355. tx.Statement = &Statement{
  356. DB: tx,
  357. ConnPool: db.Statement.ConnPool,
  358. Context: db.Statement.Context,
  359. Clauses: map[string]clause.Clause{},
  360. Vars: make([]interface{}, 0, 8),
  361. SkipHooks: db.Statement.SkipHooks,
  362. }
  363. if db.Config.PropagateUnscoped {
  364. tx.Statement.Unscoped = db.Statement.Unscoped
  365. }
  366. } else {
  367. // with clone statement
  368. tx.Statement = db.Statement.clone()
  369. tx.Statement.DB = tx
  370. }
  371. return tx
  372. }
  373. return db
  374. }
  375. // Expr returns clause.Expr, which can be used to pass SQL expression as params
  376. func Expr(expr string, args ...interface{}) clause.Expr {
  377. return clause.Expr{SQL: expr, Vars: args}
  378. }
  379. // SetupJoinTable setup join table schema
  380. func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
  381. var (
  382. tx = db.getInstance()
  383. stmt = tx.Statement
  384. modelSchema, joinSchema *schema.Schema
  385. )
  386. err := stmt.Parse(model)
  387. if err != nil {
  388. return err
  389. }
  390. modelSchema = stmt.Schema
  391. err = stmt.Parse(joinTable)
  392. if err != nil {
  393. return err
  394. }
  395. joinSchema = stmt.Schema
  396. relation, ok := modelSchema.Relationships.Relations[field]
  397. isRelation := ok && relation.JoinTable != nil
  398. if !isRelation {
  399. return fmt.Errorf("failed to find relation: %s", field)
  400. }
  401. for _, ref := range relation.References {
  402. f := joinSchema.LookUpField(ref.ForeignKey.DBName)
  403. if f == nil {
  404. return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
  405. }
  406. f.DataType = ref.ForeignKey.DataType
  407. f.GORMDataType = ref.ForeignKey.GORMDataType
  408. if f.Size == 0 {
  409. f.Size = ref.ForeignKey.Size
  410. }
  411. ref.ForeignKey = f
  412. }
  413. for name, rel := range relation.JoinTable.Relationships.Relations {
  414. if _, ok := joinSchema.Relationships.Relations[name]; !ok {
  415. rel.Schema = joinSchema
  416. joinSchema.Relationships.Relations[name] = rel
  417. }
  418. }
  419. relation.JoinTable = joinSchema
  420. return nil
  421. }
  422. // Use use plugin
  423. func (db *DB) Use(plugin Plugin) error {
  424. name := plugin.Name()
  425. if _, ok := db.Plugins[name]; ok {
  426. return ErrRegistered
  427. }
  428. if err := plugin.Initialize(db); err != nil {
  429. return err
  430. }
  431. db.Plugins[name] = plugin
  432. return nil
  433. }
  434. // ToSQL for generate SQL string.
  435. //
  436. // db.ToSQL(func(tx *gorm.DB) *gorm.DB {
  437. // return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
  438. // .Limit(10).Offset(5)
  439. // .Order("name ASC")
  440. // .First(&User{})
  441. // })
  442. func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
  443. tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
  444. stmt := tx.Statement
  445. return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)
  446. }