statement.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. package gorm
  2. import (
  3. "context"
  4. "database/sql"
  5. "database/sql/driver"
  6. "fmt"
  7. "reflect"
  8. "regexp"
  9. "sort"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "gorm.io/gorm/clause"
  14. "gorm.io/gorm/logger"
  15. "gorm.io/gorm/schema"
  16. "gorm.io/gorm/utils"
  17. )
  18. // Statement statement
  19. type Statement struct {
  20. *DB
  21. TableExpr *clause.Expr
  22. Table string
  23. Model interface{}
  24. Unscoped bool
  25. Dest interface{}
  26. ReflectValue reflect.Value
  27. Clauses map[string]clause.Clause
  28. BuildClauses []string
  29. Distinct bool
  30. Selects []string // selected columns
  31. Omits []string // omit columns
  32. ColumnMapping map[string]string // map columns
  33. Joins []join
  34. Preloads map[string][]interface{}
  35. Settings sync.Map
  36. ConnPool ConnPool
  37. Schema *schema.Schema
  38. Context context.Context
  39. RaiseErrorOnNotFound bool
  40. SkipHooks bool
  41. SQL strings.Builder
  42. Vars []interface{}
  43. CurDestIndex int
  44. attrs []interface{}
  45. assigns []interface{}
  46. scopes []func(*DB) *DB
  47. }
  48. type join struct {
  49. Name string
  50. Conds []interface{}
  51. On *clause.Where
  52. Selects []string
  53. Omits []string
  54. JoinType clause.JoinType
  55. }
  56. // StatementModifier statement modifier interface
  57. type StatementModifier interface {
  58. ModifyStatement(*Statement)
  59. }
  60. // WriteString write string
  61. func (stmt *Statement) WriteString(str string) (int, error) {
  62. return stmt.SQL.WriteString(str)
  63. }
  64. // WriteByte write byte
  65. func (stmt *Statement) WriteByte(c byte) error {
  66. return stmt.SQL.WriteByte(c)
  67. }
  68. // WriteQuoted write quoted value
  69. func (stmt *Statement) WriteQuoted(value interface{}) {
  70. stmt.QuoteTo(&stmt.SQL, value)
  71. }
  72. // QuoteTo write quoted value to writer
  73. func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
  74. write := func(raw bool, str string) {
  75. if raw {
  76. writer.WriteString(str)
  77. } else {
  78. stmt.DB.Dialector.QuoteTo(writer, str)
  79. }
  80. }
  81. switch v := field.(type) {
  82. case clause.Table:
  83. if v.Name == clause.CurrentTable {
  84. if stmt.TableExpr != nil {
  85. stmt.TableExpr.Build(stmt)
  86. } else {
  87. write(v.Raw, stmt.Table)
  88. }
  89. } else {
  90. write(v.Raw, v.Name)
  91. }
  92. if v.Alias != "" {
  93. writer.WriteByte(' ')
  94. write(v.Raw, v.Alias)
  95. }
  96. case clause.Column:
  97. if v.Table != "" {
  98. if v.Table == clause.CurrentTable {
  99. write(v.Raw, stmt.Table)
  100. } else {
  101. write(v.Raw, v.Table)
  102. }
  103. writer.WriteByte('.')
  104. }
  105. if v.Name == clause.PrimaryKey {
  106. if stmt.Schema == nil {
  107. stmt.DB.AddError(ErrModelValueRequired)
  108. } else if stmt.Schema.PrioritizedPrimaryField != nil {
  109. write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
  110. } else if len(stmt.Schema.DBNames) > 0 {
  111. write(v.Raw, stmt.Schema.DBNames[0])
  112. } else {
  113. stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
  114. }
  115. } else {
  116. write(v.Raw, v.Name)
  117. }
  118. if v.Alias != "" {
  119. writer.WriteString(" AS ")
  120. write(v.Raw, v.Alias)
  121. }
  122. case []clause.Column:
  123. writer.WriteByte('(')
  124. for idx, d := range v {
  125. if idx > 0 {
  126. writer.WriteByte(',')
  127. }
  128. stmt.QuoteTo(writer, d)
  129. }
  130. writer.WriteByte(')')
  131. case clause.Expr:
  132. v.Build(stmt)
  133. case string:
  134. stmt.DB.Dialector.QuoteTo(writer, v)
  135. case []string:
  136. writer.WriteByte('(')
  137. for idx, d := range v {
  138. if idx > 0 {
  139. writer.WriteByte(',')
  140. }
  141. stmt.DB.Dialector.QuoteTo(writer, d)
  142. }
  143. writer.WriteByte(')')
  144. default:
  145. stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
  146. }
  147. }
  148. // Quote returns quoted value
  149. func (stmt *Statement) Quote(field interface{}) string {
  150. var builder strings.Builder
  151. stmt.QuoteTo(&builder, field)
  152. return builder.String()
  153. }
  154. // AddVar add var
  155. func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
  156. for idx, v := range vars {
  157. if idx > 0 {
  158. writer.WriteByte(',')
  159. }
  160. switch v := v.(type) {
  161. case sql.NamedArg:
  162. stmt.Vars = append(stmt.Vars, v.Value)
  163. case clause.Column, clause.Table:
  164. stmt.QuoteTo(writer, v)
  165. case Valuer:
  166. reflectValue := reflect.ValueOf(v)
  167. if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
  168. stmt.AddVar(writer, nil)
  169. } else {
  170. stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
  171. }
  172. case clause.Interface:
  173. c := clause.Clause{Name: v.Name()}
  174. v.MergeClause(&c)
  175. c.Build(stmt)
  176. case clause.Expression:
  177. v.Build(stmt)
  178. case driver.Valuer:
  179. stmt.Vars = append(stmt.Vars, v)
  180. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  181. case []byte:
  182. stmt.Vars = append(stmt.Vars, v)
  183. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  184. case []interface{}:
  185. if len(v) > 0 {
  186. writer.WriteByte('(')
  187. stmt.AddVar(writer, v...)
  188. writer.WriteByte(')')
  189. } else {
  190. writer.WriteString("(NULL)")
  191. }
  192. case *DB:
  193. subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
  194. if v.Statement.SQL.Len() > 0 {
  195. var (
  196. vars = subdb.Statement.Vars
  197. sql = v.Statement.SQL.String()
  198. )
  199. subdb.Statement.Vars = make([]interface{}, 0, len(vars))
  200. for _, vv := range vars {
  201. subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
  202. bindvar := strings.Builder{}
  203. v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
  204. sql = strings.Replace(sql, bindvar.String(), "?", 1)
  205. }
  206. subdb.Statement.SQL.Reset()
  207. subdb.Statement.Vars = stmt.Vars
  208. if strings.Contains(sql, "@") {
  209. clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
  210. } else {
  211. clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
  212. }
  213. } else {
  214. subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
  215. subdb.callbacks.Query().Execute(subdb)
  216. }
  217. writer.WriteString(subdb.Statement.SQL.String())
  218. stmt.Vars = subdb.Statement.Vars
  219. default:
  220. switch rv := reflect.ValueOf(v); rv.Kind() {
  221. case reflect.Slice, reflect.Array:
  222. if rv.Len() == 0 {
  223. writer.WriteString("(NULL)")
  224. } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
  225. stmt.Vars = append(stmt.Vars, v)
  226. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  227. } else {
  228. writer.WriteByte('(')
  229. for i := 0; i < rv.Len(); i++ {
  230. if i > 0 {
  231. writer.WriteByte(',')
  232. }
  233. stmt.AddVar(writer, rv.Index(i).Interface())
  234. }
  235. writer.WriteByte(')')
  236. }
  237. default:
  238. stmt.Vars = append(stmt.Vars, v)
  239. stmt.DB.Dialector.BindVarTo(writer, stmt, v)
  240. }
  241. }
  242. }
  243. }
  244. // AddClause add clause
  245. func (stmt *Statement) AddClause(v clause.Interface) {
  246. if optimizer, ok := v.(StatementModifier); ok {
  247. optimizer.ModifyStatement(stmt)
  248. } else {
  249. name := v.Name()
  250. c := stmt.Clauses[name]
  251. c.Name = name
  252. v.MergeClause(&c)
  253. stmt.Clauses[name] = c
  254. }
  255. }
  256. // AddClauseIfNotExists add clause if not exists
  257. func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
  258. if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
  259. stmt.AddClause(v)
  260. }
  261. }
  262. // BuildCondition build condition
  263. func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
  264. if s, ok := query.(string); ok {
  265. // if it is a number, then treats it as primary key
  266. if _, err := strconv.Atoi(s); err != nil {
  267. if s == "" && len(args) == 0 {
  268. return nil
  269. }
  270. if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
  271. // looks like a where condition
  272. return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
  273. }
  274. if len(args) > 0 && strings.Contains(s, "@") {
  275. // looks like a named query
  276. return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
  277. }
  278. if strings.Contains(strings.TrimSpace(s), " ") {
  279. // looks like a where condition
  280. return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
  281. }
  282. if len(args) == 1 {
  283. return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
  284. }
  285. }
  286. }
  287. conds := make([]clause.Expression, 0, 4)
  288. args = append([]interface{}{query}, args...)
  289. for idx, arg := range args {
  290. if arg == nil {
  291. continue
  292. }
  293. if valuer, ok := arg.(driver.Valuer); ok {
  294. arg, _ = valuer.Value()
  295. }
  296. switch v := arg.(type) {
  297. case clause.Expression:
  298. conds = append(conds, v)
  299. case *DB:
  300. v.executeScopes()
  301. if cs, ok := v.Statement.Clauses["WHERE"]; ok {
  302. if where, ok := cs.Expression.(clause.Where); ok {
  303. if len(where.Exprs) == 1 {
  304. if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
  305. where.Exprs[0] = clause.AndConditions(orConds)
  306. }
  307. }
  308. conds = append(conds, clause.And(where.Exprs...))
  309. } else if cs.Expression != nil {
  310. conds = append(conds, cs.Expression)
  311. }
  312. }
  313. case map[interface{}]interface{}:
  314. for i, j := range v {
  315. conds = append(conds, clause.Eq{Column: i, Value: j})
  316. }
  317. case map[string]string:
  318. keys := make([]string, 0, len(v))
  319. for i := range v {
  320. keys = append(keys, i)
  321. }
  322. sort.Strings(keys)
  323. for _, key := range keys {
  324. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  325. }
  326. case map[string]interface{}:
  327. keys := make([]string, 0, len(v))
  328. for i := range v {
  329. keys = append(keys, i)
  330. }
  331. sort.Strings(keys)
  332. for _, key := range keys {
  333. reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
  334. switch reflectValue.Kind() {
  335. case reflect.Slice, reflect.Array:
  336. if _, ok := v[key].(driver.Valuer); ok {
  337. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  338. } else if _, ok := v[key].(Valuer); ok {
  339. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  340. } else {
  341. // optimize reflect value length
  342. valueLen := reflectValue.Len()
  343. values := make([]interface{}, valueLen)
  344. for i := 0; i < valueLen; i++ {
  345. values[i] = reflectValue.Index(i).Interface()
  346. }
  347. conds = append(conds, clause.IN{Column: key, Values: values})
  348. }
  349. default:
  350. conds = append(conds, clause.Eq{Column: key, Value: v[key]})
  351. }
  352. }
  353. default:
  354. reflectValue := reflect.Indirect(reflect.ValueOf(arg))
  355. for reflectValue.Kind() == reflect.Ptr {
  356. reflectValue = reflectValue.Elem()
  357. }
  358. if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
  359. selectedColumns := map[string]bool{}
  360. if idx == 0 {
  361. for _, v := range args[1:] {
  362. if vs, ok := v.(string); ok {
  363. selectedColumns[vs] = true
  364. }
  365. }
  366. }
  367. restricted := len(selectedColumns) != 0
  368. switch reflectValue.Kind() {
  369. case reflect.Struct:
  370. for _, field := range s.Fields {
  371. selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
  372. if selected || (!restricted && field.Readable) {
  373. if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
  374. if field.DBName != "" {
  375. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
  376. } else if field.DataType != "" {
  377. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
  378. }
  379. }
  380. }
  381. }
  382. case reflect.Slice, reflect.Array:
  383. for i := 0; i < reflectValue.Len(); i++ {
  384. for _, field := range s.Fields {
  385. selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
  386. if selected || (!restricted && field.Readable) {
  387. if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
  388. if field.DBName != "" {
  389. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
  390. } else if field.DataType != "" {
  391. conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
  392. }
  393. }
  394. }
  395. }
  396. }
  397. }
  398. if restricted {
  399. break
  400. }
  401. } else if !reflectValue.IsValid() {
  402. stmt.AddError(ErrInvalidData)
  403. } else if len(conds) == 0 {
  404. if len(args) == 1 {
  405. switch reflectValue.Kind() {
  406. case reflect.Slice, reflect.Array:
  407. // optimize reflect value length
  408. valueLen := reflectValue.Len()
  409. values := make([]interface{}, valueLen)
  410. for i := 0; i < valueLen; i++ {
  411. values[i] = reflectValue.Index(i).Interface()
  412. }
  413. if len(values) > 0 {
  414. conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
  415. return []clause.Expression{clause.And(conds...)}
  416. }
  417. return nil
  418. }
  419. }
  420. conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
  421. }
  422. }
  423. }
  424. if len(conds) > 0 {
  425. return []clause.Expression{clause.And(conds...)}
  426. }
  427. return nil
  428. }
  429. // Build build sql with clauses names
  430. func (stmt *Statement) Build(clauses ...string) {
  431. var firstClauseWritten bool
  432. for _, name := range clauses {
  433. if c, ok := stmt.Clauses[name]; ok {
  434. if firstClauseWritten {
  435. stmt.WriteByte(' ')
  436. }
  437. firstClauseWritten = true
  438. if b, ok := stmt.DB.ClauseBuilders[name]; ok {
  439. b(c, stmt)
  440. } else {
  441. c.Build(stmt)
  442. }
  443. }
  444. }
  445. }
  446. func (stmt *Statement) Parse(value interface{}) (err error) {
  447. return stmt.ParseWithSpecialTableName(value, "")
  448. }
  449. func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
  450. if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
  451. if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
  452. stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
  453. stmt.Table = tables[1]
  454. return
  455. }
  456. stmt.Table = stmt.Schema.Table
  457. }
  458. return err
  459. }
  460. func (stmt *Statement) clone() *Statement {
  461. newStmt := &Statement{
  462. TableExpr: stmt.TableExpr,
  463. Table: stmt.Table,
  464. Model: stmt.Model,
  465. Unscoped: stmt.Unscoped,
  466. Dest: stmt.Dest,
  467. ReflectValue: stmt.ReflectValue,
  468. Clauses: map[string]clause.Clause{},
  469. Distinct: stmt.Distinct,
  470. Selects: stmt.Selects,
  471. Omits: stmt.Omits,
  472. ColumnMapping: stmt.ColumnMapping,
  473. Preloads: map[string][]interface{}{},
  474. ConnPool: stmt.ConnPool,
  475. Schema: stmt.Schema,
  476. Context: stmt.Context,
  477. RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
  478. SkipHooks: stmt.SkipHooks,
  479. }
  480. if stmt.SQL.Len() > 0 {
  481. newStmt.SQL.WriteString(stmt.SQL.String())
  482. newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
  483. newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
  484. }
  485. for k, c := range stmt.Clauses {
  486. newStmt.Clauses[k] = c
  487. }
  488. for k, p := range stmt.Preloads {
  489. newStmt.Preloads[k] = p
  490. }
  491. if len(stmt.Joins) > 0 {
  492. newStmt.Joins = make([]join, len(stmt.Joins))
  493. copy(newStmt.Joins, stmt.Joins)
  494. }
  495. if len(stmt.scopes) > 0 {
  496. newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
  497. copy(newStmt.scopes, stmt.scopes)
  498. }
  499. stmt.Settings.Range(func(k, v interface{}) bool {
  500. newStmt.Settings.Store(k, v)
  501. return true
  502. })
  503. return newStmt
  504. }
  505. // SetColumn set column's value
  506. //
  507. // stmt.SetColumn("Name", "jinzhu") // Hooks Method
  508. // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
  509. func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
  510. if v, ok := stmt.Dest.(map[string]interface{}); ok {
  511. v[name] = value
  512. } else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
  513. for _, m := range v {
  514. m[name] = value
  515. }
  516. } else if stmt.Schema != nil {
  517. if field := stmt.Schema.LookUpField(name); field != nil {
  518. destValue := reflect.ValueOf(stmt.Dest)
  519. for destValue.Kind() == reflect.Ptr {
  520. destValue = destValue.Elem()
  521. }
  522. if stmt.ReflectValue != destValue {
  523. if !destValue.CanAddr() {
  524. destValueCanAddr := reflect.New(destValue.Type())
  525. destValueCanAddr.Elem().Set(destValue)
  526. stmt.Dest = destValueCanAddr.Interface()
  527. destValue = destValueCanAddr.Elem()
  528. }
  529. switch destValue.Kind() {
  530. case reflect.Struct:
  531. stmt.AddError(field.Set(stmt.Context, destValue, value))
  532. default:
  533. stmt.AddError(ErrInvalidData)
  534. }
  535. }
  536. switch stmt.ReflectValue.Kind() {
  537. case reflect.Slice, reflect.Array:
  538. if len(fromCallbacks) > 0 {
  539. for i := 0; i < stmt.ReflectValue.Len(); i++ {
  540. stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
  541. }
  542. } else {
  543. stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
  544. }
  545. case reflect.Struct:
  546. if !stmt.ReflectValue.CanAddr() {
  547. stmt.AddError(ErrInvalidValue)
  548. return
  549. }
  550. stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
  551. }
  552. } else {
  553. stmt.AddError(ErrInvalidField)
  554. }
  555. } else {
  556. stmt.AddError(ErrInvalidField)
  557. }
  558. }
  559. // Changed check model changed or not when updating
  560. func (stmt *Statement) Changed(fields ...string) bool {
  561. modelValue := stmt.ReflectValue
  562. switch modelValue.Kind() {
  563. case reflect.Slice, reflect.Array:
  564. modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
  565. }
  566. selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
  567. changed := func(field *schema.Field) bool {
  568. fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
  569. if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
  570. if mv, mok := stmt.Dest.(map[string]interface{}); mok {
  571. if fv, ok := mv[field.Name]; ok {
  572. return !utils.AssertEqual(fv, fieldValue)
  573. } else if fv, ok := mv[field.DBName]; ok {
  574. return !utils.AssertEqual(fv, fieldValue)
  575. }
  576. } else {
  577. destValue := reflect.ValueOf(stmt.Dest)
  578. for destValue.Kind() == reflect.Ptr {
  579. destValue = destValue.Elem()
  580. }
  581. changedValue, zero := field.ValueOf(stmt.Context, destValue)
  582. if v {
  583. return !utils.AssertEqual(changedValue, fieldValue)
  584. }
  585. return !zero && !utils.AssertEqual(changedValue, fieldValue)
  586. }
  587. }
  588. return false
  589. }
  590. if len(fields) == 0 {
  591. for _, field := range stmt.Schema.FieldsByDBName {
  592. if changed(field) {
  593. return true
  594. }
  595. }
  596. } else {
  597. for _, name := range fields {
  598. if field := stmt.Schema.LookUpField(name); field != nil {
  599. if changed(field) {
  600. return true
  601. }
  602. }
  603. }
  604. }
  605. return false
  606. }
  607. var matchName = func() func(tableColumn string) (table, column string) {
  608. nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
  609. return func(tableColumn string) (table, column string) {
  610. if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
  611. table = matches[1]
  612. star := matches[2]
  613. columnName := matches[3]
  614. if star != "" {
  615. return table, star
  616. }
  617. return table, columnName
  618. }
  619. return "", ""
  620. }
  621. }()
  622. // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
  623. func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
  624. results := map[string]bool{}
  625. notRestricted := false
  626. processColumn := func(column string, result bool) {
  627. if stmt.Schema == nil {
  628. results[column] = result
  629. } else if column == "*" {
  630. notRestricted = result
  631. for _, dbName := range stmt.Schema.DBNames {
  632. results[dbName] = result
  633. }
  634. } else if column == clause.Associations {
  635. for _, rel := range stmt.Schema.Relationships.Relations {
  636. results[rel.Name] = result
  637. }
  638. } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
  639. results[field.DBName] = result
  640. } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
  641. if col == "*" {
  642. for _, dbName := range stmt.Schema.DBNames {
  643. results[dbName] = result
  644. }
  645. } else {
  646. results[col] = result
  647. }
  648. } else {
  649. results[column] = result
  650. }
  651. }
  652. // select columns
  653. for _, column := range stmt.Selects {
  654. processColumn(column, true)
  655. }
  656. // omit columns
  657. for _, column := range stmt.Omits {
  658. processColumn(column, false)
  659. }
  660. if stmt.Schema != nil {
  661. for _, field := range stmt.Schema.FieldsByName {
  662. name := field.DBName
  663. if name == "" {
  664. name = field.Name
  665. }
  666. if requireCreate && !field.Creatable {
  667. results[name] = false
  668. } else if requireUpdate && !field.Updatable {
  669. results[name] = false
  670. }
  671. }
  672. }
  673. return results, !notRestricted && len(stmt.Selects) > 0
  674. }