123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345 |
- package callbacks
- import (
- "fmt"
- "reflect"
- "sort"
- "strings"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- "gorm.io/gorm/schema"
- "gorm.io/gorm/utils"
- )
- // parsePreloadMap extracts nested preloads. e.g.
- //
- // // schema has a "k0" relation and a "k7.k8" embedded relation
- // parsePreloadMap(schema, map[string][]interface{}{
- // clause.Associations: {"arg1"},
- // "k1": {"arg2"},
- // "k2.k3": {"arg3"},
- // "k4.k5.k6": {"arg4"},
- // })
- // // preloadMap is
- // map[string]map[string][]interface{}{
- // "k0": {},
- // "k7": {
- // "k8": {},
- // },
- // "k1": {},
- // "k2": {
- // "k3": {"arg3"},
- // },
- // "k4": {
- // "k5.k6": {"arg4"},
- // },
- // }
- func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} {
- preloadMap := map[string]map[string][]interface{}{}
- setPreloadMap := func(name, value string, args []interface{}) {
- if _, ok := preloadMap[name]; !ok {
- preloadMap[name] = map[string][]interface{}{}
- }
- if value != "" {
- preloadMap[name][value] = args
- }
- }
- for name, args := range preloads {
- preloadFields := strings.Split(name, ".")
- value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".")
- if preloadFields[0] == clause.Associations {
- for _, relation := range s.Relationships.Relations {
- if relation.Schema == s {
- setPreloadMap(relation.Name, value, args)
- }
- }
- for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations {
- for _, value := range embeddedValues(embeddedRelations) {
- setPreloadMap(embedded, value, args)
- }
- }
- } else {
- setPreloadMap(preloadFields[0], value, args)
- }
- }
- return preloadMap
- }
- func embeddedValues(embeddedRelations *schema.Relationships) []string {
- if embeddedRelations == nil {
- return nil
- }
- names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations))
- for _, relation := range embeddedRelations.Relations {
- // skip first struct name
- names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], "."))
- }
- for _, relations := range embeddedRelations.EmbeddedRelations {
- names = append(names, embeddedValues(relations)...)
- }
- return names
- }
- // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point.
- // If the current relationship is embedded or joined, current query will be ignored.
- //
- //nolint:cyclop
- func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error {
- preloadMap := parsePreloadMap(db.Statement.Schema, preloads)
- // avoid random traversal of the map
- preloadNames := make([]string, 0, len(preloadMap))
- for key := range preloadMap {
- preloadNames = append(preloadNames, key)
- }
- sort.Strings(preloadNames)
- isJoined := func(name string) (joined bool, nestedJoins []string) {
- for _, join := range joins {
- if _, ok := relationships.Relations[join]; ok && name == join {
- joined = true
- continue
- }
- joinNames := strings.SplitN(join, ".", 2)
- if len(joinNames) == 2 {
- if _, ok := relationships.Relations[joinNames[0]]; ok && name == joinNames[0] {
- joined = true
- nestedJoins = append(nestedJoins, joinNames[1])
- }
- }
- }
- return joined, nestedJoins
- }
- for _, name := range preloadNames {
- if relations := relationships.EmbeddedRelations[name]; relations != nil {
- if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil {
- return err
- }
- } else if rel := relationships.Relations[name]; rel != nil {
- if joined, nestedJoins := isJoined(name); joined {
- switch rv := db.Statement.ReflectValue; rv.Kind() {
- case reflect.Slice, reflect.Array:
- if rv.Len() > 0 {
- reflectValue := rel.FieldSchema.MakeSlice().Elem()
- for i := 0; i < rv.Len(); i++ {
- frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
- if frv.Kind() != reflect.Ptr {
- reflectValue = reflect.Append(reflectValue, frv.Addr())
- } else {
- if frv.IsNil() {
- continue
- }
- reflectValue = reflect.Append(reflectValue, frv)
- }
- }
- tx := preloadDB(db, reflectValue, reflectValue.Interface())
- if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
- return err
- }
- }
- case reflect.Struct, reflect.Pointer:
- reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
- tx := preloadDB(db, reflectValue, reflectValue.Interface())
- if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
- return err
- }
- default:
- return gorm.ErrInvalidData
- }
- } else {
- tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
- tx.Statement.ReflectValue = db.Statement.ReflectValue
- tx.Statement.Unscoped = db.Statement.Unscoped
- if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil {
- return err
- }
- }
- } else {
- return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)
- }
- }
- return nil
- }
- func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB {
- tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true})
- db.Statement.Settings.Range(func(k, v interface{}) bool {
- tx.Statement.Settings.Store(k, v)
- return true
- })
- if err := tx.Statement.Parse(dest); err != nil {
- tx.AddError(err)
- return tx
- }
- tx.Statement.ReflectValue = reflectValue
- tx.Statement.Unscoped = db.Statement.Unscoped
- return tx
- }
- func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error {
- var (
- reflectValue = tx.Statement.ReflectValue
- relForeignKeys []string
- relForeignFields []*schema.Field
- foreignFields []*schema.Field
- foreignValues [][]interface{}
- identityMap = map[string][]reflect.Value{}
- inlineConds []interface{}
- )
- if rel.JoinTable != nil {
- var (
- joinForeignFields = make([]*schema.Field, 0, len(rel.References))
- joinRelForeignFields = make([]*schema.Field, 0, len(rel.References))
- joinForeignKeys = make([]string, 0, len(rel.References))
- )
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName)
- joinForeignFields = append(joinForeignFields, ref.ForeignKey)
- foreignFields = append(foreignFields, ref.PrimaryKey)
- } else if ref.PrimaryValue != "" {
- tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
- } else {
- joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey)
- relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
- relForeignFields = append(relForeignFields, ref.PrimaryKey)
- }
- }
- joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
- if len(joinForeignValues) == 0 {
- return nil
- }
- joinResults := rel.JoinTable.MakeSlice().Elem()
- column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues)
- if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil {
- return err
- }
- // convert join identity map to relation identity map
- fieldValues := make([]interface{}, len(joinForeignFields))
- joinFieldValues := make([]interface{}, len(joinRelForeignFields))
- for i := 0; i < joinResults.Len(); i++ {
- joinIndexValue := joinResults.Index(i)
- for idx, field := range joinForeignFields {
- fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
- }
- for idx, field := range joinRelForeignFields {
- joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue)
- }
- if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok {
- joinKey := utils.ToStringKey(joinFieldValues...)
- identityMap[joinKey] = append(identityMap[joinKey], results...)
- }
- }
- _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields)
- } else {
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
- relForeignFields = append(relForeignFields, ref.ForeignKey)
- foreignFields = append(foreignFields, ref.PrimaryKey)
- } else if ref.PrimaryValue != "" {
- tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue})
- } else {
- relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName)
- relForeignFields = append(relForeignFields, ref.PrimaryKey)
- foreignFields = append(foreignFields, ref.ForeignKey)
- }
- }
- identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields)
- if len(foreignValues) == 0 {
- return nil
- }
- }
- // nested preload
- for p, pvs := range preloads {
- tx = tx.Preload(p, pvs...)
- }
- reflectResults := rel.FieldSchema.MakeSlice().Elem()
- column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
- if len(values) != 0 {
- for _, cond := range conds {
- if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
- tx = fc(tx)
- } else {
- inlineConds = append(inlineConds, cond)
- }
- }
- if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
- return err
- }
- }
- fieldValues := make([]interface{}, len(relForeignFields))
- // clean up old values before preloading
- switch reflectValue.Kind() {
- case reflect.Struct:
- switch rel.Type {
- case schema.HasMany, schema.Many2Many:
- tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
- default:
- tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface()))
- }
- case reflect.Slice, reflect.Array:
- for i := 0; i < reflectValue.Len(); i++ {
- switch rel.Type {
- case schema.HasMany, schema.Many2Many:
- tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface()))
- default:
- tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface()))
- }
- }
- }
- for i := 0; i < reflectResults.Len(); i++ {
- elem := reflectResults.Index(i)
- for idx, field := range relForeignFields {
- fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem)
- }
- datas, ok := identityMap[utils.ToStringKey(fieldValues...)]
- if !ok {
- return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface())
- }
- for _, data := range datas {
- reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data)
- if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() {
- reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem()))
- }
- reflectFieldValue = reflect.Indirect(reflectFieldValue)
- switch reflectFieldValue.Kind() {
- case reflect.Struct:
- tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface()))
- case reflect.Slice, reflect.Array:
- if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr {
- tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface()))
- } else {
- tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface()))
- }
- }
- }
- }
- return tx.Error
- }
|