schema.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. package schema
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "go/ast"
  7. "reflect"
  8. "strings"
  9. "sync"
  10. "gorm.io/gorm/clause"
  11. "gorm.io/gorm/logger"
  12. )
  13. type callbackType string
  14. const (
  15. callbackTypeBeforeCreate callbackType = "BeforeCreate"
  16. callbackTypeBeforeUpdate callbackType = "BeforeUpdate"
  17. callbackTypeAfterCreate callbackType = "AfterCreate"
  18. callbackTypeAfterUpdate callbackType = "AfterUpdate"
  19. callbackTypeBeforeSave callbackType = "BeforeSave"
  20. callbackTypeAfterSave callbackType = "AfterSave"
  21. callbackTypeBeforeDelete callbackType = "BeforeDelete"
  22. callbackTypeAfterDelete callbackType = "AfterDelete"
  23. callbackTypeAfterFind callbackType = "AfterFind"
  24. )
  25. // ErrUnsupportedDataType unsupported data type
  26. var ErrUnsupportedDataType = errors.New("unsupported data type")
  27. type Schema struct {
  28. Name string
  29. ModelType reflect.Type
  30. Table string
  31. PrioritizedPrimaryField *Field
  32. DBNames []string
  33. PrimaryFields []*Field
  34. PrimaryFieldDBNames []string
  35. Fields []*Field
  36. FieldsByName map[string]*Field
  37. FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field'
  38. FieldsByDBName map[string]*Field
  39. FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
  40. Relationships Relationships
  41. CreateClauses []clause.Interface
  42. QueryClauses []clause.Interface
  43. UpdateClauses []clause.Interface
  44. DeleteClauses []clause.Interface
  45. BeforeCreate, AfterCreate bool
  46. BeforeUpdate, AfterUpdate bool
  47. BeforeDelete, AfterDelete bool
  48. BeforeSave, AfterSave bool
  49. AfterFind bool
  50. err error
  51. initialized chan struct{}
  52. namer Namer
  53. cacheStore *sync.Map
  54. }
  55. func (schema Schema) String() string {
  56. if schema.ModelType.Name() == "" {
  57. return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
  58. }
  59. return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
  60. }
  61. func (schema Schema) MakeSlice() reflect.Value {
  62. slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20)
  63. results := reflect.New(slice.Type())
  64. results.Elem().Set(slice)
  65. return results
  66. }
  67. func (schema Schema) LookUpField(name string) *Field {
  68. if field, ok := schema.FieldsByDBName[name]; ok {
  69. return field
  70. }
  71. if field, ok := schema.FieldsByName[name]; ok {
  72. return field
  73. }
  74. return nil
  75. }
  76. // LookUpFieldByBindName looks for the closest field in the embedded struct.
  77. //
  78. // type Struct struct {
  79. // Embedded struct {
  80. // ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID")
  81. // }
  82. // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID")
  83. // }
  84. func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field {
  85. if len(bindNames) == 0 {
  86. return nil
  87. }
  88. for i := len(bindNames) - 1; i >= 0; i-- {
  89. find := strings.Join(bindNames[:i], ".") + "." + name
  90. if field, ok := schema.FieldsByBindName[find]; ok {
  91. return field
  92. }
  93. }
  94. return nil
  95. }
  96. type Tabler interface {
  97. TableName() string
  98. }
  99. type TablerWithNamer interface {
  100. TableName(Namer) string
  101. }
  102. // Parse get data type from dialector
  103. func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
  104. return ParseWithSpecialTableName(dest, cacheStore, namer, "")
  105. }
  106. // ParseWithSpecialTableName get data type from dialector with extra schema table
  107. func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
  108. if dest == nil {
  109. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  110. }
  111. value := reflect.ValueOf(dest)
  112. if value.Kind() == reflect.Ptr && value.IsNil() {
  113. value = reflect.New(value.Type().Elem())
  114. }
  115. modelType := reflect.Indirect(value).Type()
  116. if modelType.Kind() == reflect.Interface {
  117. modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
  118. }
  119. for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
  120. modelType = modelType.Elem()
  121. }
  122. if modelType.Kind() != reflect.Struct {
  123. if modelType.PkgPath() == "" {
  124. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  125. }
  126. return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  127. }
  128. // Cache the Schema for performance,
  129. // Use the modelType or modelType + schemaTable (if it present) as cache key.
  130. var schemaCacheKey interface{}
  131. if specialTableName != "" {
  132. schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
  133. } else {
  134. schemaCacheKey = modelType
  135. }
  136. // Load exist schema cache, return if exists
  137. if v, ok := cacheStore.Load(schemaCacheKey); ok {
  138. s := v.(*Schema)
  139. // Wait for the initialization of other goroutines to complete
  140. <-s.initialized
  141. return s, s.err
  142. }
  143. modelValue := reflect.New(modelType)
  144. tableName := namer.TableName(modelType.Name())
  145. if tabler, ok := modelValue.Interface().(Tabler); ok {
  146. tableName = tabler.TableName()
  147. }
  148. if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
  149. tableName = tabler.TableName(namer)
  150. }
  151. if en, ok := namer.(embeddedNamer); ok {
  152. tableName = en.Table
  153. }
  154. if specialTableName != "" && specialTableName != tableName {
  155. tableName = specialTableName
  156. }
  157. schema := &Schema{
  158. Name: modelType.Name(),
  159. ModelType: modelType,
  160. Table: tableName,
  161. FieldsByName: map[string]*Field{},
  162. FieldsByBindName: map[string]*Field{},
  163. FieldsByDBName: map[string]*Field{},
  164. Relationships: Relationships{Relations: map[string]*Relationship{}},
  165. cacheStore: cacheStore,
  166. namer: namer,
  167. initialized: make(chan struct{}),
  168. }
  169. // When the schema initialization is completed, the channel will be closed
  170. defer close(schema.initialized)
  171. // Load exist schema cache, return if exists
  172. if v, ok := cacheStore.Load(schemaCacheKey); ok {
  173. s := v.(*Schema)
  174. // Wait for the initialization of other goroutines to complete
  175. <-s.initialized
  176. return s, s.err
  177. }
  178. for i := 0; i < modelType.NumField(); i++ {
  179. if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
  180. if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
  181. schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
  182. } else {
  183. schema.Fields = append(schema.Fields, field)
  184. }
  185. }
  186. }
  187. for _, field := range schema.Fields {
  188. if field.DBName == "" && field.DataType != "" {
  189. field.DBName = namer.ColumnName(schema.Table, field.Name)
  190. }
  191. bindName := field.BindName()
  192. if field.DBName != "" {
  193. // nonexistence or shortest path or first appear prioritized if has permission
  194. if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
  195. if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
  196. schema.DBNames = append(schema.DBNames, field.DBName)
  197. }
  198. schema.FieldsByDBName[field.DBName] = field
  199. schema.FieldsByName[field.Name] = field
  200. schema.FieldsByBindName[bindName] = field
  201. if v != nil && v.PrimaryKey {
  202. for idx, f := range schema.PrimaryFields {
  203. if f == v {
  204. schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
  205. }
  206. }
  207. }
  208. if field.PrimaryKey {
  209. schema.PrimaryFields = append(schema.PrimaryFields, field)
  210. }
  211. }
  212. }
  213. if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
  214. schema.FieldsByName[field.Name] = field
  215. }
  216. if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" {
  217. schema.FieldsByBindName[bindName] = field
  218. }
  219. field.setupValuerAndSetter()
  220. }
  221. prioritizedPrimaryField := schema.LookUpField("id")
  222. if prioritizedPrimaryField == nil {
  223. prioritizedPrimaryField = schema.LookUpField("ID")
  224. }
  225. if prioritizedPrimaryField != nil {
  226. if prioritizedPrimaryField.PrimaryKey {
  227. schema.PrioritizedPrimaryField = prioritizedPrimaryField
  228. } else if len(schema.PrimaryFields) == 0 {
  229. prioritizedPrimaryField.PrimaryKey = true
  230. schema.PrioritizedPrimaryField = prioritizedPrimaryField
  231. schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
  232. }
  233. }
  234. if schema.PrioritizedPrimaryField == nil {
  235. if len(schema.PrimaryFields) == 1 {
  236. schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
  237. } else if len(schema.PrimaryFields) > 1 {
  238. // If there are multiple primary keys, the AUTOINCREMENT field is prioritized
  239. for _, field := range schema.PrimaryFields {
  240. if field.AutoIncrement {
  241. schema.PrioritizedPrimaryField = field
  242. break
  243. }
  244. }
  245. }
  246. }
  247. for _, field := range schema.PrimaryFields {
  248. schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
  249. }
  250. for _, field := range schema.Fields {
  251. if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil {
  252. schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
  253. }
  254. }
  255. if field := schema.PrioritizedPrimaryField; field != nil {
  256. switch field.GORMDataType {
  257. case Int, Uint:
  258. if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
  259. if !field.HasDefaultValue || field.DefaultValueInterface != nil {
  260. schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
  261. }
  262. field.HasDefaultValue = true
  263. field.AutoIncrement = true
  264. }
  265. }
  266. }
  267. callbackTypes := []callbackType{
  268. callbackTypeBeforeCreate, callbackTypeAfterCreate,
  269. callbackTypeBeforeUpdate, callbackTypeAfterUpdate,
  270. callbackTypeBeforeSave, callbackTypeAfterSave,
  271. callbackTypeBeforeDelete, callbackTypeAfterDelete,
  272. callbackTypeAfterFind,
  273. }
  274. for _, cbName := range callbackTypes {
  275. if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() {
  276. switch methodValue.Type().String() {
  277. case "func(*gorm.DB) error": // TODO hack
  278. reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true)
  279. default:
  280. logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName)
  281. }
  282. }
  283. }
  284. // Cache the schema
  285. if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
  286. s := v.(*Schema)
  287. // Wait for the initialization of other goroutines to complete
  288. <-s.initialized
  289. return s, s.err
  290. }
  291. defer func() {
  292. if schema.err != nil {
  293. logger.Default.Error(context.Background(), schema.err.Error())
  294. cacheStore.Delete(modelType)
  295. }
  296. }()
  297. if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
  298. for _, field := range schema.Fields {
  299. if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) {
  300. if schema.parseRelation(field); schema.err != nil {
  301. return schema, schema.err
  302. } else {
  303. schema.FieldsByName[field.Name] = field
  304. schema.FieldsByBindName[field.BindName()] = field
  305. }
  306. }
  307. fieldValue := reflect.New(field.IndirectFieldType)
  308. fieldInterface := fieldValue.Interface()
  309. if fc, ok := fieldInterface.(CreateClausesInterface); ok {
  310. field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
  311. }
  312. if fc, ok := fieldInterface.(QueryClausesInterface); ok {
  313. field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
  314. }
  315. if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
  316. field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
  317. }
  318. if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
  319. field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
  320. }
  321. }
  322. }
  323. return schema, schema.err
  324. }
  325. // This unrolling is needed to show to the compiler the exact set of methods
  326. // that can be used on the modelType.
  327. // Prior to go1.22 any use of MethodByName would cause the linker to
  328. // abandon dead code elimination for the entire binary.
  329. // As of go1.22 the compiler supports one special case of a string constant
  330. // being passed to MethodByName. For enterprise customers or those building
  331. // large binaries, this gives a significant reduction in binary size.
  332. // https://github.com/golang/go/issues/62257
  333. func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value {
  334. switch cbType {
  335. case callbackTypeBeforeCreate:
  336. return modelType.MethodByName(string(callbackTypeBeforeCreate))
  337. case callbackTypeAfterCreate:
  338. return modelType.MethodByName(string(callbackTypeAfterCreate))
  339. case callbackTypeBeforeUpdate:
  340. return modelType.MethodByName(string(callbackTypeBeforeUpdate))
  341. case callbackTypeAfterUpdate:
  342. return modelType.MethodByName(string(callbackTypeAfterUpdate))
  343. case callbackTypeBeforeSave:
  344. return modelType.MethodByName(string(callbackTypeBeforeSave))
  345. case callbackTypeAfterSave:
  346. return modelType.MethodByName(string(callbackTypeAfterSave))
  347. case callbackTypeBeforeDelete:
  348. return modelType.MethodByName(string(callbackTypeBeforeDelete))
  349. case callbackTypeAfterDelete:
  350. return modelType.MethodByName(string(callbackTypeAfterDelete))
  351. case callbackTypeAfterFind:
  352. return modelType.MethodByName(string(callbackTypeAfterFind))
  353. default:
  354. return reflect.ValueOf(nil)
  355. }
  356. }
  357. func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
  358. modelType := reflect.ValueOf(dest).Type()
  359. for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
  360. modelType = modelType.Elem()
  361. }
  362. if modelType.Kind() != reflect.Struct {
  363. if modelType.PkgPath() == "" {
  364. return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
  365. }
  366. return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
  367. }
  368. if v, ok := cacheStore.Load(modelType); ok {
  369. return v.(*Schema), nil
  370. }
  371. return Parse(dest, cacheStore, namer)
  372. }