codec_field_opaque.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. // Copyright 2024 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "google.golang.org/protobuf/encoding/protowire"
  9. "google.golang.org/protobuf/internal/errors"
  10. "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. func makeOpaqueMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
  13. mi := getMessageInfo(ft)
  14. if mi == nil {
  15. panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), ft))
  16. }
  17. switch fd.Kind() {
  18. case protoreflect.MessageKind:
  19. return mi, pointerCoderFuncs{
  20. size: sizeOpaqueMessage,
  21. marshal: appendOpaqueMessage,
  22. unmarshal: consumeOpaqueMessage,
  23. isInit: isInitOpaqueMessage,
  24. merge: mergeOpaqueMessage,
  25. }
  26. case protoreflect.GroupKind:
  27. return mi, pointerCoderFuncs{
  28. size: sizeOpaqueGroup,
  29. marshal: appendOpaqueGroup,
  30. unmarshal: consumeOpaqueGroup,
  31. isInit: isInitOpaqueMessage,
  32. merge: mergeOpaqueMessage,
  33. }
  34. }
  35. panic("unexpected field kind")
  36. }
  37. func sizeOpaqueMessage(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
  38. return protowire.SizeBytes(f.mi.sizePointer(p.AtomicGetPointer(), opts)) + f.tagsize
  39. }
  40. func appendOpaqueMessage(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  41. mp := p.AtomicGetPointer()
  42. calculatedSize := f.mi.sizePointer(mp, opts)
  43. b = protowire.AppendVarint(b, f.wiretag)
  44. b = protowire.AppendVarint(b, uint64(calculatedSize))
  45. before := len(b)
  46. b, err := f.mi.marshalAppendPointer(b, mp, opts)
  47. if measuredSize := len(b) - before; calculatedSize != measuredSize && err == nil {
  48. return nil, errors.MismatchedSizeCalculation(calculatedSize, measuredSize)
  49. }
  50. return b, err
  51. }
  52. func consumeOpaqueMessage(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  53. if wtyp != protowire.BytesType {
  54. return out, errUnknown
  55. }
  56. v, n := protowire.ConsumeBytes(b)
  57. if n < 0 {
  58. return out, errDecode
  59. }
  60. mp := p.AtomicGetPointer()
  61. if mp.IsNil() {
  62. mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
  63. }
  64. o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
  65. if err != nil {
  66. return out, err
  67. }
  68. out.n = n
  69. out.initialized = o.initialized
  70. return out, nil
  71. }
  72. func isInitOpaqueMessage(p pointer, f *coderFieldInfo) error {
  73. mp := p.AtomicGetPointer()
  74. if mp.IsNil() {
  75. return nil
  76. }
  77. return f.mi.checkInitializedPointer(mp)
  78. }
  79. func mergeOpaqueMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  80. dstmp := dst.AtomicGetPointer()
  81. if dstmp.IsNil() {
  82. dstmp = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
  83. }
  84. f.mi.mergePointer(dstmp, src.AtomicGetPointer(), opts)
  85. }
  86. func sizeOpaqueGroup(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
  87. return 2*f.tagsize + f.mi.sizePointer(p.AtomicGetPointer(), opts)
  88. }
  89. func appendOpaqueGroup(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  90. b = protowire.AppendVarint(b, f.wiretag) // start group
  91. b, err := f.mi.marshalAppendPointer(b, p.AtomicGetPointer(), opts)
  92. b = protowire.AppendVarint(b, f.wiretag+1) // end group
  93. return b, err
  94. }
  95. func consumeOpaqueGroup(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  96. if wtyp != protowire.StartGroupType {
  97. return out, errUnknown
  98. }
  99. mp := p.AtomicGetPointer()
  100. if mp.IsNil() {
  101. mp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.mi.GoReflectType.Elem())))
  102. }
  103. o, e := f.mi.unmarshalPointer(b, mp, f.num, opts)
  104. return o, e
  105. }
  106. func makeOpaqueRepeatedMessageFieldCoder(fd protoreflect.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
  107. if ft.Kind() != reflect.Ptr || ft.Elem().Kind() != reflect.Slice {
  108. panic(fmt.Sprintf("invalid field: %v: unsupported type for opaque repeated message: %v", fd.FullName(), ft))
  109. }
  110. mt := ft.Elem().Elem() // *[]*T -> *T
  111. mi := getMessageInfo(mt)
  112. if mi == nil {
  113. panic(fmt.Sprintf("invalid field: %v: unsupported message type %v", fd.FullName(), mt))
  114. }
  115. switch fd.Kind() {
  116. case protoreflect.MessageKind:
  117. return mi, pointerCoderFuncs{
  118. size: sizeOpaqueMessageSlice,
  119. marshal: appendOpaqueMessageSlice,
  120. unmarshal: consumeOpaqueMessageSlice,
  121. isInit: isInitOpaqueMessageSlice,
  122. merge: mergeOpaqueMessageSlice,
  123. }
  124. case protoreflect.GroupKind:
  125. return mi, pointerCoderFuncs{
  126. size: sizeOpaqueGroupSlice,
  127. marshal: appendOpaqueGroupSlice,
  128. unmarshal: consumeOpaqueGroupSlice,
  129. isInit: isInitOpaqueMessageSlice,
  130. merge: mergeOpaqueMessageSlice,
  131. }
  132. }
  133. panic("unexpected field kind")
  134. }
  135. func sizeOpaqueMessageSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
  136. s := p.AtomicGetPointer().PointerSlice()
  137. n := 0
  138. for _, v := range s {
  139. n += protowire.SizeBytes(f.mi.sizePointer(v, opts)) + f.tagsize
  140. }
  141. return n
  142. }
  143. func appendOpaqueMessageSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  144. s := p.AtomicGetPointer().PointerSlice()
  145. var err error
  146. for _, v := range s {
  147. b = protowire.AppendVarint(b, f.wiretag)
  148. siz := f.mi.sizePointer(v, opts)
  149. b = protowire.AppendVarint(b, uint64(siz))
  150. before := len(b)
  151. b, err = f.mi.marshalAppendPointer(b, v, opts)
  152. if err != nil {
  153. return b, err
  154. }
  155. if measuredSize := len(b) - before; siz != measuredSize {
  156. return nil, errors.MismatchedSizeCalculation(siz, measuredSize)
  157. }
  158. }
  159. return b, nil
  160. }
  161. func consumeOpaqueMessageSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  162. if wtyp != protowire.BytesType {
  163. return out, errUnknown
  164. }
  165. v, n := protowire.ConsumeBytes(b)
  166. if n < 0 {
  167. return out, errDecode
  168. }
  169. mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
  170. o, err := f.mi.unmarshalPointer(v, mp, 0, opts)
  171. if err != nil {
  172. return out, err
  173. }
  174. sp := p.AtomicGetPointer()
  175. if sp.IsNil() {
  176. sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
  177. }
  178. sp.AppendPointerSlice(mp)
  179. out.n = n
  180. out.initialized = o.initialized
  181. return out, nil
  182. }
  183. func isInitOpaqueMessageSlice(p pointer, f *coderFieldInfo) error {
  184. sp := p.AtomicGetPointer()
  185. if sp.IsNil() {
  186. return nil
  187. }
  188. s := sp.PointerSlice()
  189. for _, v := range s {
  190. if err := f.mi.checkInitializedPointer(v); err != nil {
  191. return err
  192. }
  193. }
  194. return nil
  195. }
  196. func mergeOpaqueMessageSlice(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
  197. ds := dst.AtomicGetPointer()
  198. if ds.IsNil() {
  199. ds = dst.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
  200. }
  201. for _, sp := range src.AtomicGetPointer().PointerSlice() {
  202. dm := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
  203. f.mi.mergePointer(dm, sp, opts)
  204. ds.AppendPointerSlice(dm)
  205. }
  206. }
  207. func sizeOpaqueGroupSlice(p pointer, f *coderFieldInfo, opts marshalOptions) (size int) {
  208. s := p.AtomicGetPointer().PointerSlice()
  209. n := 0
  210. for _, v := range s {
  211. n += 2*f.tagsize + f.mi.sizePointer(v, opts)
  212. }
  213. return n
  214. }
  215. func appendOpaqueGroupSlice(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
  216. s := p.AtomicGetPointer().PointerSlice()
  217. var err error
  218. for _, v := range s {
  219. b = protowire.AppendVarint(b, f.wiretag) // start group
  220. b, err = f.mi.marshalAppendPointer(b, v, opts)
  221. if err != nil {
  222. return b, err
  223. }
  224. b = protowire.AppendVarint(b, f.wiretag+1) // end group
  225. }
  226. return b, nil
  227. }
  228. func consumeOpaqueGroupSlice(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
  229. if wtyp != protowire.StartGroupType {
  230. return out, errUnknown
  231. }
  232. mp := pointerOfValue(reflect.New(f.mi.GoReflectType.Elem()))
  233. out, err = f.mi.unmarshalPointer(b, mp, f.num, opts)
  234. if err != nil {
  235. return out, err
  236. }
  237. sp := p.AtomicGetPointer()
  238. if sp.IsNil() {
  239. sp = p.AtomicSetPointerIfNil(pointerOfValue(reflect.New(f.ft.Elem())))
  240. }
  241. sp.AppendPointerSlice(mp)
  242. return out, err
  243. }