codec_message.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "sort"
  9. "google.golang.org/protobuf/encoding/protowire"
  10. "google.golang.org/protobuf/internal/encoding/messageset"
  11. "google.golang.org/protobuf/internal/order"
  12. "google.golang.org/protobuf/reflect/protoreflect"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. )
  15. // coderMessageInfo contains per-message information used by the fast-path functions.
  16. // This is a different type from MessageInfo to keep MessageInfo as general-purpose as
  17. // possible.
  18. type coderMessageInfo struct {
  19. methods protoiface.Methods
  20. orderedCoderFields []*coderFieldInfo
  21. denseCoderFields []*coderFieldInfo
  22. coderFields map[protowire.Number]*coderFieldInfo
  23. sizecacheOffset offset
  24. unknownOffset offset
  25. unknownPtrKind bool
  26. extensionOffset offset
  27. needsInitCheck bool
  28. isMessageSet bool
  29. numRequiredFields uint8
  30. lazyOffset offset
  31. presenceOffset offset
  32. presenceSize presenceSize
  33. }
  34. type coderFieldInfo struct {
  35. funcs pointerCoderFuncs // fast-path per-field functions
  36. mi *MessageInfo // field's message
  37. ft reflect.Type
  38. validation validationInfo // information used by message validation
  39. num protoreflect.FieldNumber // field number
  40. offset offset // struct field offset
  41. wiretag uint64 // field tag (number + wire type)
  42. tagsize int // size of the varint-encoded tag
  43. isPointer bool // true if IsNil may be called on the struct field
  44. isRequired bool // true if field is required
  45. isLazy bool
  46. presenceIndex uint32
  47. }
  48. const noPresence = 0xffffffff
  49. func (mi *MessageInfo) makeCoderMethods(t reflect.Type, si structInfo) {
  50. mi.sizecacheOffset = invalidOffset
  51. mi.unknownOffset = invalidOffset
  52. mi.extensionOffset = invalidOffset
  53. mi.lazyOffset = invalidOffset
  54. mi.presenceOffset = si.presenceOffset
  55. if si.sizecacheOffset.IsValid() && si.sizecacheType == sizecacheType {
  56. mi.sizecacheOffset = si.sizecacheOffset
  57. }
  58. if si.unknownOffset.IsValid() && (si.unknownType == unknownFieldsAType || si.unknownType == unknownFieldsBType) {
  59. mi.unknownOffset = si.unknownOffset
  60. mi.unknownPtrKind = si.unknownType.Kind() == reflect.Ptr
  61. }
  62. if si.extensionOffset.IsValid() && si.extensionType == extensionFieldsType {
  63. mi.extensionOffset = si.extensionOffset
  64. }
  65. mi.coderFields = make(map[protowire.Number]*coderFieldInfo)
  66. fields := mi.Desc.Fields()
  67. preallocFields := make([]coderFieldInfo, fields.Len())
  68. for i := 0; i < fields.Len(); i++ {
  69. fd := fields.Get(i)
  70. fs := si.fieldsByNumber[fd.Number()]
  71. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  72. if isOneof {
  73. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  74. }
  75. ft := fs.Type
  76. var wiretag uint64
  77. if !fd.IsPacked() {
  78. wiretag = protowire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
  79. } else {
  80. wiretag = protowire.EncodeTag(fd.Number(), protowire.BytesType)
  81. }
  82. var fieldOffset offset
  83. var funcs pointerCoderFuncs
  84. var childMessage *MessageInfo
  85. switch {
  86. case ft == nil:
  87. // This never occurs for generated message types.
  88. // It implies that a hand-crafted type has missing Go fields
  89. // for specific protobuf message fields.
  90. funcs = pointerCoderFuncs{
  91. size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
  92. return 0
  93. },
  94. marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  95. return nil, nil
  96. },
  97. unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
  98. panic("missing Go struct field for " + string(fd.FullName()))
  99. },
  100. isInit: func(p pointer, f *coderFieldInfo) error {
  101. panic("missing Go struct field for " + string(fd.FullName()))
  102. },
  103. merge: func(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  104. panic("missing Go struct field for " + string(fd.FullName()))
  105. },
  106. }
  107. case isOneof:
  108. fieldOffset = offsetOf(fs)
  109. default:
  110. fieldOffset = offsetOf(fs)
  111. childMessage, funcs = fieldCoder(fd, ft)
  112. }
  113. cf := &preallocFields[i]
  114. *cf = coderFieldInfo{
  115. num: fd.Number(),
  116. offset: fieldOffset,
  117. wiretag: wiretag,
  118. ft: ft,
  119. tagsize: protowire.SizeVarint(wiretag),
  120. funcs: funcs,
  121. mi: childMessage,
  122. validation: newFieldValidationInfo(mi, si, fd, ft),
  123. isPointer: fd.Cardinality() == protoreflect.Repeated || fd.HasPresence(),
  124. isRequired: fd.Cardinality() == protoreflect.Required,
  125. presenceIndex: noPresence,
  126. }
  127. mi.orderedCoderFields = append(mi.orderedCoderFields, cf)
  128. mi.coderFields[cf.num] = cf
  129. }
  130. for i, oneofs := 0, mi.Desc.Oneofs(); i < oneofs.Len(); i++ {
  131. if od := oneofs.Get(i); !od.IsSynthetic() {
  132. mi.initOneofFieldCoders(od, si)
  133. }
  134. }
  135. if messageset.IsMessageSet(mi.Desc) {
  136. if !mi.extensionOffset.IsValid() {
  137. panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.Desc.FullName()))
  138. }
  139. if !mi.unknownOffset.IsValid() {
  140. panic(fmt.Sprintf("%v: MessageSet with no unknown field", mi.Desc.FullName()))
  141. }
  142. mi.isMessageSet = true
  143. }
  144. sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
  145. return mi.orderedCoderFields[i].num < mi.orderedCoderFields[j].num
  146. })
  147. var maxDense protoreflect.FieldNumber
  148. for _, cf := range mi.orderedCoderFields {
  149. if cf.num >= 16 && cf.num >= 2*maxDense {
  150. break
  151. }
  152. maxDense = cf.num
  153. }
  154. mi.denseCoderFields = make([]*coderFieldInfo, maxDense+1)
  155. for _, cf := range mi.orderedCoderFields {
  156. if int(cf.num) >= len(mi.denseCoderFields) {
  157. break
  158. }
  159. mi.denseCoderFields[cf.num] = cf
  160. }
  161. // To preserve compatibility with historic wire output, marshal oneofs last.
  162. if mi.Desc.Oneofs().Len() > 0 {
  163. sort.Slice(mi.orderedCoderFields, func(i, j int) bool {
  164. fi := fields.ByNumber(mi.orderedCoderFields[i].num)
  165. fj := fields.ByNumber(mi.orderedCoderFields[j].num)
  166. return order.LegacyFieldOrder(fi, fj)
  167. })
  168. }
  169. mi.needsInitCheck = needsInitCheck(mi.Desc)
  170. if mi.methods.Marshal == nil && mi.methods.Size == nil {
  171. mi.methods.Flags |= protoiface.SupportMarshalDeterministic
  172. mi.methods.Marshal = mi.marshal
  173. mi.methods.Size = mi.size
  174. }
  175. if mi.methods.Unmarshal == nil {
  176. mi.methods.Flags |= protoiface.SupportUnmarshalDiscardUnknown
  177. mi.methods.Unmarshal = mi.unmarshal
  178. }
  179. if mi.methods.CheckInitialized == nil {
  180. mi.methods.CheckInitialized = mi.checkInitialized
  181. }
  182. if mi.methods.Merge == nil {
  183. mi.methods.Merge = mi.merge
  184. }
  185. if mi.methods.Equal == nil {
  186. mi.methods.Equal = equal
  187. }
  188. }
  189. // getUnknownBytes returns a *[]byte for the unknown fields.
  190. // It is the caller's responsibility to check whether the pointer is nil.
  191. // This function is specially designed to be inlineable.
  192. func (mi *MessageInfo) getUnknownBytes(p pointer) *[]byte {
  193. if mi.unknownPtrKind {
  194. return *p.Apply(mi.unknownOffset).BytesPtr()
  195. } else {
  196. return p.Apply(mi.unknownOffset).Bytes()
  197. }
  198. }
  199. // mutableUnknownBytes returns a *[]byte for the unknown fields.
  200. // The returned pointer is guaranteed to not be nil.
  201. func (mi *MessageInfo) mutableUnknownBytes(p pointer) *[]byte {
  202. if mi.unknownPtrKind {
  203. bp := p.Apply(mi.unknownOffset).BytesPtr()
  204. if *bp == nil {
  205. *bp = new([]byte)
  206. }
  207. return *bp
  208. } else {
  209. return p.Apply(mi.unknownOffset).Bytes()
  210. }
  211. }