scan.go 10 KB


  1. package gorm
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "reflect"
  6. "time"
  7. "gorm.io/gorm/schema"
  8. "gorm.io/gorm/utils"
  9. )
  10. // prepareValues prepare values slice
  11. func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
  12. if db.Statement.Schema != nil {
  13. for idx, name := range columns {
  14. if field := db.Statement.Schema.LookUpField(name); field != nil {
  15. values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface()
  16. continue
  17. }
  18. values[idx] = new(interface{})
  19. }
  20. } else if len(columnTypes) > 0 {
  21. for idx, columnType := range columnTypes {
  22. if columnType.ScanType() != nil {
  23. values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface()
  24. } else {
  25. values[idx] = new(interface{})
  26. }
  27. }
  28. } else {
  29. for idx := range columns {
  30. values[idx] = new(interface{})
  31. }
  32. }
  33. }
  34. func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) {
  35. for idx, column := range columns {
  36. if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() {
  37. mapValue[column] = reflectValue.Interface()
  38. if valuer, ok := mapValue[column].(driver.Valuer); ok {
  39. mapValue[column], _ = valuer.Value()
  40. } else if b, ok := mapValue[column].(sql.RawBytes); ok {
  41. mapValue[column] = string(b)
  42. }
  43. } else {
  44. mapValue[column] = nil
  45. }
  46. }
  47. }
  48. func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) {
  49. for idx, field := range fields {
  50. if field != nil {
  51. values[idx] = field.NewValuePool.Get()
  52. } else if len(fields) == 1 {
  53. if reflectValue.CanAddr() {
  54. values[idx] = reflectValue.Addr().Interface()
  55. } else {
  56. values[idx] = reflectValue.Interface()
  57. }
  58. }
  59. }
  60. db.RowsAffected++
  61. db.AddError(rows.Scan(values...))
  62. joinedNestedSchemaMap := make(map[string]interface{})
  63. for idx, field := range fields {
  64. if field == nil {
  65. continue
  66. }
  67. if len(joinFields) == 0 || len(joinFields[idx]) == 0 {
  68. db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx]))
  69. } else { // joinFields count is larger than 2 when using join
  70. var isNilPtrValue bool
  71. var relValue reflect.Value
  72. // does not contain raw dbname
  73. nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1]
  74. // current reflect value
  75. currentReflectValue := reflectValue
  76. fullRels := make([]string, 0, len(nestedJoinSchemas))
  77. for _, joinSchema := range nestedJoinSchemas {
  78. fullRels = append(fullRels, joinSchema.Name)
  79. relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue)
  80. if relValue.Kind() == reflect.Ptr {
  81. fullRelsName := utils.JoinNestedRelationNames(fullRels)
  82. // same nested structure
  83. if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok {
  84. if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() {
  85. isNilPtrValue = true
  86. break
  87. }
  88. relValue.Set(reflect.New(relValue.Type().Elem()))
  89. joinedNestedSchemaMap[fullRelsName] = nil
  90. }
  91. }
  92. currentReflectValue = relValue
  93. }
  94. if !isNilPtrValue { // ignore if value is nil
  95. f := joinFields[idx][len(joinFields[idx])-1]
  96. db.AddError(f.Set(db.Statement.Context, relValue, values[idx]))
  97. }
  98. }
  99. // release data to pool
  100. field.NewValuePool.Put(values[idx])
  101. }
  102. }
  103. // ScanMode scan data mode
  104. type ScanMode uint8
  105. // scan modes
  106. const (
  107. ScanInitialized ScanMode = 1 << 0 // 1
  108. ScanUpdate ScanMode = 1 << 1 // 2
  109. ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
  110. )
  111. // Scan scan rows into db statement
  112. func Scan(rows Rows, db *DB, mode ScanMode) {
  113. var (
  114. columns, _ = rows.Columns()
  115. values = make([]interface{}, len(columns))
  116. initialized = mode&ScanInitialized != 0
  117. update = mode&ScanUpdate != 0
  118. onConflictDonothing = mode&ScanOnConflictDoNothing != 0
  119. )
  120. if len(db.Statement.ColumnMapping) > 0 {
  121. for i, column := range columns {
  122. v, ok := db.Statement.ColumnMapping[column]
  123. if ok {
  124. columns[i] = v
  125. }
  126. }
  127. }
  128. db.RowsAffected = 0
  129. switch dest := db.Statement.Dest.(type) {
  130. case map[string]interface{}, *map[string]interface{}:
  131. if initialized || rows.Next() {
  132. columnTypes, _ := rows.ColumnTypes()
  133. prepareValues(values, db, columnTypes, columns)
  134. db.RowsAffected++
  135. db.AddError(rows.Scan(values...))
  136. mapValue, ok := dest.(map[string]interface{})
  137. if !ok {
  138. if v, ok := dest.(*map[string]interface{}); ok {
  139. if *v == nil {
  140. *v = map[string]interface{}{}
  141. }
  142. mapValue = *v
  143. }
  144. }
  145. scanIntoMap(mapValue, values, columns)
  146. }
  147. case *[]map[string]interface{}:
  148. columnTypes, _ := rows.ColumnTypes()
  149. for initialized || rows.Next() {
  150. prepareValues(values, db, columnTypes, columns)
  151. initialized = false
  152. db.RowsAffected++
  153. db.AddError(rows.Scan(values...))
  154. mapValue := map[string]interface{}{}
  155. scanIntoMap(mapValue, values, columns)
  156. *dest = append(*dest, mapValue)
  157. }
  158. case *int, *int8, *int16, *int32, *int64,
  159. *uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
  160. *float32, *float64,
  161. *bool, *string, *time.Time,
  162. *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64,
  163. *sql.NullBool, *sql.NullString, *sql.NullTime:
  164. for initialized || rows.Next() {
  165. initialized = false
  166. db.RowsAffected++
  167. db.AddError(rows.Scan(dest))
  168. }
  169. default:
  170. var (
  171. fields = make([]*schema.Field, len(columns))
  172. joinFields [][]*schema.Field
  173. sch = db.Statement.Schema
  174. reflectValue = db.Statement.ReflectValue
  175. )
  176. if reflectValue.Kind() == reflect.Interface {
  177. reflectValue = reflectValue.Elem()
  178. }
  179. reflectValueType := reflectValue.Type()
  180. switch reflectValueType.Kind() {
  181. case reflect.Array, reflect.Slice:
  182. reflectValueType = reflectValueType.Elem()
  183. }
  184. isPtr := reflectValueType.Kind() == reflect.Ptr
  185. if isPtr {
  186. reflectValueType = reflectValueType.Elem()
  187. }
  188. if sch != nil {
  189. if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct {
  190. sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy)
  191. }
  192. if len(columns) == 1 {
  193. // Is Pluck
  194. if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner
  195. reflectValueType.Kind() != reflect.Struct || // is not struct
  196. sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time
  197. sch = nil
  198. }
  199. }
  200. // Not Pluck
  201. if sch != nil {
  202. matchedFieldCount := make(map[string]int, len(columns))
  203. for idx, column := range columns {
  204. if field := sch.LookUpField(column); field != nil && field.Readable {
  205. fields[idx] = field
  206. if count, ok := matchedFieldCount[column]; ok {
  207. // handle duplicate fields
  208. for _, selectField := range sch.Fields {
  209. if selectField.DBName == column && selectField.Readable {
  210. if count == 0 {
  211. matchedFieldCount[column]++
  212. fields[idx] = selectField
  213. break
  214. }
  215. count--
  216. }
  217. }
  218. } else {
  219. matchedFieldCount[column] = 1
  220. }
  221. } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation
  222. if rel, ok := sch.Relationships.Relations[names[0]]; ok {
  223. subNameCount := len(names)
  224. // nested relation fields
  225. relFields := make([]*schema.Field, 0, subNameCount-1)
  226. relFields = append(relFields, rel.Field)
  227. for _, name := range names[1 : subNameCount-1] {
  228. rel = rel.FieldSchema.Relationships.Relations[name]
  229. relFields = append(relFields, rel.Field)
  230. }
  231. // latest name is raw dbname
  232. dbName := names[subNameCount-1]
  233. if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable {
  234. fields[idx] = field
  235. if len(joinFields) == 0 {
  236. joinFields = make([][]*schema.Field, len(columns))
  237. }
  238. relFields = append(relFields, field)
  239. joinFields[idx] = relFields
  240. continue
  241. }
  242. }
  243. var val interface{}
  244. values[idx] = &val
  245. } else {
  246. var val interface{}
  247. values[idx] = &val
  248. }
  249. }
  250. }
  251. }
  252. switch reflectValue.Kind() {
  253. case reflect.Slice, reflect.Array:
  254. var (
  255. elem reflect.Value
  256. isArrayKind = reflectValue.Kind() == reflect.Array
  257. )
  258. if !update || reflectValue.Len() == 0 {
  259. update = false
  260. if isArrayKind {
  261. db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
  262. } else {
  263. // if the slice cap is externally initialized, the externally initialized slice is directly used here
  264. if reflectValue.Cap() == 0 {
  265. db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
  266. } else {
  267. reflectValue.SetLen(0)
  268. db.Statement.ReflectValue.Set(reflectValue)
  269. }
  270. }
  271. }
  272. for initialized || rows.Next() {
  273. BEGIN:
  274. initialized = false
  275. if update {
  276. if int(db.RowsAffected) >= reflectValue.Len() {
  277. return
  278. }
  279. elem = reflectValue.Index(int(db.RowsAffected))
  280. if onConflictDonothing {
  281. for _, field := range fields {
  282. if _, ok := field.ValueOf(db.Statement.Context, elem); !ok {
  283. db.RowsAffected++
  284. goto BEGIN
  285. }
  286. }
  287. }
  288. } else {
  289. elem = reflect.New(reflectValueType)
  290. }
  291. db.scanIntoStruct(rows, elem, values, fields, joinFields)
  292. if !update {
  293. if !isPtr {
  294. elem = elem.Elem()
  295. }
  296. if isArrayKind {
  297. if reflectValue.Len() >= int(db.RowsAffected) {
  298. reflectValue.Index(int(db.RowsAffected - 1)).Set(elem)
  299. }
  300. } else {
  301. reflectValue = reflect.Append(reflectValue, elem)
  302. }
  303. }
  304. }
  305. if !update {
  306. db.Statement.ReflectValue.Set(reflectValue)
  307. }
  308. case reflect.Struct, reflect.Ptr:
  309. if initialized || rows.Next() {
  310. if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct {
  311. db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type()))
  312. }
  313. db.scanIntoStruct(rows, reflectValue, values, fields, joinFields)
  314. }
  315. default:
  316. db.AddError(rows.Scan(dest))
  317. }
  318. }
  319. if err := rows.Err(); err != nil && err != db.Error {
  320. db.AddError(err)
  321. }
  322. if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil {
  323. db.AddError(ErrRecordNotFound)
  324. }
  325. }