decode.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "math/bits"
  7. "google.golang.org/protobuf/encoding/protowire"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/proto"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. )
  15. var errDecode = errors.New("cannot parse invalid wire-format data")
  16. var errRecursionDepth = errors.New("exceeded maximum recursion depth")
  17. type unmarshalOptions struct {
  18. flags protoiface.UnmarshalInputFlags
  19. resolver interface {
  20. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  21. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  22. }
  23. depth int
  24. }
  25. func (o unmarshalOptions) Options() proto.UnmarshalOptions {
  26. return proto.UnmarshalOptions{
  27. Merge: true,
  28. AllowPartial: true,
  29. DiscardUnknown: o.DiscardUnknown(),
  30. Resolver: o.resolver,
  31. NoLazyDecoding: o.NoLazyDecoding(),
  32. }
  33. }
  34. func (o unmarshalOptions) DiscardUnknown() bool {
  35. return o.flags&protoiface.UnmarshalDiscardUnknown != 0
  36. }
  37. func (o unmarshalOptions) AliasBuffer() bool { return o.flags&protoiface.UnmarshalAliasBuffer != 0 }
  38. func (o unmarshalOptions) Validated() bool { return o.flags&protoiface.UnmarshalValidated != 0 }
  39. func (o unmarshalOptions) NoLazyDecoding() bool {
  40. return o.flags&protoiface.UnmarshalNoLazyDecoding != 0
  41. }
  42. func (o unmarshalOptions) CanBeLazy() bool {
  43. if o.resolver != protoregistry.GlobalTypes {
  44. return false
  45. }
  46. // We ignore the UnmarshalInvalidateSizeCache even though it's not in the default set
  47. return (o.flags & ^(protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated | protoiface.UnmarshalCheckRequired)) == 0
  48. }
  49. var lazyUnmarshalOptions = unmarshalOptions{
  50. resolver: protoregistry.GlobalTypes,
  51. flags: protoiface.UnmarshalAliasBuffer | protoiface.UnmarshalValidated,
  52. depth: protowire.DefaultRecursionLimit,
  53. }
  54. type unmarshalOutput struct {
  55. n int // number of bytes consumed
  56. initialized bool
  57. }
  58. // unmarshal is protoreflect.Methods.Unmarshal.
  59. func (mi *MessageInfo) unmarshal(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
  60. var p pointer
  61. if ms, ok := in.Message.(*messageState); ok {
  62. p = ms.pointer()
  63. } else {
  64. p = in.Message.(*messageReflectWrapper).pointer()
  65. }
  66. out, err := mi.unmarshalPointer(in.Buf, p, 0, unmarshalOptions{
  67. flags: in.Flags,
  68. resolver: in.Resolver,
  69. depth: in.Depth,
  70. })
  71. var flags protoiface.UnmarshalOutputFlags
  72. if out.initialized {
  73. flags |= protoiface.UnmarshalInitialized
  74. }
  75. return protoiface.UnmarshalOutput{
  76. Flags: flags,
  77. }, err
  78. }
  79. // errUnknown is returned during unmarshaling to indicate a parse error that
  80. // should result in a field being placed in the unknown fields section (for example,
  81. // when the wire type doesn't match) as opposed to the entire unmarshal operation
  82. // failing (for example, when a field extends past the available input).
  83. //
  84. // This is a sentinel error which should never be visible to the user.
  85. var errUnknown = errors.New("unknown")
  86. func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  87. mi.init()
  88. opts.depth--
  89. if opts.depth < 0 {
  90. return out, errRecursionDepth
  91. }
  92. if flags.ProtoLegacy && mi.isMessageSet {
  93. return unmarshalMessageSet(mi, b, p, opts)
  94. }
  95. lazyDecoding := LazyEnabled() // default
  96. if opts.NoLazyDecoding() {
  97. lazyDecoding = false // explicitly disabled
  98. }
  99. if mi.lazyOffset.IsValid() && lazyDecoding {
  100. return mi.unmarshalPointerLazy(b, p, groupTag, opts)
  101. }
  102. return mi.unmarshalPointerEager(b, p, groupTag, opts)
  103. }
  104. // unmarshalPointerEager is the message unmarshalling function for all messages that are not lazy.
  105. // The corresponding function for Lazy is in google_lazy.go.
  106. func (mi *MessageInfo) unmarshalPointerEager(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
  107. initialized := true
  108. var requiredMask uint64
  109. var exts *map[int32]ExtensionField
  110. var presence presence
  111. if mi.presenceOffset.IsValid() {
  112. presence = p.Apply(mi.presenceOffset).PresenceInfo()
  113. }
  114. start := len(b)
  115. for len(b) > 0 {
  116. // Parse the tag (field number and wire type).
  117. var tag uint64
  118. if b[0] < 0x80 {
  119. tag = uint64(b[0])
  120. b = b[1:]
  121. } else if len(b) >= 2 && b[1] < 128 {
  122. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  123. b = b[2:]
  124. } else {
  125. var n int
  126. tag, n = protowire.ConsumeVarint(b)
  127. if n < 0 {
  128. return out, errDecode
  129. }
  130. b = b[n:]
  131. }
  132. var num protowire.Number
  133. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  134. return out, errDecode
  135. } else {
  136. num = protowire.Number(n)
  137. }
  138. wtyp := protowire.Type(tag & 7)
  139. if wtyp == protowire.EndGroupType {
  140. if num != groupTag {
  141. return out, errDecode
  142. }
  143. groupTag = 0
  144. break
  145. }
  146. var f *coderFieldInfo
  147. if int(num) < len(mi.denseCoderFields) {
  148. f = mi.denseCoderFields[num]
  149. } else {
  150. f = mi.coderFields[num]
  151. }
  152. var n int
  153. err := errUnknown
  154. switch {
  155. case f != nil:
  156. if f.funcs.unmarshal == nil {
  157. break
  158. }
  159. var o unmarshalOutput
  160. o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
  161. n = o.n
  162. if err != nil {
  163. break
  164. }
  165. requiredMask |= f.validation.requiredBit
  166. if f.funcs.isInit != nil && !o.initialized {
  167. initialized = false
  168. }
  169. if f.presenceIndex != noPresence {
  170. presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
  171. }
  172. default:
  173. // Possible extension.
  174. if exts == nil && mi.extensionOffset.IsValid() {
  175. exts = p.Apply(mi.extensionOffset).Extensions()
  176. if *exts == nil {
  177. *exts = make(map[int32]ExtensionField)
  178. }
  179. }
  180. if exts == nil {
  181. break
  182. }
  183. var o unmarshalOutput
  184. o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
  185. if err != nil {
  186. break
  187. }
  188. n = o.n
  189. if !o.initialized {
  190. initialized = false
  191. }
  192. }
  193. if err != nil {
  194. if err != errUnknown {
  195. return out, err
  196. }
  197. n = protowire.ConsumeFieldValue(num, wtyp, b)
  198. if n < 0 {
  199. return out, errDecode
  200. }
  201. if !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
  202. u := mi.mutableUnknownBytes(p)
  203. *u = protowire.AppendTag(*u, num, wtyp)
  204. *u = append(*u, b[:n]...)
  205. }
  206. }
  207. b = b[n:]
  208. }
  209. if groupTag != 0 {
  210. return out, errDecode
  211. }
  212. if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
  213. initialized = false
  214. }
  215. if initialized {
  216. out.initialized = true
  217. }
  218. out.n = start - len(b)
  219. return out, nil
  220. }
  221. func (mi *MessageInfo) unmarshalExtension(b []byte, num protowire.Number, wtyp protowire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
  222. x := exts[int32(num)]
  223. xt := x.Type()
  224. if xt == nil {
  225. var err error
  226. xt, err = opts.resolver.FindExtensionByNumber(mi.Desc.FullName(), num)
  227. if err != nil {
  228. if err == protoregistry.NotFound {
  229. return out, errUnknown
  230. }
  231. return out, errors.New("%v: unable to resolve extension %v: %v", mi.Desc.FullName(), num, err)
  232. }
  233. }
  234. xi := getExtensionFieldInfo(xt)
  235. if xi.funcs.unmarshal == nil {
  236. return out, errUnknown
  237. }
  238. if flags.LazyUnmarshalExtensions {
  239. if opts.CanBeLazy() && x.canLazy(xt) {
  240. out, valid := skipExtension(b, xi, num, wtyp, opts)
  241. switch valid {
  242. case ValidationValid:
  243. if out.initialized {
  244. x.appendLazyBytes(xt, xi, num, wtyp, b[:out.n])
  245. exts[int32(num)] = x
  246. return out, nil
  247. }
  248. case ValidationInvalid:
  249. return out, errDecode
  250. case ValidationUnknown:
  251. }
  252. }
  253. }
  254. ival := x.Value()
  255. if !ival.IsValid() && xi.unmarshalNeedsValue {
  256. // Create a new message, list, or map value to fill in.
  257. // For enums, create a prototype value to let the unmarshal func know the
  258. // concrete type.
  259. ival = xt.New()
  260. }
  261. v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
  262. if err != nil {
  263. return out, err
  264. }
  265. if xi.funcs.isInit == nil {
  266. out.initialized = true
  267. }
  268. x.Set(xt, v)
  269. exts[int32(num)] = x
  270. return out, nil
  271. }
  272. func skipExtension(b []byte, xi *extensionFieldInfo, num protowire.Number, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
  273. if xi.validation.mi == nil {
  274. return out, ValidationUnknown
  275. }
  276. xi.validation.mi.init()
  277. switch xi.validation.typ {
  278. case validationTypeMessage:
  279. if wtyp != protowire.BytesType {
  280. return out, ValidationUnknown
  281. }
  282. v, n := protowire.ConsumeBytes(b)
  283. if n < 0 {
  284. return out, ValidationUnknown
  285. }
  286. if opts.Validated() {
  287. out.initialized = true
  288. out.n = n
  289. return out, ValidationValid
  290. }
  291. out, st := xi.validation.mi.validate(v, 0, opts)
  292. out.n = n
  293. return out, st
  294. case validationTypeGroup:
  295. if wtyp != protowire.StartGroupType {
  296. return out, ValidationUnknown
  297. }
  298. out, st := xi.validation.mi.validate(b, num, opts)
  299. return out, st
  300. default:
  301. return out, ValidationUnknown
  302. }
  303. }