lazy.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. // Copyright 2024 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "math/bits"
  8. "os"
  9. "reflect"
  10. "sort"
  11. "sync/atomic"
  12. "google.golang.org/protobuf/encoding/protowire"
  13. "google.golang.org/protobuf/internal/errors"
  14. "google.golang.org/protobuf/internal/protolazy"
  15. "google.golang.org/protobuf/reflect/protoreflect"
  16. preg "google.golang.org/protobuf/reflect/protoregistry"
  17. piface "google.golang.org/protobuf/runtime/protoiface"
  18. )
  19. var enableLazy int32 = func() int32 {
  20. if os.Getenv("GOPROTODEBUG") == "nolazy" {
  21. return 0
  22. }
  23. return 1
  24. }()
  25. // EnableLazyUnmarshal enables lazy unmarshaling.
  26. func EnableLazyUnmarshal(enable bool) {
  27. if enable {
  28. atomic.StoreInt32(&enableLazy, 1)
  29. return
  30. }
  31. atomic.StoreInt32(&enableLazy, 0)
  32. }
  33. // LazyEnabled reports whether lazy unmarshalling is currently enabled.
  34. func LazyEnabled() bool {
  35. return atomic.LoadInt32(&enableLazy) != 0
  36. }
  37. // UnmarshalField unmarshals a field in a message.
  38. func UnmarshalField(m interface{}, num protowire.Number) {
  39. switch m := m.(type) {
  40. case *messageState:
  41. m.messageInfo().lazyUnmarshal(m.pointer(), num)
  42. case *messageReflectWrapper:
  43. m.messageInfo().lazyUnmarshal(m.pointer(), num)
  44. default:
  45. panic(fmt.Sprintf("unsupported wrapper type %T", m))
  46. }
  47. }
  48. func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) {
  49. var f *coderFieldInfo
  50. if int(num) < len(mi.denseCoderFields) {
  51. f = mi.denseCoderFields[num]
  52. } else {
  53. f = mi.coderFields[num]
  54. }
  55. if f == nil {
  56. panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num))
  57. }
  58. lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
  59. start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num))
  60. if !found && multipleEntries == nil {
  61. panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num))
  62. }
  63. // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races.
  64. // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil.
  65. fp := pointerOfValue(reflect.New(f.ft))
  66. if multipleEntries != nil {
  67. for _, entry := range multipleEntries {
  68. mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags())
  69. }
  70. } else {
  71. mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags())
  72. }
  73. p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem())
  74. }
  75. func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error {
  76. opts := lazyUnmarshalOptions
  77. opts.flags |= flags
  78. for len(b) > 0 {
  79. // Parse the tag (field number and wire type).
  80. var tag uint64
  81. if b[0] < 0x80 {
  82. tag = uint64(b[0])
  83. b = b[1:]
  84. } else if len(b) >= 2 && b[1] < 128 {
  85. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  86. b = b[2:]
  87. } else {
  88. var n int
  89. tag, n = protowire.ConsumeVarint(b)
  90. if n < 0 {
  91. return errors.New("invalid wire data")
  92. }
  93. b = b[n:]
  94. }
  95. var num protowire.Number
  96. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  97. return errors.New("invalid wire data")
  98. } else {
  99. num = protowire.Number(n)
  100. }
  101. wtyp := protowire.Type(tag & 7)
  102. if num == f.num {
  103. o, err := f.funcs.unmarshal(b, p, wtyp, f, opts)
  104. if err == nil {
  105. b = b[o.n:]
  106. continue
  107. }
  108. if err != errUnknown {
  109. return err
  110. }
  111. }
  112. n := protowire.ConsumeFieldValue(num, wtyp, b)
  113. if n < 0 {
  114. return errors.New("invalid wire data")
  115. }
  116. b = b[n:]
  117. }
  118. return nil
  119. }
  120. func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
  121. fmi := f.validation.mi
  122. if fmi == nil {
  123. fd := mi.Desc.Fields().ByNumber(f.num)
  124. if fd == nil {
  125. return out, ValidationUnknown
  126. }
  127. messageName := fd.Message().FullName()
  128. messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
  129. if err != nil {
  130. return out, ValidationUnknown
  131. }
  132. var ok bool
  133. fmi, ok = messageType.(*MessageInfo)
  134. if !ok {
  135. return out, ValidationUnknown
  136. }
  137. }
  138. fmi.init()
  139. switch f.validation.typ {
  140. case validationTypeMessage:
  141. if wtyp != protowire.BytesType {
  142. return out, ValidationWrongWireType
  143. }
  144. v, n := protowire.ConsumeBytes(b)
  145. if n < 0 {
  146. return out, ValidationInvalid
  147. }
  148. out, st := fmi.validate(v, 0, opts)
  149. out.n = n
  150. return out, st
  151. case validationTypeGroup:
  152. if wtyp != protowire.StartGroupType {
  153. return out, ValidationWrongWireType
  154. }
  155. out, st := fmi.validate(b, f.num, opts)
  156. return out, st
  157. default:
  158. return out, ValidationUnknown
  159. }
  160. }
  161. // unmarshalPointerLazy is similar to unmarshalPointerEager, but it
  162. // specifically handles lazy unmarshalling. it expects lazyOffset and
  163. // presenceOffset to both be valid.
  164. func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  165. initialized := true
  166. var requiredMask uint64
  167. var lazy **protolazy.XXX_lazyUnmarshalInfo
  168. var presence presence
  169. var lazyIndex []protolazy.IndexEntry
  170. var lastNum protowire.Number
  171. outOfOrder := false
  172. lazyDecode := false
  173. presence = p.Apply(mi.presenceOffset).PresenceInfo()
  174. lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
  175. if !presence.AnyPresent(mi.presenceSize) {
  176. if opts.CanBeLazy() {
  177. // If the message contains existing data, we need to merge into it.
  178. // Lazy unmarshaling doesn't merge, so only enable it when the
  179. // message is empty (has no presence bitmap).
  180. lazyDecode = true
  181. if *lazy == nil {
  182. *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
  183. }
  184. (*lazy).SetUnmarshalFlags(opts.flags)
  185. if !opts.AliasBuffer() {
  186. // Make a copy of the buffer for lazy unmarshaling.
  187. // Set the AliasBuffer flag so recursive unmarshal
  188. // operations reuse the copy.
  189. b = append([]byte{}, b...)
  190. opts.flags |= piface.UnmarshalAliasBuffer
  191. }
  192. (*lazy).SetBuffer(b)
  193. }
  194. }
  195. // Track special handling of lazy fields.
  196. //
  197. // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil).
  198. // In the event that validation for a field fails, this map tracks handling of the field.
  199. type lazyAction uint8
  200. const (
  201. lazyValidateOnly lazyAction = iota // validate the field only
  202. lazyUnmarshalNow // eagerly unmarshal the field
  203. lazyUnmarshalLater // unmarshal the field after the message is fully processed
  204. )
  205. var lazyFields map[*coderFieldInfo]lazyAction
  206. var exts *map[int32]ExtensionField
  207. start := len(b)
  208. pos := 0
  209. for len(b) > 0 {
  210. // Parse the tag (field number and wire type).
  211. var tag uint64
  212. if b[0] < 0x80 {
  213. tag = uint64(b[0])
  214. b = b[1:]
  215. } else if len(b) >= 2 && b[1] < 128 {
  216. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  217. b = b[2:]
  218. } else {
  219. var n int
  220. tag, n = protowire.ConsumeVarint(b)
  221. if n < 0 {
  222. return out, errDecode
  223. }
  224. b = b[n:]
  225. }
  226. var num protowire.Number
  227. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  228. return out, errors.New("invalid field number")
  229. } else {
  230. num = protowire.Number(n)
  231. }
  232. wtyp := protowire.Type(tag & 7)
  233. if wtyp == protowire.EndGroupType {
  234. if num != groupTag {
  235. return out, errors.New("mismatching end group marker")
  236. }
  237. groupTag = 0
  238. break
  239. }
  240. var f *coderFieldInfo
  241. if int(num) < len(mi.denseCoderFields) {
  242. f = mi.denseCoderFields[num]
  243. } else {
  244. f = mi.coderFields[num]
  245. }
  246. var n int
  247. err := errUnknown
  248. discardUnknown := false
  249. Field:
  250. switch {
  251. case f != nil:
  252. if f.funcs.unmarshal == nil {
  253. break
  254. }
  255. if f.isLazy && lazyDecode {
  256. switch {
  257. case lazyFields == nil || lazyFields[f] == lazyValidateOnly:
  258. // Attempt to validate this field and leave it for later lazy unmarshaling.
  259. o, valid := mi.skipField(b, f, wtyp, opts)
  260. switch valid {
  261. case ValidationValid:
  262. // Skip over the valid field and continue.
  263. err = nil
  264. presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
  265. requiredMask |= f.validation.requiredBit
  266. if !o.initialized {
  267. initialized = false
  268. }
  269. n = o.n
  270. break Field
  271. case ValidationInvalid:
  272. return out, errors.New("invalid proto wire format")
  273. case ValidationWrongWireType:
  274. break Field
  275. case ValidationUnknown:
  276. if lazyFields == nil {
  277. lazyFields = make(map[*coderFieldInfo]lazyAction)
  278. }
  279. if presence.Present(f.presenceIndex) {
  280. // We were unable to determine if the field is valid or not,
  281. // and we've already skipped over at least one instance of this
  282. // field. Clear the presence bit (so if we stop decoding early,
  283. // we don't leave a partially-initialized field around) and flag
  284. // the field for unmarshaling before we return.
  285. presence.ClearPresent(f.presenceIndex)
  286. lazyFields[f] = lazyUnmarshalLater
  287. discardUnknown = true
  288. break Field
  289. } else {
  290. // We were unable to determine if the field is valid or not,
  291. // but this is the first time we've seen it. Flag it as needing
  292. // eager unmarshaling and fall through to the eager unmarshal case below.
  293. lazyFields[f] = lazyUnmarshalNow
  294. }
  295. }
  296. case lazyFields[f] == lazyUnmarshalLater:
  297. // This field will be unmarshaled in a separate pass below.
  298. // Skip over it here.
  299. discardUnknown = true
  300. break Field
  301. default:
  302. // Eagerly unmarshal the field.
  303. }
  304. }
  305. if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) {
  306. if p.Apply(f.offset).AtomicGetPointer().IsNil() {
  307. mi.lazyUnmarshal(p, f.num)
  308. }
  309. }
  310. var o unmarshalOutput
  311. o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
  312. n = o.n
  313. if err != nil {
  314. break
  315. }
  316. requiredMask |= f.validation.requiredBit
  317. if f.funcs.isInit != nil && !o.initialized {
  318. initialized = false
  319. }
  320. if f.presenceIndex != noPresence {
  321. presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
  322. }
  323. default:
  324. // Possible extension.
  325. if exts == nil && mi.extensionOffset.IsValid() {
  326. exts = p.Apply(mi.extensionOffset).Extensions()
  327. if *exts == nil {
  328. *exts = make(map[int32]ExtensionField)
  329. }
  330. }
  331. if exts == nil {
  332. break
  333. }
  334. var o unmarshalOutput
  335. o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
  336. if err != nil {
  337. break
  338. }
  339. n = o.n
  340. if !o.initialized {
  341. initialized = false
  342. }
  343. }
  344. if err != nil {
  345. if err != errUnknown {
  346. return out, err
  347. }
  348. n = protowire.ConsumeFieldValue(num, wtyp, b)
  349. if n < 0 {
  350. return out, errDecode
  351. }
  352. if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
  353. u := mi.mutableUnknownBytes(p)
  354. *u = protowire.AppendTag(*u, num, wtyp)
  355. *u = append(*u, b[:n]...)
  356. }
  357. }
  358. b = b[n:]
  359. end := start - len(b)
  360. if lazyDecode && f != nil && f.isLazy {
  361. if num != lastNum {
  362. lazyIndex = append(lazyIndex, protolazy.IndexEntry{
  363. FieldNum: uint32(num),
  364. Start: uint32(pos),
  365. End: uint32(end),
  366. })
  367. } else {
  368. i := len(lazyIndex) - 1
  369. lazyIndex[i].End = uint32(end)
  370. lazyIndex[i].MultipleContiguous = true
  371. }
  372. }
  373. if num < lastNum {
  374. outOfOrder = true
  375. }
  376. pos = end
  377. lastNum = num
  378. }
  379. if groupTag != 0 {
  380. return out, errors.New("missing end group marker")
  381. }
  382. if lazyFields != nil {
  383. // Some fields failed validation, and now need to be unmarshaled.
  384. for f, action := range lazyFields {
  385. if action != lazyUnmarshalLater {
  386. continue
  387. }
  388. initialized = false
  389. if *lazy == nil {
  390. *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
  391. }
  392. if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil {
  393. return out, err
  394. }
  395. presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
  396. }
  397. }
  398. if lazyDecode {
  399. if outOfOrder {
  400. sort.Slice(lazyIndex, func(i, j int) bool {
  401. return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum ||
  402. (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum &&
  403. lazyIndex[i].Start < lazyIndex[j].Start)
  404. })
  405. }
  406. if *lazy == nil {
  407. *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
  408. }
  409. (*lazy).SetIndex(lazyIndex)
  410. }
  411. if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
  412. initialized = false
  413. }
  414. if initialized {
  415. out.initialized = true
  416. }
  417. out.n = start - len(b)
  418. return out, nil
  419. }