123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- // Copyright 2024 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package impl
- import (
- "fmt"
- "math/bits"
- "os"
- "reflect"
- "sort"
- "sync/atomic"
- "google.golang.org/protobuf/encoding/protowire"
- "google.golang.org/protobuf/internal/errors"
- "google.golang.org/protobuf/internal/protolazy"
- "google.golang.org/protobuf/reflect/protoreflect"
- preg "google.golang.org/protobuf/reflect/protoregistry"
- piface "google.golang.org/protobuf/runtime/protoiface"
- )
- var enableLazy int32 = func() int32 {
- if os.Getenv("GOPROTODEBUG") == "nolazy" {
- return 0
- }
- return 1
- }()
- // EnableLazyUnmarshal enables lazy unmarshaling.
- func EnableLazyUnmarshal(enable bool) {
- if enable {
- atomic.StoreInt32(&enableLazy, 1)
- return
- }
- atomic.StoreInt32(&enableLazy, 0)
- }
- // LazyEnabled reports whether lazy unmarshalling is currently enabled.
- func LazyEnabled() bool {
- return atomic.LoadInt32(&enableLazy) != 0
- }
- // UnmarshalField unmarshals a field in a message.
- func UnmarshalField(m interface{}, num protowire.Number) {
- switch m := m.(type) {
- case *messageState:
- m.messageInfo().lazyUnmarshal(m.pointer(), num)
- case *messageReflectWrapper:
- m.messageInfo().lazyUnmarshal(m.pointer(), num)
- default:
- panic(fmt.Sprintf("unsupported wrapper type %T", m))
- }
- }
- func (mi *MessageInfo) lazyUnmarshal(p pointer, num protoreflect.FieldNumber) {
- var f *coderFieldInfo
- if int(num) < len(mi.denseCoderFields) {
- f = mi.denseCoderFields[num]
- } else {
- f = mi.coderFields[num]
- }
- if f == nil {
- panic(fmt.Sprintf("lazyUnmarshal: field info for %v.%v", mi.Desc.FullName(), num))
- }
- lazy := *p.Apply(mi.lazyOffset).LazyInfoPtr()
- start, end, found, _, multipleEntries := lazy.FindFieldInProto(uint32(num))
- if !found && multipleEntries == nil {
- panic(fmt.Sprintf("lazyUnmarshal: can't find field data for %v.%v", mi.Desc.FullName(), num))
- }
- // The actual pointer in the message can not be set until the whole struct is filled in, otherwise we will have races.
- // Create another pointer and set it atomically, if we won the race and the pointer in the original message is still nil.
- fp := pointerOfValue(reflect.New(f.ft))
- if multipleEntries != nil {
- for _, entry := range multipleEntries {
- mi.unmarshalField(lazy.Buffer()[entry.Start:entry.End], fp, f, lazy, lazy.UnmarshalFlags())
- }
- } else {
- mi.unmarshalField(lazy.Buffer()[start:end], fp, f, lazy, lazy.UnmarshalFlags())
- }
- p.Apply(f.offset).AtomicSetPointerIfNil(fp.Elem())
- }
- func (mi *MessageInfo) unmarshalField(b []byte, p pointer, f *coderFieldInfo, lazyInfo *protolazy.XXX_lazyUnmarshalInfo, flags piface.UnmarshalInputFlags) error {
- opts := lazyUnmarshalOptions
- opts.flags |= flags
- for len(b) > 0 {
- // Parse the tag (field number and wire type).
- var tag uint64
- if b[0] < 0x80 {
- tag = uint64(b[0])
- b = b[1:]
- } else if len(b) >= 2 && b[1] < 128 {
- tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
- b = b[2:]
- } else {
- var n int
- tag, n = protowire.ConsumeVarint(b)
- if n < 0 {
- return errors.New("invalid wire data")
- }
- b = b[n:]
- }
- var num protowire.Number
- if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
- return errors.New("invalid wire data")
- } else {
- num = protowire.Number(n)
- }
- wtyp := protowire.Type(tag & 7)
- if num == f.num {
- o, err := f.funcs.unmarshal(b, p, wtyp, f, opts)
- if err == nil {
- b = b[o.n:]
- continue
- }
- if err != errUnknown {
- return err
- }
- }
- n := protowire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
- return errors.New("invalid wire data")
- }
- b = b[n:]
- }
- return nil
- }
- func (mi *MessageInfo) skipField(b []byte, f *coderFieldInfo, wtyp protowire.Type, opts unmarshalOptions) (out unmarshalOutput, _ ValidationStatus) {
- fmi := f.validation.mi
- if fmi == nil {
- fd := mi.Desc.Fields().ByNumber(f.num)
- if fd == nil {
- return out, ValidationUnknown
- }
- messageName := fd.Message().FullName()
- messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
- if err != nil {
- return out, ValidationUnknown
- }
- var ok bool
- fmi, ok = messageType.(*MessageInfo)
- if !ok {
- return out, ValidationUnknown
- }
- }
- fmi.init()
- switch f.validation.typ {
- case validationTypeMessage:
- if wtyp != protowire.BytesType {
- return out, ValidationWrongWireType
- }
- v, n := protowire.ConsumeBytes(b)
- if n < 0 {
- return out, ValidationInvalid
- }
- out, st := fmi.validate(v, 0, opts)
- out.n = n
- return out, st
- case validationTypeGroup:
- if wtyp != protowire.StartGroupType {
- return out, ValidationWrongWireType
- }
- out, st := fmi.validate(b, f.num, opts)
- return out, st
- default:
- return out, ValidationUnknown
- }
- }
- // unmarshalPointerLazy is similar to unmarshalPointerEager, but it
- // specifically handles lazy unmarshalling. it expects lazyOffset and
- // presenceOffset to both be valid.
- func (mi *MessageInfo) unmarshalPointerLazy(b []byte, p pointer, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
- initialized := true
- var requiredMask uint64
- var lazy **protolazy.XXX_lazyUnmarshalInfo
- var presence presence
- var lazyIndex []protolazy.IndexEntry
- var lastNum protowire.Number
- outOfOrder := false
- lazyDecode := false
- presence = p.Apply(mi.presenceOffset).PresenceInfo()
- lazy = p.Apply(mi.lazyOffset).LazyInfoPtr()
- if !presence.AnyPresent(mi.presenceSize) {
- if opts.CanBeLazy() {
- // If the message contains existing data, we need to merge into it.
- // Lazy unmarshaling doesn't merge, so only enable it when the
- // message is empty (has no presence bitmap).
- lazyDecode = true
- if *lazy == nil {
- *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
- }
- (*lazy).SetUnmarshalFlags(opts.flags)
- if !opts.AliasBuffer() {
- // Make a copy of the buffer for lazy unmarshaling.
- // Set the AliasBuffer flag so recursive unmarshal
- // operations reuse the copy.
- b = append([]byte{}, b...)
- opts.flags |= piface.UnmarshalAliasBuffer
- }
- (*lazy).SetBuffer(b)
- }
- }
- // Track special handling of lazy fields.
- //
- // In the common case, all fields are lazyValidateOnly (and lazyFields remains nil).
- // In the event that validation for a field fails, this map tracks handling of the field.
- type lazyAction uint8
- const (
- lazyValidateOnly lazyAction = iota // validate the field only
- lazyUnmarshalNow // eagerly unmarshal the field
- lazyUnmarshalLater // unmarshal the field after the message is fully processed
- )
- var lazyFields map[*coderFieldInfo]lazyAction
- var exts *map[int32]ExtensionField
- start := len(b)
- pos := 0
- for len(b) > 0 {
- // Parse the tag (field number and wire type).
- var tag uint64
- if b[0] < 0x80 {
- tag = uint64(b[0])
- b = b[1:]
- } else if len(b) >= 2 && b[1] < 128 {
- tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
- b = b[2:]
- } else {
- var n int
- tag, n = protowire.ConsumeVarint(b)
- if n < 0 {
- return out, errDecode
- }
- b = b[n:]
- }
- var num protowire.Number
- if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
- return out, errors.New("invalid field number")
- } else {
- num = protowire.Number(n)
- }
- wtyp := protowire.Type(tag & 7)
- if wtyp == protowire.EndGroupType {
- if num != groupTag {
- return out, errors.New("mismatching end group marker")
- }
- groupTag = 0
- break
- }
- var f *coderFieldInfo
- if int(num) < len(mi.denseCoderFields) {
- f = mi.denseCoderFields[num]
- } else {
- f = mi.coderFields[num]
- }
- var n int
- err := errUnknown
- discardUnknown := false
- Field:
- switch {
- case f != nil:
- if f.funcs.unmarshal == nil {
- break
- }
- if f.isLazy && lazyDecode {
- switch {
- case lazyFields == nil || lazyFields[f] == lazyValidateOnly:
- // Attempt to validate this field and leave it for later lazy unmarshaling.
- o, valid := mi.skipField(b, f, wtyp, opts)
- switch valid {
- case ValidationValid:
- // Skip over the valid field and continue.
- err = nil
- presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
- requiredMask |= f.validation.requiredBit
- if !o.initialized {
- initialized = false
- }
- n = o.n
- break Field
- case ValidationInvalid:
- return out, errors.New("invalid proto wire format")
- case ValidationWrongWireType:
- break Field
- case ValidationUnknown:
- if lazyFields == nil {
- lazyFields = make(map[*coderFieldInfo]lazyAction)
- }
- if presence.Present(f.presenceIndex) {
- // We were unable to determine if the field is valid or not,
- // and we've already skipped over at least one instance of this
- // field. Clear the presence bit (so if we stop decoding early,
- // we don't leave a partially-initialized field around) and flag
- // the field for unmarshaling before we return.
- presence.ClearPresent(f.presenceIndex)
- lazyFields[f] = lazyUnmarshalLater
- discardUnknown = true
- break Field
- } else {
- // We were unable to determine if the field is valid or not,
- // but this is the first time we've seen it. Flag it as needing
- // eager unmarshaling and fall through to the eager unmarshal case below.
- lazyFields[f] = lazyUnmarshalNow
- }
- }
- case lazyFields[f] == lazyUnmarshalLater:
- // This field will be unmarshaled in a separate pass below.
- // Skip over it here.
- discardUnknown = true
- break Field
- default:
- // Eagerly unmarshal the field.
- }
- }
- if f.isLazy && !lazyDecode && presence.Present(f.presenceIndex) {
- if p.Apply(f.offset).AtomicGetPointer().IsNil() {
- mi.lazyUnmarshal(p, f.num)
- }
- }
- var o unmarshalOutput
- o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, f, opts)
- n = o.n
- if err != nil {
- break
- }
- requiredMask |= f.validation.requiredBit
- if f.funcs.isInit != nil && !o.initialized {
- initialized = false
- }
- if f.presenceIndex != noPresence {
- presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
- }
- default:
- // Possible extension.
- if exts == nil && mi.extensionOffset.IsValid() {
- exts = p.Apply(mi.extensionOffset).Extensions()
- if *exts == nil {
- *exts = make(map[int32]ExtensionField)
- }
- }
- if exts == nil {
- break
- }
- var o unmarshalOutput
- o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
- if err != nil {
- break
- }
- n = o.n
- if !o.initialized {
- initialized = false
- }
- }
- if err != nil {
- if err != errUnknown {
- return out, err
- }
- n = protowire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
- return out, errDecode
- }
- if !discardUnknown && !opts.DiscardUnknown() && mi.unknownOffset.IsValid() {
- u := mi.mutableUnknownBytes(p)
- *u = protowire.AppendTag(*u, num, wtyp)
- *u = append(*u, b[:n]...)
- }
- }
- b = b[n:]
- end := start - len(b)
- if lazyDecode && f != nil && f.isLazy {
- if num != lastNum {
- lazyIndex = append(lazyIndex, protolazy.IndexEntry{
- FieldNum: uint32(num),
- Start: uint32(pos),
- End: uint32(end),
- })
- } else {
- i := len(lazyIndex) - 1
- lazyIndex[i].End = uint32(end)
- lazyIndex[i].MultipleContiguous = true
- }
- }
- if num < lastNum {
- outOfOrder = true
- }
- pos = end
- lastNum = num
- }
- if groupTag != 0 {
- return out, errors.New("missing end group marker")
- }
- if lazyFields != nil {
- // Some fields failed validation, and now need to be unmarshaled.
- for f, action := range lazyFields {
- if action != lazyUnmarshalLater {
- continue
- }
- initialized = false
- if *lazy == nil {
- *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
- }
- if err := mi.unmarshalField((*lazy).Buffer(), p.Apply(f.offset), f, *lazy, opts.flags); err != nil {
- return out, err
- }
- presence.SetPresentUnatomic(f.presenceIndex, mi.presenceSize)
- }
- }
- if lazyDecode {
- if outOfOrder {
- sort.Slice(lazyIndex, func(i, j int) bool {
- return lazyIndex[i].FieldNum < lazyIndex[j].FieldNum ||
- (lazyIndex[i].FieldNum == lazyIndex[j].FieldNum &&
- lazyIndex[i].Start < lazyIndex[j].Start)
- })
- }
- if *lazy == nil {
- *lazy = &protolazy.XXX_lazyUnmarshalInfo{}
- }
- (*lazy).SetIndex(lazyIndex)
- }
- if mi.numRequiredFields > 0 && bits.OnesCount64(requiredMask) != int(mi.numRequiredFields) {
- initialized = false
- }
- if initialized {
- out.initialized = true
- }
- out.n = start - len(b)
- return out, nil
- }
|