compiler.go 9.3 KB


  1. package optdec
  2. import (
  3. "fmt"
  4. "reflect"
  5. "github.com/bytedance/sonic/option"
  6. "github.com/bytedance/sonic/internal/rt"
  7. "github.com/bytedance/sonic/internal/caching"
  8. )
  9. var (
  10. programCache = caching.CreateProgramCache()
  11. )
  12. func findOrCompile(vt *rt.GoType) (decFunc, error) {
  13. makeDecoder := func(vt *rt.GoType, _ ...interface{}) (interface{}, error) {
  14. ret, err := newCompiler().compileType(vt.Pack())
  15. return ret, err
  16. }
  17. if val := programCache.Get(vt); val != nil {
  18. return val.(decFunc), nil
  19. } else if ret, err := programCache.Compute(vt, makeDecoder); err == nil {
  20. return ret.(decFunc), nil
  21. } else {
  22. return nil, err
  23. }
  24. }
  25. type compiler struct {
  26. visited map[reflect.Type]bool
  27. depth int
  28. counts int
  29. opts option.CompileOptions
  30. namedPtr bool
  31. }
  32. func newCompiler() *compiler {
  33. return &compiler{
  34. visited: make(map[reflect.Type]bool),
  35. opts: option.DefaultCompileOptions(),
  36. }
  37. }
  38. func (self *compiler) apply(opts option.CompileOptions) *compiler {
  39. self.opts = opts
  40. return self
  41. }
  42. const _CompileMaxDepth = 4096
  43. func (c *compiler) enter(vt reflect.Type) {
  44. c.visited[vt] = true
  45. c.depth += 1
  46. if c.depth > _CompileMaxDepth {
  47. panic(*stackOverflow)
  48. }
  49. }
  50. func (c *compiler) exit(vt reflect.Type) {
  51. c.visited[vt] = false
  52. c.depth -= 1
  53. }
  54. func (c *compiler) compileInt(vt reflect.Type) decFunc {
  55. switch vt.Size() {
  56. case 4:
  57. switch vt.Kind() {
  58. case reflect.Uint:
  59. fallthrough
  60. case reflect.Uintptr:
  61. return &u32Decoder{}
  62. case reflect.Int:
  63. return &i32Decoder{}
  64. }
  65. case 8:
  66. switch vt.Kind() {
  67. case reflect.Uint:
  68. fallthrough
  69. case reflect.Uintptr:
  70. return &u64Decoder{}
  71. case reflect.Int:
  72. return &i64Decoder{}
  73. }
  74. default:
  75. panic("not supported pointer size: " + fmt.Sprint(vt.Size()))
  76. }
  77. panic("unreachable")
  78. }
  79. func (c *compiler) rescue(ep *error) {
  80. if val := recover(); val != nil {
  81. if err, ok := val.(error); ok {
  82. *ep = err
  83. } else {
  84. panic(val)
  85. }
  86. }
  87. }
  88. func (c *compiler) compileType(vt reflect.Type) (rt decFunc, err error) {
  89. defer c.rescue(&err)
  90. rt = c.compile(vt)
  91. return rt, err
  92. }
  93. func (c *compiler) compile(vt reflect.Type) decFunc {
  94. if c.visited[vt] {
  95. return &recuriveDecoder{
  96. typ: rt.UnpackType(vt),
  97. }
  98. }
  99. dec := c.tryCompilePtrUnmarshaler(vt, false)
  100. if dec != nil {
  101. return dec
  102. }
  103. return c.compileBasic(vt)
  104. }
  105. func (c *compiler) compileBasic(vt reflect.Type) decFunc {
  106. defer func() {
  107. c.counts += 1
  108. }()
  109. switch vt.Kind() {
  110. case reflect.Bool:
  111. return &boolDecoder{}
  112. case reflect.Int8:
  113. return &i8Decoder{}
  114. case reflect.Int16:
  115. return &i16Decoder{}
  116. case reflect.Int32:
  117. return &i32Decoder{}
  118. case reflect.Int64:
  119. return &i64Decoder{}
  120. case reflect.Uint8:
  121. return &u8Decoder{}
  122. case reflect.Uint16:
  123. return &u16Decoder{}
  124. case reflect.Uint32:
  125. return &u32Decoder{}
  126. case reflect.Uint64:
  127. return &u64Decoder{}
  128. case reflect.Float32:
  129. return &f32Decoder{}
  130. case reflect.Float64:
  131. return &f64Decoder{}
  132. case reflect.Uint:
  133. fallthrough
  134. case reflect.Uintptr:
  135. fallthrough
  136. case reflect.Int:
  137. return c.compileInt(vt)
  138. case reflect.String:
  139. return c.compileString(vt)
  140. case reflect.Array:
  141. return c.compileArray(vt)
  142. case reflect.Interface:
  143. return c.compileInterface(vt)
  144. case reflect.Map:
  145. return c.compileMap(vt)
  146. case reflect.Ptr:
  147. return c.compilePtr(vt)
  148. case reflect.Slice:
  149. return c.compileSlice(vt)
  150. case reflect.Struct:
  151. return c.compileStruct(vt)
  152. default:
  153. return &unsupportedTypeDecoder{
  154. typ: rt.UnpackType(vt),
  155. }
  156. }
  157. }
  158. func (c *compiler) compilePtr(vt reflect.Type) decFunc {
  159. c.enter(vt)
  160. defer c.exit(vt)
  161. // special logic for Named Ptr, issue 379
  162. if reflect.PtrTo(vt.Elem()) != vt {
  163. c.namedPtr = true
  164. return &ptrDecoder{
  165. typ: rt.UnpackType(vt.Elem()),
  166. deref: c.compileBasic(vt.Elem()),
  167. }
  168. }
  169. return &ptrDecoder{
  170. typ: rt.UnpackType(vt.Elem()),
  171. deref: c.compile(vt.Elem()),
  172. }
  173. }
  174. func (c *compiler) compileArray(vt reflect.Type) decFunc {
  175. c.enter(vt)
  176. defer c.exit(vt)
  177. return &arrayDecoder{
  178. len: vt.Len(),
  179. elemType: rt.UnpackType(vt.Elem()),
  180. elemDec: c.compile(vt.Elem()),
  181. typ: vt,
  182. }
  183. }
  184. func (c *compiler) compileString(vt reflect.Type) decFunc {
  185. if vt == jsonNumberType {
  186. return &numberDecoder{}
  187. }
  188. return &stringDecoder{}
  189. }
  190. func (c *compiler) tryCompileSliceUnmarshaler(vt reflect.Type) decFunc {
  191. pt := reflect.PtrTo(vt.Elem())
  192. if pt.Implements(jsonUnmarshalerType) {
  193. return &sliceDecoder{
  194. elemType: rt.UnpackType(vt.Elem()),
  195. elemDec: c.compile(vt.Elem()),
  196. typ: vt,
  197. }
  198. }
  199. if pt.Implements(encodingTextUnmarshalerType) {
  200. return &sliceDecoder{
  201. elemType: rt.UnpackType(vt.Elem()),
  202. elemDec: c.compile(vt.Elem()),
  203. typ: vt,
  204. }
  205. }
  206. return nil
  207. }
  208. func (c *compiler) compileSlice(vt reflect.Type) decFunc {
  209. c.enter(vt)
  210. defer c.exit(vt)
  211. // Some common slice, use a decoder, to avoid function calls
  212. et := rt.UnpackType(vt.Elem())
  213. /* first checking `[]byte` */
  214. if et.Kind() == reflect.Uint8 /* []byte */ {
  215. return c.compileSliceBytes(vt)
  216. }
  217. dec := c.tryCompileSliceUnmarshaler(vt)
  218. if dec != nil {
  219. return dec
  220. }
  221. if vt == reflect.TypeOf([]interface{}{}) {
  222. return &sliceEfaceDecoder{}
  223. }
  224. if et.IsInt32() {
  225. return &sliceI32Decoder{}
  226. }
  227. if et.IsInt64() {
  228. return &sliceI64Decoder{}
  229. }
  230. if et.IsUint32() {
  231. return &sliceU32Decoder{}
  232. }
  233. if et.IsUint64() {
  234. return &sliceU64Decoder{}
  235. }
  236. if et.Kind() == reflect.String && et != rt.JsonNumberType {
  237. return &sliceStringDecoder{}
  238. }
  239. return &sliceDecoder{
  240. elemType: rt.UnpackType(vt.Elem()),
  241. elemDec: c.compile(vt.Elem()),
  242. typ: vt,
  243. }
  244. }
  245. func (c *compiler) compileSliceBytes(vt reflect.Type) decFunc {
  246. ep := reflect.PtrTo(vt.Elem())
  247. if ep.Implements(jsonUnmarshalerType) {
  248. return &sliceBytesUnmarshalerDecoder{
  249. elemType: rt.UnpackType(vt.Elem()),
  250. elemDec: c.compile(vt.Elem()),
  251. typ: vt,
  252. }
  253. }
  254. if ep.Implements(encodingTextUnmarshalerType) {
  255. return &sliceBytesUnmarshalerDecoder{
  256. elemType: rt.UnpackType(vt.Elem()),
  257. elemDec: c.compile(vt.Elem()),
  258. typ: vt,
  259. }
  260. }
  261. return &sliceBytesDecoder{}
  262. }
  263. func (c *compiler) compileInterface(vt reflect.Type) decFunc {
  264. c.enter(vt)
  265. defer c.exit(vt)
  266. if vt.NumMethod() == 0 {
  267. return &efaceDecoder{}
  268. }
  269. if vt.Implements(jsonUnmarshalerType) {
  270. return &unmarshalJSONDecoder{
  271. typ: rt.UnpackType(vt),
  272. }
  273. }
  274. if vt.Implements(encodingTextUnmarshalerType) {
  275. return &unmarshalTextDecoder{
  276. typ: rt.UnpackType(vt),
  277. }
  278. }
  279. return &ifaceDecoder{
  280. typ: rt.UnpackType(vt),
  281. }
  282. }
  283. func (c *compiler) compileMap(vt reflect.Type) decFunc {
  284. c.enter(vt)
  285. defer c.exit(vt)
  286. // check the key unmarshaler at first
  287. decKey := tryCompileKeyUnmarshaler(vt)
  288. if decKey != nil {
  289. return &mapDecoder{
  290. mapType: rt.MapType(rt.UnpackType(vt)),
  291. keyDec: decKey,
  292. elemDec: c.compile(vt.Elem()),
  293. }
  294. }
  295. // Most common map, use a decoder, to avoid function calls
  296. if vt == reflect.TypeOf(map[string]interface{}{}) {
  297. return &mapEfaceDecoder{}
  298. } else if vt == reflect.TypeOf(map[string]string{}) {
  299. return &mapStringDecoder{}
  300. }
  301. // Some common integer map later
  302. mt := rt.MapType(rt.UnpackType(vt))
  303. if mt.Key.Kind() == reflect.String && mt.Key != rt.JsonNumberType {
  304. return &mapStrKeyDecoder{
  305. mapType: mt,
  306. assign: rt.GetMapStrAssign(vt),
  307. elemDec: c.compile(vt.Elem()),
  308. }
  309. }
  310. if mt.Key.IsInt64() {
  311. return &mapI64KeyDecoder{
  312. mapType: mt,
  313. elemDec: c.compile(vt.Elem()),
  314. assign: rt.GetMap64Assign(vt),
  315. }
  316. }
  317. if mt.Key.IsInt32() {
  318. return &mapI32KeyDecoder{
  319. mapType: mt,
  320. elemDec: c.compile(vt.Elem()),
  321. assign: rt.GetMap32Assign(vt),
  322. }
  323. }
  324. if mt.Key.IsUint64() {
  325. return &mapU64KeyDecoder{
  326. mapType: mt,
  327. elemDec: c.compile(vt.Elem()),
  328. assign: rt.GetMap64Assign(vt),
  329. }
  330. }
  331. if mt.Key.IsUint32() {
  332. return &mapU32KeyDecoder{
  333. mapType: mt,
  334. elemDec: c.compile(vt.Elem()),
  335. assign: rt.GetMap32Assign(vt),
  336. }
  337. }
  338. // Generic map
  339. return &mapDecoder{
  340. mapType: mt,
  341. keyDec: c.compileMapKey(vt),
  342. elemDec: c.compile(vt.Elem()),
  343. }
  344. }
  345. func tryCompileKeyUnmarshaler(vt reflect.Type) decKey {
  346. pt := reflect.PtrTo(vt.Key())
  347. /* check for `encoding.TextUnmarshaler` with pointer receiver */
  348. if pt.Implements(encodingTextUnmarshalerType) {
  349. return decodeKeyTextUnmarshaler
  350. }
  351. /* NOTE: encoding/json not support map key with `json.Unmarshaler` */
  352. return nil
  353. }
  354. func (c *compiler) compileMapKey(vt reflect.Type) decKey {
  355. switch vt.Key().Kind() {
  356. case reflect.Int8:
  357. return decodeKeyI8
  358. case reflect.Int16:
  359. return decodeKeyI16
  360. case reflect.Uint8:
  361. return decodeKeyU8
  362. case reflect.Uint16:
  363. return decodeKeyU16
  364. // NOTE: actually, encoding/json can't use float as map key
  365. case reflect.Float32:
  366. return decodeFloat32Key
  367. case reflect.Float64:
  368. return decodeFloat64Key
  369. case reflect.String:
  370. if rt.UnpackType(vt.Key()) == rt.JsonNumberType {
  371. return decodeJsonNumberKey
  372. }
  373. fallthrough
  374. default:
  375. return nil
  376. }
  377. }
  378. // maybe vt is a named type, and not a pointer receiver, see issue 379
  379. func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type, strOpt bool) decFunc {
  380. pt := reflect.PtrTo(vt)
  381. /* check for `json.Unmarshaler` with pointer receiver */
  382. if pt.Implements(jsonUnmarshalerType) {
  383. return &unmarshalJSONDecoder{
  384. typ: rt.UnpackType(pt),
  385. strOpt: strOpt,
  386. }
  387. }
  388. /* check for `encoding.TextMarshaler` with pointer receiver */
  389. if pt.Implements(encodingTextUnmarshalerType) {
  390. /* TextUnmarshal not support, string tag */
  391. if strOpt {
  392. panicForInvalidStrType(vt)
  393. }
  394. return &unmarshalTextDecoder{
  395. typ: rt.UnpackType(pt),
  396. }
  397. }
  398. return nil
  399. }
  400. func panicForInvalidStrType(vt reflect.Type) {
  401. panic(error_type(rt.UnpackType(vt)))
  402. }