compiler.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939
  1. package encoder
  2. import (
  3. "context"
  4. "encoding"
  5. "encoding/json"
  6. "reflect"
  7. "sync"
  8. "sync/atomic"
  9. "unsafe"
  10. "github.com/goccy/go-json/internal/errors"
  11. "github.com/goccy/go-json/internal/runtime"
  12. )
  13. type marshalerContext interface {
  14. MarshalJSON(context.Context) ([]byte, error)
  15. }
  16. var (
  17. marshalJSONType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
  18. marshalJSONContextType = reflect.TypeOf((*marshalerContext)(nil)).Elem()
  19. marshalTextType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
  20. jsonNumberType = reflect.TypeOf(json.Number(""))
  21. cachedOpcodeSets []*OpcodeSet
  22. cachedOpcodeMap unsafe.Pointer // map[uintptr]*OpcodeSet
  23. typeAddr *runtime.TypeAddr
  24. initEncoderOnce sync.Once
  25. )
  26. func initEncoder() {
  27. initEncoderOnce.Do(func() {
  28. typeAddr = runtime.AnalyzeTypeAddr()
  29. if typeAddr == nil {
  30. typeAddr = &runtime.TypeAddr{}
  31. }
  32. cachedOpcodeSets = make([]*OpcodeSet, typeAddr.AddrRange>>typeAddr.AddrShift+1)
  33. })
  34. }
  35. func loadOpcodeMap() map[uintptr]*OpcodeSet {
  36. p := atomic.LoadPointer(&cachedOpcodeMap)
  37. return *(*map[uintptr]*OpcodeSet)(unsafe.Pointer(&p))
  38. }
  39. func storeOpcodeSet(typ uintptr, set *OpcodeSet, m map[uintptr]*OpcodeSet) {
  40. newOpcodeMap := make(map[uintptr]*OpcodeSet, len(m)+1)
  41. newOpcodeMap[typ] = set
  42. for k, v := range m {
  43. newOpcodeMap[k] = v
  44. }
  45. atomic.StorePointer(&cachedOpcodeMap, *(*unsafe.Pointer)(unsafe.Pointer(&newOpcodeMap)))
  46. }
  47. func compileToGetCodeSetSlowPath(typeptr uintptr) (*OpcodeSet, error) {
  48. opcodeMap := loadOpcodeMap()
  49. if codeSet, exists := opcodeMap[typeptr]; exists {
  50. return codeSet, nil
  51. }
  52. codeSet, err := newCompiler().compile(typeptr)
  53. if err != nil {
  54. return nil, err
  55. }
  56. storeOpcodeSet(typeptr, codeSet, opcodeMap)
  57. return codeSet, nil
  58. }
  59. func getFilteredCodeSetIfNeeded(ctx *RuntimeContext, codeSet *OpcodeSet) (*OpcodeSet, error) {
  60. if (ctx.Option.Flag & ContextOption) == 0 {
  61. return codeSet, nil
  62. }
  63. query := FieldQueryFromContext(ctx.Option.Context)
  64. if query == nil {
  65. return codeSet, nil
  66. }
  67. ctx.Option.Flag |= FieldQueryOption
  68. cacheCodeSet := codeSet.getQueryCache(query.Hash())
  69. if cacheCodeSet != nil {
  70. return cacheCodeSet, nil
  71. }
  72. queryCodeSet, err := newCompiler().codeToOpcodeSet(codeSet.Type, codeSet.Code.Filter(query))
  73. if err != nil {
  74. return nil, err
  75. }
  76. codeSet.setQueryCache(query.Hash(), queryCodeSet)
  77. return queryCodeSet, nil
  78. }
  79. type Compiler struct {
  80. structTypeToCode map[uintptr]*StructCode
  81. }
  82. func newCompiler() *Compiler {
  83. return &Compiler{
  84. structTypeToCode: map[uintptr]*StructCode{},
  85. }
  86. }
  87. func (c *Compiler) compile(typeptr uintptr) (*OpcodeSet, error) {
  88. // noescape trick for header.typ ( reflect.*rtype )
  89. typ := *(**runtime.Type)(unsafe.Pointer(&typeptr))
  90. code, err := c.typeToCode(typ)
  91. if err != nil {
  92. return nil, err
  93. }
  94. return c.codeToOpcodeSet(typ, code)
  95. }
  96. func (c *Compiler) codeToOpcodeSet(typ *runtime.Type, code Code) (*OpcodeSet, error) {
  97. noescapeKeyCode := c.codeToOpcode(&compileContext{
  98. structTypeToCodes: map[uintptr]Opcodes{},
  99. recursiveCodes: &Opcodes{},
  100. }, typ, code)
  101. if err := noescapeKeyCode.Validate(); err != nil {
  102. return nil, err
  103. }
  104. escapeKeyCode := c.codeToOpcode(&compileContext{
  105. structTypeToCodes: map[uintptr]Opcodes{},
  106. recursiveCodes: &Opcodes{},
  107. escapeKey: true,
  108. }, typ, code)
  109. noescapeKeyCode = copyOpcode(noescapeKeyCode)
  110. escapeKeyCode = copyOpcode(escapeKeyCode)
  111. setTotalLengthToInterfaceOp(noescapeKeyCode)
  112. setTotalLengthToInterfaceOp(escapeKeyCode)
  113. interfaceNoescapeKeyCode := copyToInterfaceOpcode(noescapeKeyCode)
  114. interfaceEscapeKeyCode := copyToInterfaceOpcode(escapeKeyCode)
  115. codeLength := noescapeKeyCode.TotalLength()
  116. return &OpcodeSet{
  117. Type: typ,
  118. NoescapeKeyCode: noescapeKeyCode,
  119. EscapeKeyCode: escapeKeyCode,
  120. InterfaceNoescapeKeyCode: interfaceNoescapeKeyCode,
  121. InterfaceEscapeKeyCode: interfaceEscapeKeyCode,
  122. CodeLength: codeLength,
  123. EndCode: ToEndCode(interfaceNoescapeKeyCode),
  124. Code: code,
  125. QueryCache: map[string]*OpcodeSet{},
  126. }, nil
  127. }
  128. func (c *Compiler) typeToCode(typ *runtime.Type) (Code, error) {
  129. switch {
  130. case c.implementsMarshalJSON(typ):
  131. return c.marshalJSONCode(typ)
  132. case c.implementsMarshalText(typ):
  133. return c.marshalTextCode(typ)
  134. }
  135. isPtr := false
  136. orgType := typ
  137. if typ.Kind() == reflect.Ptr {
  138. typ = typ.Elem()
  139. isPtr = true
  140. }
  141. switch {
  142. case c.implementsMarshalJSON(typ):
  143. return c.marshalJSONCode(orgType)
  144. case c.implementsMarshalText(typ):
  145. return c.marshalTextCode(orgType)
  146. }
  147. switch typ.Kind() {
  148. case reflect.Slice:
  149. elem := typ.Elem()
  150. if elem.Kind() == reflect.Uint8 {
  151. p := runtime.PtrTo(elem)
  152. if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
  153. return c.bytesCode(typ, isPtr)
  154. }
  155. }
  156. return c.sliceCode(typ)
  157. case reflect.Map:
  158. if isPtr {
  159. return c.ptrCode(runtime.PtrTo(typ))
  160. }
  161. return c.mapCode(typ)
  162. case reflect.Struct:
  163. return c.structCode(typ, isPtr)
  164. case reflect.Int:
  165. return c.intCode(typ, isPtr)
  166. case reflect.Int8:
  167. return c.int8Code(typ, isPtr)
  168. case reflect.Int16:
  169. return c.int16Code(typ, isPtr)
  170. case reflect.Int32:
  171. return c.int32Code(typ, isPtr)
  172. case reflect.Int64:
  173. return c.int64Code(typ, isPtr)
  174. case reflect.Uint, reflect.Uintptr:
  175. return c.uintCode(typ, isPtr)
  176. case reflect.Uint8:
  177. return c.uint8Code(typ, isPtr)
  178. case reflect.Uint16:
  179. return c.uint16Code(typ, isPtr)
  180. case reflect.Uint32:
  181. return c.uint32Code(typ, isPtr)
  182. case reflect.Uint64:
  183. return c.uint64Code(typ, isPtr)
  184. case reflect.Float32:
  185. return c.float32Code(typ, isPtr)
  186. case reflect.Float64:
  187. return c.float64Code(typ, isPtr)
  188. case reflect.String:
  189. return c.stringCode(typ, isPtr)
  190. case reflect.Bool:
  191. return c.boolCode(typ, isPtr)
  192. case reflect.Interface:
  193. return c.interfaceCode(typ, isPtr)
  194. default:
  195. if isPtr && typ.Implements(marshalTextType) {
  196. typ = orgType
  197. }
  198. return c.typeToCodeWithPtr(typ, isPtr)
  199. }
  200. }
  201. func (c *Compiler) typeToCodeWithPtr(typ *runtime.Type, isPtr bool) (Code, error) {
  202. switch {
  203. case c.implementsMarshalJSON(typ):
  204. return c.marshalJSONCode(typ)
  205. case c.implementsMarshalText(typ):
  206. return c.marshalTextCode(typ)
  207. }
  208. switch typ.Kind() {
  209. case reflect.Ptr:
  210. return c.ptrCode(typ)
  211. case reflect.Slice:
  212. elem := typ.Elem()
  213. if elem.Kind() == reflect.Uint8 {
  214. p := runtime.PtrTo(elem)
  215. if !c.implementsMarshalJSONType(p) && !p.Implements(marshalTextType) {
  216. return c.bytesCode(typ, false)
  217. }
  218. }
  219. return c.sliceCode(typ)
  220. case reflect.Array:
  221. return c.arrayCode(typ)
  222. case reflect.Map:
  223. return c.mapCode(typ)
  224. case reflect.Struct:
  225. return c.structCode(typ, isPtr)
  226. case reflect.Interface:
  227. return c.interfaceCode(typ, false)
  228. case reflect.Int:
  229. return c.intCode(typ, false)
  230. case reflect.Int8:
  231. return c.int8Code(typ, false)
  232. case reflect.Int16:
  233. return c.int16Code(typ, false)
  234. case reflect.Int32:
  235. return c.int32Code(typ, false)
  236. case reflect.Int64:
  237. return c.int64Code(typ, false)
  238. case reflect.Uint:
  239. return c.uintCode(typ, false)
  240. case reflect.Uint8:
  241. return c.uint8Code(typ, false)
  242. case reflect.Uint16:
  243. return c.uint16Code(typ, false)
  244. case reflect.Uint32:
  245. return c.uint32Code(typ, false)
  246. case reflect.Uint64:
  247. return c.uint64Code(typ, false)
  248. case reflect.Uintptr:
  249. return c.uintCode(typ, false)
  250. case reflect.Float32:
  251. return c.float32Code(typ, false)
  252. case reflect.Float64:
  253. return c.float64Code(typ, false)
  254. case reflect.String:
  255. return c.stringCode(typ, false)
  256. case reflect.Bool:
  257. return c.boolCode(typ, false)
  258. }
  259. return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
  260. }
  261. const intSize = 32 << (^uint(0) >> 63)
  262. //nolint:unparam
  263. func (c *Compiler) intCode(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  264. return &IntCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
  265. }
  266. //nolint:unparam
  267. func (c *Compiler) int8Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  268. return &IntCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
  269. }
  270. //nolint:unparam
  271. func (c *Compiler) int16Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  272. return &IntCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
  273. }
  274. //nolint:unparam
  275. func (c *Compiler) int32Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  276. return &IntCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  277. }
  278. //nolint:unparam
  279. func (c *Compiler) int64Code(typ *runtime.Type, isPtr bool) (*IntCode, error) {
  280. return &IntCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  281. }
  282. //nolint:unparam
  283. func (c *Compiler) uintCode(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  284. return &UintCode{typ: typ, bitSize: intSize, isPtr: isPtr}, nil
  285. }
  286. //nolint:unparam
  287. func (c *Compiler) uint8Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  288. return &UintCode{typ: typ, bitSize: 8, isPtr: isPtr}, nil
  289. }
  290. //nolint:unparam
  291. func (c *Compiler) uint16Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  292. return &UintCode{typ: typ, bitSize: 16, isPtr: isPtr}, nil
  293. }
  294. //nolint:unparam
  295. func (c *Compiler) uint32Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  296. return &UintCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  297. }
  298. //nolint:unparam
  299. func (c *Compiler) uint64Code(typ *runtime.Type, isPtr bool) (*UintCode, error) {
  300. return &UintCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  301. }
  302. //nolint:unparam
  303. func (c *Compiler) float32Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
  304. return &FloatCode{typ: typ, bitSize: 32, isPtr: isPtr}, nil
  305. }
  306. //nolint:unparam
  307. func (c *Compiler) float64Code(typ *runtime.Type, isPtr bool) (*FloatCode, error) {
  308. return &FloatCode{typ: typ, bitSize: 64, isPtr: isPtr}, nil
  309. }
  310. //nolint:unparam
  311. func (c *Compiler) stringCode(typ *runtime.Type, isPtr bool) (*StringCode, error) {
  312. return &StringCode{typ: typ, isPtr: isPtr}, nil
  313. }
  314. //nolint:unparam
  315. func (c *Compiler) boolCode(typ *runtime.Type, isPtr bool) (*BoolCode, error) {
  316. return &BoolCode{typ: typ, isPtr: isPtr}, nil
  317. }
  318. //nolint:unparam
  319. func (c *Compiler) intStringCode(typ *runtime.Type) (*IntCode, error) {
  320. return &IntCode{typ: typ, bitSize: intSize, isString: true}, nil
  321. }
  322. //nolint:unparam
  323. func (c *Compiler) int8StringCode(typ *runtime.Type) (*IntCode, error) {
  324. return &IntCode{typ: typ, bitSize: 8, isString: true}, nil
  325. }
  326. //nolint:unparam
  327. func (c *Compiler) int16StringCode(typ *runtime.Type) (*IntCode, error) {
  328. return &IntCode{typ: typ, bitSize: 16, isString: true}, nil
  329. }
  330. //nolint:unparam
  331. func (c *Compiler) int32StringCode(typ *runtime.Type) (*IntCode, error) {
  332. return &IntCode{typ: typ, bitSize: 32, isString: true}, nil
  333. }
  334. //nolint:unparam
  335. func (c *Compiler) int64StringCode(typ *runtime.Type) (*IntCode, error) {
  336. return &IntCode{typ: typ, bitSize: 64, isString: true}, nil
  337. }
  338. //nolint:unparam
  339. func (c *Compiler) uintStringCode(typ *runtime.Type) (*UintCode, error) {
  340. return &UintCode{typ: typ, bitSize: intSize, isString: true}, nil
  341. }
  342. //nolint:unparam
  343. func (c *Compiler) uint8StringCode(typ *runtime.Type) (*UintCode, error) {
  344. return &UintCode{typ: typ, bitSize: 8, isString: true}, nil
  345. }
  346. //nolint:unparam
  347. func (c *Compiler) uint16StringCode(typ *runtime.Type) (*UintCode, error) {
  348. return &UintCode{typ: typ, bitSize: 16, isString: true}, nil
  349. }
  350. //nolint:unparam
  351. func (c *Compiler) uint32StringCode(typ *runtime.Type) (*UintCode, error) {
  352. return &UintCode{typ: typ, bitSize: 32, isString: true}, nil
  353. }
  354. //nolint:unparam
  355. func (c *Compiler) uint64StringCode(typ *runtime.Type) (*UintCode, error) {
  356. return &UintCode{typ: typ, bitSize: 64, isString: true}, nil
  357. }
  358. //nolint:unparam
  359. func (c *Compiler) bytesCode(typ *runtime.Type, isPtr bool) (*BytesCode, error) {
  360. return &BytesCode{typ: typ, isPtr: isPtr}, nil
  361. }
  362. //nolint:unparam
  363. func (c *Compiler) interfaceCode(typ *runtime.Type, isPtr bool) (*InterfaceCode, error) {
  364. return &InterfaceCode{typ: typ, isPtr: isPtr}, nil
  365. }
  366. //nolint:unparam
  367. func (c *Compiler) marshalJSONCode(typ *runtime.Type) (*MarshalJSONCode, error) {
  368. return &MarshalJSONCode{
  369. typ: typ,
  370. isAddrForMarshaler: c.isPtrMarshalJSONType(typ),
  371. isNilableType: c.isNilableType(typ),
  372. isMarshalerContext: typ.Implements(marshalJSONContextType) || runtime.PtrTo(typ).Implements(marshalJSONContextType),
  373. }, nil
  374. }
  375. //nolint:unparam
  376. func (c *Compiler) marshalTextCode(typ *runtime.Type) (*MarshalTextCode, error) {
  377. return &MarshalTextCode{
  378. typ: typ,
  379. isAddrForMarshaler: c.isPtrMarshalTextType(typ),
  380. isNilableType: c.isNilableType(typ),
  381. }, nil
  382. }
  383. func (c *Compiler) ptrCode(typ *runtime.Type) (*PtrCode, error) {
  384. code, err := c.typeToCodeWithPtr(typ.Elem(), true)
  385. if err != nil {
  386. return nil, err
  387. }
  388. ptr, ok := code.(*PtrCode)
  389. if ok {
  390. return &PtrCode{typ: typ, value: ptr.value, ptrNum: ptr.ptrNum + 1}, nil
  391. }
  392. return &PtrCode{typ: typ, value: code, ptrNum: 1}, nil
  393. }
  394. func (c *Compiler) sliceCode(typ *runtime.Type) (*SliceCode, error) {
  395. elem := typ.Elem()
  396. code, err := c.listElemCode(elem)
  397. if err != nil {
  398. return nil, err
  399. }
  400. if code.Kind() == CodeKindStruct {
  401. structCode := code.(*StructCode)
  402. structCode.enableIndirect()
  403. }
  404. return &SliceCode{typ: typ, value: code}, nil
  405. }
  406. func (c *Compiler) arrayCode(typ *runtime.Type) (*ArrayCode, error) {
  407. elem := typ.Elem()
  408. code, err := c.listElemCode(elem)
  409. if err != nil {
  410. return nil, err
  411. }
  412. if code.Kind() == CodeKindStruct {
  413. structCode := code.(*StructCode)
  414. structCode.enableIndirect()
  415. }
  416. return &ArrayCode{typ: typ, value: code}, nil
  417. }
  418. func (c *Compiler) mapCode(typ *runtime.Type) (*MapCode, error) {
  419. keyCode, err := c.mapKeyCode(typ.Key())
  420. if err != nil {
  421. return nil, err
  422. }
  423. valueCode, err := c.mapValueCode(typ.Elem())
  424. if err != nil {
  425. return nil, err
  426. }
  427. if valueCode.Kind() == CodeKindStruct {
  428. structCode := valueCode.(*StructCode)
  429. structCode.enableIndirect()
  430. }
  431. return &MapCode{typ: typ, key: keyCode, value: valueCode}, nil
  432. }
  433. func (c *Compiler) listElemCode(typ *runtime.Type) (Code, error) {
  434. switch {
  435. case c.implementsMarshalJSONType(typ) || c.implementsMarshalJSONType(runtime.PtrTo(typ)):
  436. return c.marshalJSONCode(typ)
  437. case !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType):
  438. return c.marshalTextCode(typ)
  439. case typ.Kind() == reflect.Map:
  440. return c.ptrCode(runtime.PtrTo(typ))
  441. default:
  442. // isPtr was originally used to indicate whether the type of top level is pointer.
  443. // However, since the slice/array element is a specification that can get the pointer address, explicitly set isPtr to true.
  444. // See here for related issues: https://github.com/goccy/go-json/issues/370
  445. code, err := c.typeToCodeWithPtr(typ, true)
  446. if err != nil {
  447. return nil, err
  448. }
  449. ptr, ok := code.(*PtrCode)
  450. if ok {
  451. if ptr.value.Kind() == CodeKindMap {
  452. ptr.ptrNum++
  453. }
  454. }
  455. return code, nil
  456. }
  457. }
  458. func (c *Compiler) mapKeyCode(typ *runtime.Type) (Code, error) {
  459. switch {
  460. case c.implementsMarshalText(typ):
  461. return c.marshalTextCode(typ)
  462. }
  463. switch typ.Kind() {
  464. case reflect.Ptr:
  465. return c.ptrCode(typ)
  466. case reflect.String:
  467. return c.stringCode(typ, false)
  468. case reflect.Int:
  469. return c.intStringCode(typ)
  470. case reflect.Int8:
  471. return c.int8StringCode(typ)
  472. case reflect.Int16:
  473. return c.int16StringCode(typ)
  474. case reflect.Int32:
  475. return c.int32StringCode(typ)
  476. case reflect.Int64:
  477. return c.int64StringCode(typ)
  478. case reflect.Uint:
  479. return c.uintStringCode(typ)
  480. case reflect.Uint8:
  481. return c.uint8StringCode(typ)
  482. case reflect.Uint16:
  483. return c.uint16StringCode(typ)
  484. case reflect.Uint32:
  485. return c.uint32StringCode(typ)
  486. case reflect.Uint64:
  487. return c.uint64StringCode(typ)
  488. case reflect.Uintptr:
  489. return c.uintStringCode(typ)
  490. }
  491. return nil, &errors.UnsupportedTypeError{Type: runtime.RType2Type(typ)}
  492. }
  493. func (c *Compiler) mapValueCode(typ *runtime.Type) (Code, error) {
  494. switch typ.Kind() {
  495. case reflect.Map:
  496. return c.ptrCode(runtime.PtrTo(typ))
  497. default:
  498. code, err := c.typeToCodeWithPtr(typ, false)
  499. if err != nil {
  500. return nil, err
  501. }
  502. ptr, ok := code.(*PtrCode)
  503. if ok {
  504. if ptr.value.Kind() == CodeKindMap {
  505. ptr.ptrNum++
  506. }
  507. }
  508. return code, nil
  509. }
  510. }
  511. func (c *Compiler) structCode(typ *runtime.Type, isPtr bool) (*StructCode, error) {
  512. typeptr := uintptr(unsafe.Pointer(typ))
  513. if code, exists := c.structTypeToCode[typeptr]; exists {
  514. derefCode := *code
  515. derefCode.isRecursive = true
  516. return &derefCode, nil
  517. }
  518. indirect := runtime.IfaceIndir(typ)
  519. code := &StructCode{typ: typ, isPtr: isPtr, isIndirect: indirect}
  520. c.structTypeToCode[typeptr] = code
  521. fieldNum := typ.NumField()
  522. tags := c.typeToStructTags(typ)
  523. fields := []*StructFieldCode{}
  524. for i, tag := range tags {
  525. isOnlyOneFirstField := i == 0 && fieldNum == 1
  526. field, err := c.structFieldCode(code, tag, isPtr, isOnlyOneFirstField)
  527. if err != nil {
  528. return nil, err
  529. }
  530. if field.isAnonymous {
  531. structCode := field.getAnonymousStruct()
  532. if structCode != nil {
  533. structCode.removeFieldsByTags(tags)
  534. if c.isAssignableIndirect(field, isPtr) {
  535. if indirect {
  536. structCode.isIndirect = true
  537. } else {
  538. structCode.isIndirect = false
  539. }
  540. }
  541. }
  542. } else {
  543. structCode := field.getStruct()
  544. if structCode != nil {
  545. if indirect {
  546. // if parent is indirect type, set child indirect property to true
  547. structCode.isIndirect = true
  548. } else {
  549. // if parent is not indirect type, set child indirect property to false.
  550. // but if parent's indirect is false and isPtr is true, then indirect must be true.
  551. // Do this only if indirectConversion is enabled at the end of compileStruct.
  552. structCode.isIndirect = false
  553. }
  554. }
  555. }
  556. fields = append(fields, field)
  557. }
  558. fieldMap := c.getFieldMap(fields)
  559. duplicatedFieldMap := c.getDuplicatedFieldMap(fieldMap)
  560. code.fields = c.filteredDuplicatedFields(fields, duplicatedFieldMap)
  561. if !code.disableIndirectConversion && !indirect && isPtr {
  562. code.enableIndirect()
  563. }
  564. delete(c.structTypeToCode, typeptr)
  565. return code, nil
  566. }
  567. func toElemType(t *runtime.Type) *runtime.Type {
  568. for t.Kind() == reflect.Ptr {
  569. t = t.Elem()
  570. }
  571. return t
  572. }
  573. func (c *Compiler) structFieldCode(structCode *StructCode, tag *runtime.StructTag, isPtr, isOnlyOneFirstField bool) (*StructFieldCode, error) {
  574. field := tag.Field
  575. fieldType := runtime.Type2RType(field.Type)
  576. isIndirectSpecialCase := isPtr && isOnlyOneFirstField
  577. fieldCode := &StructFieldCode{
  578. typ: fieldType,
  579. key: tag.Key,
  580. tag: tag,
  581. offset: field.Offset,
  582. isAnonymous: field.Anonymous && !tag.IsTaggedKey && toElemType(fieldType).Kind() == reflect.Struct,
  583. isTaggedKey: tag.IsTaggedKey,
  584. isNilableType: c.isNilableType(fieldType),
  585. isNilCheck: true,
  586. }
  587. switch {
  588. case c.isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(fieldType, isIndirectSpecialCase):
  589. code, err := c.marshalJSONCode(fieldType)
  590. if err != nil {
  591. return nil, err
  592. }
  593. fieldCode.value = code
  594. fieldCode.isAddrForMarshaler = true
  595. fieldCode.isNilCheck = false
  596. structCode.isIndirect = false
  597. structCode.disableIndirectConversion = true
  598. case c.isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(fieldType, isIndirectSpecialCase):
  599. code, err := c.marshalTextCode(fieldType)
  600. if err != nil {
  601. return nil, err
  602. }
  603. fieldCode.value = code
  604. fieldCode.isAddrForMarshaler = true
  605. fieldCode.isNilCheck = false
  606. structCode.isIndirect = false
  607. structCode.disableIndirectConversion = true
  608. case isPtr && c.isPtrMarshalJSONType(fieldType):
  609. // *struct{ field T }
  610. // func (*T) MarshalJSON() ([]byte, error)
  611. code, err := c.marshalJSONCode(fieldType)
  612. if err != nil {
  613. return nil, err
  614. }
  615. fieldCode.value = code
  616. fieldCode.isAddrForMarshaler = true
  617. fieldCode.isNilCheck = false
  618. case isPtr && c.isPtrMarshalTextType(fieldType):
  619. // *struct{ field T }
  620. // func (*T) MarshalText() ([]byte, error)
  621. code, err := c.marshalTextCode(fieldType)
  622. if err != nil {
  623. return nil, err
  624. }
  625. fieldCode.value = code
  626. fieldCode.isAddrForMarshaler = true
  627. fieldCode.isNilCheck = false
  628. default:
  629. code, err := c.typeToCodeWithPtr(fieldType, isPtr)
  630. if err != nil {
  631. return nil, err
  632. }
  633. switch code.Kind() {
  634. case CodeKindPtr, CodeKindInterface:
  635. fieldCode.isNextOpPtrType = true
  636. }
  637. fieldCode.value = code
  638. }
  639. return fieldCode, nil
  640. }
  641. func (c *Compiler) isAssignableIndirect(fieldCode *StructFieldCode, isPtr bool) bool {
  642. if isPtr {
  643. return false
  644. }
  645. codeType := fieldCode.value.Kind()
  646. if codeType == CodeKindMarshalJSON {
  647. return false
  648. }
  649. if codeType == CodeKindMarshalText {
  650. return false
  651. }
  652. return true
  653. }
  654. func (c *Compiler) getFieldMap(fields []*StructFieldCode) map[string][]*StructFieldCode {
  655. fieldMap := map[string][]*StructFieldCode{}
  656. for _, field := range fields {
  657. if field.isAnonymous {
  658. for k, v := range c.getAnonymousFieldMap(field) {
  659. fieldMap[k] = append(fieldMap[k], v...)
  660. }
  661. continue
  662. }
  663. fieldMap[field.key] = append(fieldMap[field.key], field)
  664. }
  665. return fieldMap
  666. }
  667. func (c *Compiler) getAnonymousFieldMap(field *StructFieldCode) map[string][]*StructFieldCode {
  668. fieldMap := map[string][]*StructFieldCode{}
  669. structCode := field.getAnonymousStruct()
  670. if structCode == nil || structCode.isRecursive {
  671. fieldMap[field.key] = append(fieldMap[field.key], field)
  672. return fieldMap
  673. }
  674. for k, v := range c.getFieldMapFromAnonymousParent(structCode.fields) {
  675. fieldMap[k] = append(fieldMap[k], v...)
  676. }
  677. return fieldMap
  678. }
  679. func (c *Compiler) getFieldMapFromAnonymousParent(fields []*StructFieldCode) map[string][]*StructFieldCode {
  680. fieldMap := map[string][]*StructFieldCode{}
  681. for _, field := range fields {
  682. if field.isAnonymous {
  683. for k, v := range c.getAnonymousFieldMap(field) {
  684. // Do not handle tagged key when embedding more than once
  685. for _, vv := range v {
  686. vv.isTaggedKey = false
  687. }
  688. fieldMap[k] = append(fieldMap[k], v...)
  689. }
  690. continue
  691. }
  692. fieldMap[field.key] = append(fieldMap[field.key], field)
  693. }
  694. return fieldMap
  695. }
  696. func (c *Compiler) getDuplicatedFieldMap(fieldMap map[string][]*StructFieldCode) map[*StructFieldCode]struct{} {
  697. duplicatedFieldMap := map[*StructFieldCode]struct{}{}
  698. for _, fields := range fieldMap {
  699. if len(fields) == 1 {
  700. continue
  701. }
  702. if c.isTaggedKeyOnly(fields) {
  703. for _, field := range fields {
  704. if field.isTaggedKey {
  705. continue
  706. }
  707. duplicatedFieldMap[field] = struct{}{}
  708. }
  709. } else {
  710. for _, field := range fields {
  711. duplicatedFieldMap[field] = struct{}{}
  712. }
  713. }
  714. }
  715. return duplicatedFieldMap
  716. }
  717. func (c *Compiler) filteredDuplicatedFields(fields []*StructFieldCode, duplicatedFieldMap map[*StructFieldCode]struct{}) []*StructFieldCode {
  718. filteredFields := make([]*StructFieldCode, 0, len(fields))
  719. for _, field := range fields {
  720. if field.isAnonymous {
  721. structCode := field.getAnonymousStruct()
  722. if structCode != nil && !structCode.isRecursive {
  723. structCode.fields = c.filteredDuplicatedFields(structCode.fields, duplicatedFieldMap)
  724. if len(structCode.fields) > 0 {
  725. filteredFields = append(filteredFields, field)
  726. }
  727. continue
  728. }
  729. }
  730. if _, exists := duplicatedFieldMap[field]; exists {
  731. continue
  732. }
  733. filteredFields = append(filteredFields, field)
  734. }
  735. return filteredFields
  736. }
  737. func (c *Compiler) isTaggedKeyOnly(fields []*StructFieldCode) bool {
  738. var taggedKeyFieldCount int
  739. for _, field := range fields {
  740. if field.isTaggedKey {
  741. taggedKeyFieldCount++
  742. }
  743. }
  744. return taggedKeyFieldCount == 1
  745. }
  746. func (c *Compiler) typeToStructTags(typ *runtime.Type) runtime.StructTags {
  747. tags := runtime.StructTags{}
  748. fieldNum := typ.NumField()
  749. for i := 0; i < fieldNum; i++ {
  750. field := typ.Field(i)
  751. if runtime.IsIgnoredStructField(field) {
  752. continue
  753. }
  754. tags = append(tags, runtime.StructTagFromField(field))
  755. }
  756. return tags
  757. }
  758. // *struct{ field T } => struct { field *T }
  759. // func (*T) MarshalJSON() ([]byte, error)
  760. func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalJSONFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
  761. return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalJSONType(typ)
  762. }
  763. // *struct{ field T } => struct { field *T }
  764. // func (*T) MarshalText() ([]byte, error)
  765. func (c *Compiler) isMovePointerPositionFromHeadToFirstMarshalTextFieldCase(typ *runtime.Type, isIndirectSpecialCase bool) bool {
  766. return isIndirectSpecialCase && !c.isNilableType(typ) && c.isPtrMarshalTextType(typ)
  767. }
  768. func (c *Compiler) implementsMarshalJSON(typ *runtime.Type) bool {
  769. if !c.implementsMarshalJSONType(typ) {
  770. return false
  771. }
  772. if typ.Kind() != reflect.Ptr {
  773. return true
  774. }
  775. // type kind is reflect.Ptr
  776. if !c.implementsMarshalJSONType(typ.Elem()) {
  777. return true
  778. }
  779. // needs to dereference
  780. return false
  781. }
  782. func (c *Compiler) implementsMarshalText(typ *runtime.Type) bool {
  783. if !typ.Implements(marshalTextType) {
  784. return false
  785. }
  786. if typ.Kind() != reflect.Ptr {
  787. return true
  788. }
  789. // type kind is reflect.Ptr
  790. if !typ.Elem().Implements(marshalTextType) {
  791. return true
  792. }
  793. // needs to dereference
  794. return false
  795. }
  796. func (c *Compiler) isNilableType(typ *runtime.Type) bool {
  797. if !runtime.IfaceIndir(typ) {
  798. return true
  799. }
  800. switch typ.Kind() {
  801. case reflect.Ptr:
  802. return true
  803. case reflect.Map:
  804. return true
  805. case reflect.Func:
  806. return true
  807. default:
  808. return false
  809. }
  810. }
  811. func (c *Compiler) implementsMarshalJSONType(typ *runtime.Type) bool {
  812. return typ.Implements(marshalJSONType) || typ.Implements(marshalJSONContextType)
  813. }
  814. func (c *Compiler) isPtrMarshalJSONType(typ *runtime.Type) bool {
  815. return !c.implementsMarshalJSONType(typ) && c.implementsMarshalJSONType(runtime.PtrTo(typ))
  816. }
  817. func (c *Compiler) isPtrMarshalTextType(typ *runtime.Type) bool {
  818. return !typ.Implements(marshalTextType) && runtime.PtrTo(typ).Implements(marshalTextType)
  819. }
  820. func (c *Compiler) codeToOpcode(ctx *compileContext, typ *runtime.Type, code Code) *Opcode {
  821. codes := code.ToOpcode(ctx)
  822. codes.Last().Next = newEndOp(ctx, typ)
  823. c.linkRecursiveCode(ctx)
  824. return codes.First()
  825. }
  826. func (c *Compiler) linkRecursiveCode(ctx *compileContext) {
  827. recursiveCodes := map[uintptr]*CompiledCode{}
  828. for _, recursive := range *ctx.recursiveCodes {
  829. typeptr := uintptr(unsafe.Pointer(recursive.Type))
  830. codes := ctx.structTypeToCodes[typeptr]
  831. if recursiveCode, ok := recursiveCodes[typeptr]; ok {
  832. *recursive.Jmp = *recursiveCode
  833. continue
  834. }
  835. code := copyOpcode(codes.First())
  836. code.Op = code.Op.PtrHeadToHead()
  837. lastCode := newEndOp(&compileContext{}, recursive.Type)
  838. lastCode.Op = OpRecursiveEnd
  839. // OpRecursiveEnd must set before call TotalLength
  840. code.End.Next = lastCode
  841. totalLength := code.TotalLength()
  842. // Idx, ElemIdx, Length must set after call TotalLength
  843. lastCode.Idx = uint32((totalLength + 1) * uintptrSize)
  844. lastCode.ElemIdx = lastCode.Idx + uintptrSize
  845. lastCode.Length = lastCode.Idx + 2*uintptrSize
  846. // extend length to alloc slot for elemIdx + length
  847. curTotalLength := uintptr(recursive.TotalLength()) + 3
  848. nextTotalLength := uintptr(totalLength) + 3
  849. compiled := recursive.Jmp
  850. compiled.Code = code
  851. compiled.CurLen = curTotalLength
  852. compiled.NextLen = nextTotalLength
  853. compiled.Linked = true
  854. recursiveCodes[typeptr] = compiled
  855. }
  856. }