123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- package callbacks
- import (
- "reflect"
- "strings"
- "gorm.io/gorm"
- "gorm.io/gorm/clause"
- "gorm.io/gorm/schema"
- "gorm.io/gorm/utils"
- )
- func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
- return func(db *gorm.DB) {
- if db.Error == nil && db.Statement.Schema != nil {
- selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
- // Save Belongs To associations
- for _, rel := range db.Statement.Schema.Relationships.BelongsTo {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- setupReferences := func(obj reflect.Value, elem reflect.Value) {
- for _, ref := range rel.References {
- if !ref.OwnPrimaryKey {
- pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv))
- if dest, ok := db.Statement.Dest.(map[string]interface{}); ok {
- dest[ref.ForeignKey.DBName] = pv
- if _, ok := dest[rel.Name]; ok {
- dest[rel.Name] = elem.Interface()
- }
- }
- }
- }
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- var (
- rValLen = db.Statement.ReflectValue.Len()
- objs = make([]reflect.Value, 0, rValLen)
- fieldType = rel.Field.FieldType
- isPtr = fieldType.Kind() == reflect.Ptr
- )
- if !isPtr {
- fieldType = reflect.PointerTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- identityMap := map[string]bool{}
- for i := 0; i < rValLen; i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() != reflect.Struct {
- break
- }
- if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value
- rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value
- if !isPtr {
- rv = rv.Addr()
- }
- objs = append(objs, obj)
- elems = reflect.Append(elems, rv)
- relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
- for _, pf := range rel.FieldSchema.PrimaryFields {
- if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok {
- relPrimaryValues = append(relPrimaryValues, pfv)
- }
- }
- cacheKey := utils.ToStringKey(relPrimaryValues...)
- if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
- if cacheKey != "" { // has primary fields
- identityMap[cacheKey] = true
- }
- distinctElems = reflect.Append(distinctElems, rv)
- }
- }
- }
- if elems.Len() > 0 {
- if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil {
- for i := 0; i < elems.Len(); i++ {
- setupReferences(objs[i], elems.Index(i))
- }
- }
- }
- case reflect.Struct:
- if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
- rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value
- if rv.Kind() != reflect.Ptr {
- rv = rv.Addr()
- }
- if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
- setupReferences(db.Statement.ReflectValue, rv)
- }
- }
- }
- }
- }
- }
- }
- func SaveAfterAssociations(create bool) func(db *gorm.DB) {
- return func(db *gorm.DB) {
- if db.Error == nil && db.Statement.Schema != nil {
- selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create)
- // Save Has One associations
- for _, rel := range db.Statement.Schema.Relationships.HasOne {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- var (
- fieldType = rel.Field.FieldType
- isPtr = fieldType.Kind() == reflect.Ptr
- )
- if !isPtr {
- fieldType = reflect.PointerTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero {
- rv := rel.Field.ReflectValueOf(db.Statement.Context, obj)
- if rv.Kind() != reflect.Ptr {
- rv = rv.Addr()
- }
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv))
- } else if ref.PrimaryValue != "" {
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue))
- }
- }
- elems = reflect.Append(elems, rv)
- }
- }
- }
- if elems.Len() > 0 {
- assignmentColumns := make([]string, 0, len(rel.References))
- for _, ref := range rel.References {
- assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
- }
- saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
- }
- case reflect.Struct:
- if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
- f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
- if f.Kind() != reflect.Ptr {
- f = f.Addr()
- }
- assignmentColumns := make([]string, 0, len(rel.References))
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv))
- } else if ref.PrimaryValue != "" {
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue))
- }
- assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
- }
- saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
- }
- }
- }
- // Save Has Many associations
- for _, rel := range db.Statement.Schema.Relationships.HasMany {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- fieldType := rel.Field.IndirectFieldType.Elem()
- isPtr := fieldType.Kind() == reflect.Ptr
- if !isPtr {
- fieldType = reflect.PointerTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- identityMap := map[string]bool{}
- appendToElems := func(v reflect.Value) {
- if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
- f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
- for i := 0; i < f.Len(); i++ {
- elem := f.Index(i)
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v)
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv))
- } else if ref.PrimaryValue != "" {
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue))
- }
- }
- relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
- for _, pf := range rel.FieldSchema.PrimaryFields {
- if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
- relPrimaryValues = append(relPrimaryValues, pfv)
- }
- }
- cacheKey := utils.ToStringKey(relPrimaryValues...)
- if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
- if cacheKey != "" { // has primary fields
- identityMap[cacheKey] = true
- }
- if isPtr {
- elems = reflect.Append(elems, elem)
- } else {
- elems = reflect.Append(elems, elem.Addr())
- }
- }
- }
- }
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- appendToElems(obj)
- }
- }
- case reflect.Struct:
- appendToElems(db.Statement.ReflectValue)
- }
- if elems.Len() > 0 {
- assignmentColumns := make([]string, 0, len(rel.References))
- for _, ref := range rel.References {
- assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
- }
- saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
- }
- }
- // Save Many2Many associations
- for _, rel := range db.Statement.Schema.Relationships.Many2Many {
- if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) {
- continue
- }
- fieldType := rel.Field.IndirectFieldType.Elem()
- isPtr := fieldType.Kind() == reflect.Ptr
- if !isPtr {
- fieldType = reflect.PointerTo(fieldType)
- }
- elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10)
- joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10)
- objs := []reflect.Value{}
- appendToJoins := func(obj reflect.Value, elem reflect.Value) {
- joinValue := reflect.New(rel.JoinTable.ModelType)
- for _, ref := range rel.References {
- if ref.OwnPrimaryKey {
- fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj)
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
- } else if ref.PrimaryValue != "" {
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue))
- } else {
- fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem)
- db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv))
- }
- }
- joins = reflect.Append(joins, joinValue)
- }
- identityMap := map[string]bool{}
- appendToElems := func(v reflect.Value) {
- if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero {
- f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v))
- for i := 0; i < f.Len(); i++ {
- elem := f.Index(i)
- if !isPtr {
- elem = elem.Addr()
- }
- objs = append(objs, v)
- elems = reflect.Append(elems, elem)
- relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields))
- for _, pf := range rel.FieldSchema.PrimaryFields {
- if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok {
- relPrimaryValues = append(relPrimaryValues, pfv)
- }
- }
- cacheKey := utils.ToStringKey(relPrimaryValues...)
- if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] {
- if cacheKey != "" { // has primary fields
- identityMap[cacheKey] = true
- }
- distinctElems = reflect.Append(distinctElems, elem)
- }
- }
- }
- }
- switch db.Statement.ReflectValue.Kind() {
- case reflect.Slice, reflect.Array:
- for i := 0; i < db.Statement.ReflectValue.Len(); i++ {
- obj := db.Statement.ReflectValue.Index(i)
- if reflect.Indirect(obj).Kind() == reflect.Struct {
- appendToElems(obj)
- }
- }
- case reflect.Struct:
- appendToElems(db.Statement.ReflectValue)
- }
- // optimize elems of reflect value length
- if elemLen := elems.Len(); elemLen > 0 {
- if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
- saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil)
- }
- for i := 0; i < elemLen; i++ {
- appendToJoins(objs[i], elems.Index(i))
- }
- }
- if joins.Len() > 0 {
- db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{
- SkipHooks: db.Statement.SkipHooks,
- DisableNestedTransaction: true,
- }).Create(joins.Interface()).Error)
- }
- }
- }
- }
- }
- func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) {
- if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations {
- onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames))
- for _, dbName := range s.PrimaryFieldDBNames {
- onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName})
- }
- onConflict.UpdateAll = stmt.DB.FullSaveAssociations
- if !onConflict.UpdateAll {
- onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns)
- }
- } else {
- onConflict.DoNothing = true
- }
- return
- }
- func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
- // stop save association loop
- if checkAssociationsSaved(db, rValues) {
- return nil
- }
- var (
- selects, omits []string
- onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns)
- refName = rel.Name + "."
- values = rValues.Interface()
- )
- for name, ok := range selectColumns {
- columnName := ""
- if strings.HasPrefix(name, refName) {
- columnName = strings.TrimPrefix(name, refName)
- }
- if columnName != "" {
- if ok {
- selects = append(selects, columnName)
- } else {
- omits = append(omits, columnName)
- }
- }
- }
- tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{
- FullSaveAssociations: db.FullSaveAssociations,
- SkipHooks: db.Statement.SkipHooks,
- DisableNestedTransaction: true,
- })
- db.Statement.Settings.Range(func(k, v interface{}) bool {
- tx.Statement.Settings.Store(k, v)
- return true
- })
- if tx.Statement.FullSaveAssociations {
- tx = tx.Set("gorm:update_track_time", true)
- }
- if len(selects) > 0 {
- tx = tx.Select(selects)
- } else if restricted && len(omits) == 0 {
- tx = tx.Omit(clause.Associations)
- }
- if len(omits) > 0 {
- tx = tx.Omit(omits...)
- }
- return db.AddError(tx.Create(values).Error)
- }
- // check association values has been saved
- // if values kind is Struct, check it has been saved
- // if values kind is Slice/Array, check all items have been saved
- var visitMapStoreKey = "gorm:saved_association_map"
- func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
- if visit, ok := db.Get(visitMapStoreKey); ok {
- if v, ok := visit.(*visitMap); ok {
- if loadOrStoreVisitMap(v, values) {
- return true
- }
- }
- } else {
- vistMap := make(visitMap)
- loadOrStoreVisitMap(&vistMap, values)
- db.Set(visitMapStoreKey, &vistMap)
- }
- return false
- }
|