preload.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. package callbacks
  2. import (
  3. "fmt"
  4. "reflect"
  5. "sort"
  6. "strings"
  7. "gorm.io/gorm"
  8. "gorm.io/gorm/clause"
  9. "gorm.io/gorm/schema"
  10. "gorm.io/gorm/utils"
  11. )
  12. // parsePreloadMap extracts nested preloads. e.g.
  13. //
  14. // // schema has a "k0" relation and a "k7.k8" embedded relation
  15. // parsePreloadMap(schema, map[string][]interface{}{
  16. // clause.Associations: {"arg1"},
  17. // "k1": {"arg2"},
  18. // "k2.k3": {"arg3"},
  19. // "k4.k5.k6": {"arg4"},
  20. // })
  21. // // preloadMap is
  22. // map[string]map[string][]interface{}{
  23. // "k0": {},
  24. // "k7": {
  25. // "k8": {},
  26. // },
  27. // "k1": {},
  28. // "k2": {
  29. // "k3": {"arg3"},
  30. // },
  31. // "k4": {
  32. // "k5.k6": {"arg4"},
  33. // },
  34. // }
  35. func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
  36. preloadMap := map[string]map[string][]interface{}{}
  37. setPreloadMap := func(name, value string, args []interface{}) {
  38. if _, ok := preloadMap[name]; !ok {
  39. preloadMap[name] = map[string][]interface{}{}
  40. }
  41. if value != "" {
  42. preloadMap[name][value] = args
  43. }
  44. }
  45. for name, args := range preloads {
  46. preloadFields := strings.Split(name, ".")
  47. value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
  48. if preloadFields[0] == clause.Associations {
  49. for _, relation := range s.Relationships.Relations {
  50. if relation.Schema == s {
  51. setPreloadMap(relation.Name, value, args)
  52. }
  53. }
  54. for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
  55. for _, value := range embeddedValues(embeddedRelations) {
  56. setPreloadMap(embedded, value, args)
  57. }
  58. }
  59. } else {
  60. setPreloadMap(preloadFields[0], value, args)
  61. }
  62. }
  63. return preloadMap
  64. }
  65. func embeddedValues(embeddedRelations *schema.Relationships) []string {
  66. if embeddedRelations == nil {
  67. return nil
  68. }
  69. names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
  70. for _, relation := range embeddedRelations.Relations {
  71. // skip first struct name
  72. names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
  73. }
  74. for _, relations := range embeddedRelations.EmbeddedRelations {
  75. names = append(names, embeddedValues(relations)...)
  76. }
  77. return names
  78. }
  79. // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
  80. // If the current relationship is embedded or joined, current query will be ignored.
  81. //
  82. //nolint:cyclop
  83. func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
  84. preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
  85. // avoid random traversal of the map
  86. preloadNames := make([]string, 0, len(preloadMap))
  87. for key := range preloadMap {
  88. preloadNames = append(preloadNames, key)
  89. }
  90. sort.Strings(preloadNames)
  91. isJoined := func(name string) (joined bool, nestedJoins []string) {
  92. for _, join := range joins {
  93. if _, ok := relationships.Relations[join]; ok && name == join {
  94. joined = true
  95. continue
  96. }
  97. joinNames := strings.SplitN(join, ".", 2)
  98. if len(joinNames) == 2 {
  99. if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
  100. joined = true
  101. nestedJoins = append(nestedJoins, joinNames[1])
  102. }
  103. }
  104. }
  105. return joined, nestedJoins
  106. }
  107. for _, name := range preloadNames {
  108. if relations := relationships.EmbeddedRelations[name]; relations != nil {
  109. if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
  110. return err
  111. }
  112. } else if rel := relationships.Relations[name]; rel != nil {
  113. if joined, nestedJoins := isJoined(name); joined {
  114. switch rv := db.Statement.ReflectValue; rv.Kind() {
  115. case reflect.Slice, reflect.Array:
  116. if rv.Len() > 0 {
  117. reflectValue := rel.FieldSchema.MakeSlice().Elem()
  118. for i := 0; i < rv.Len(); i++ {
  119. frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
  120. if frv.Kind() != reflect.Ptr {
  121. reflectValue = reflect.Append(reflectValue, frv.Addr())
  122. } else {
  123. if frv.IsNil() {
  124. continue
  125. }
  126. reflectValue = reflect.Append(reflectValue, frv)
  127. }
  128. }
  129. tx := preloadDB(db, reflectValue, reflectValue.Interface())
  130. if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
  131. return err
  132. }
  133. }
  134. case reflect.Struct, reflect.Pointer:
  135. reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
  136. tx := preloadDB(db, reflectValue, reflectValue.Interface())
  137. if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
  138. return err
  139. }
  140. default:
  141. return gorm.ErrInvalidData
  142. }
  143. } else {
  144. tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
  145. tx.Statement.ReflectValue = db.Statement.ReflectValue
  146. tx.Statement.Unscoped = db.Statement.Unscoped
  147. if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
  148. return err
  149. }
  150. }
  151. } else {
  152. return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
  153. }
  154. }
  155. return nil
  156. }
  157. func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
  158. tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
  159. db.Statement.Settings.Range(func(k, v interface{}) bool {
  160. tx.Statement.Settings.Store(k, v)
  161. return true
  162. })
  163. if err := tx.Statement.Parse(dest); err != nil {
  164. tx.AddError(err)
  165. return tx
  166. }
  167. tx.Statement.ReflectValue = reflectValue
  168. tx.Statement.Unscoped = db.Statement.Unscoped
  169. return tx
  170. }
  171. func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
  172. var (
  173. reflectValue = tx.Statement.ReflectValue
  174. relForeignKeys []string
  175. relForeignFields []*schema.Field
  176. foreignFields []*schema.Field
  177. foreignValues [][]interface{}
  178. identityMap = map[string][]reflect.Value{}
  179. inlineConds []interface{}
  180. )
  181. if rel.JoinTable != nil {
  182. var (
  183. joinForeignFields = make([]*schema.Field, 0, len(rel.References))
  184. joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
  185. joinForeignKeys = make([]string, 0, len(rel.References))
  186. )
  187. for _, ref := range rel.References {
  188. if ref.OwnPrimaryKey {
  189. joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
  190. joinForeignFields = append(joinForeignFields, ref.ForeignKey)
  191. foreignFields = append(foreignFields, ref.PrimaryKey)
  192. } else if ref.PrimaryValue != "" {
  193. tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  194. } else {
  195. joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
  196. relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
  197. relForeignFields = append(relForeignFields, ref.PrimaryKey)
  198. }
  199. }
  200. joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
  201. if len(joinForeignValues) == 0 {
  202. return nil
  203. }
  204. joinResults := rel.JoinTable.MakeSlice().Elem()
  205. column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
  206. if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
  207. return err
  208. }
  209. // convert join identity map to relation identity map
  210. fieldValues := make([]interface{}, len(joinForeignFields))
  211. joinFieldValues := make([]interface{}, len(joinRelForeignFields))
  212. for i := 0; i < joinResults.Len(); i++ {
  213. joinIndexValue := joinResults.Index(i)
  214. for idx, field := range joinForeignFields {
  215. fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
  216. }
  217. for idx, field := range joinRelForeignFields {
  218. joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
  219. }
  220. if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
  221. joinKey := utils.ToStringKey(joinFieldValues...)
  222. identityMap[joinKey] = append(identityMap[joinKey], results...)
  223. }
  224. }
  225. _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
  226. } else {
  227. for _, ref := range rel.References {
  228. if ref.OwnPrimaryKey {
  229. relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
  230. relForeignFields = append(relForeignFields, ref.ForeignKey)
  231. foreignFields = append(foreignFields, ref.PrimaryKey)
  232. } else if ref.PrimaryValue != "" {
  233. tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
  234. } else {
  235. relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
  236. relForeignFields = append(relForeignFields, ref.PrimaryKey)
  237. foreignFields = append(foreignFields, ref.ForeignKey)
  238. }
  239. }
  240. identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
  241. if len(foreignValues) == 0 {
  242. return nil
  243. }
  244. }
  245. // nested preload
  246. for p, pvs := range preloads {
  247. tx = tx.Preload(p, pvs...)
  248. }
  249. reflectResults := rel.FieldSchema.MakeSlice().Elem()
  250. column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
  251. if len(values) != 0 {
  252. for _, cond := range conds {
  253. if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
  254. tx = fc(tx)
  255. } else {
  256. inlineConds = append(inlineConds, cond)
  257. }
  258. }
  259. if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
  260. return err
  261. }
  262. }
  263. fieldValues := make([]interface{}, len(relForeignFields))
  264. // clean up old values before preloading
  265. switch reflectValue.Kind() {
  266. case reflect.Struct:
  267. switch rel.Type {
  268. case schema.HasMany, schema.Many2Many:
  269. tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
  270. default:
  271. tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
  272. }
  273. case reflect.Slice, reflect.Array:
  274. for i := 0; i < reflectValue.Len(); i++ {
  275. switch rel.Type {
  276. case schema.HasMany, schema.Many2Many:
  277. tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
  278. default:
  279. tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
  280. }
  281. }
  282. }
  283. for i := 0; i < reflectResults.Len(); i++ {
  284. elem := reflectResults.Index(i)
  285. for idx, field := range relForeignFields {
  286. fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
  287. }
  288. datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
  289. if !ok {
  290. return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
  291. }
  292. for _, data := range datas {
  293. reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
  294. if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
  295. reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
  296. }
  297. reflectFieldValue = reflect.Indirect(reflectFieldValue)
  298. switch reflectFieldValue.Kind() {
  299. case reflect.Struct:
  300. tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
  301. case reflect.Slice, reflect.Array:
  302. if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
  303. tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
  304. } else {
  305. tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
  306. }
  307. }
  308. }
  309. }
  310. return tx.Error
  311. }