message_reflect.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "reflect"
  8. "google.golang.org/protobuf/internal/detrand"
  9. "google.golang.org/protobuf/internal/pragma"
  10. "google.golang.org/protobuf/reflect/protoreflect"
  11. )
  12. type reflectMessageInfo struct {
  13. fields map[protoreflect.FieldNumber]*fieldInfo
  14. oneofs map[protoreflect.Name]*oneofInfo
  15. // fieldTypes contains the zero value of an enum or message field.
  16. // For lists, it contains the element type.
  17. // For maps, it contains the entry value type.
  18. fieldTypes map[protoreflect.FieldNumber]any
  19. // denseFields is a subset of fields where:
  20. // 0 < fieldDesc.Number() < len(denseFields)
  21. // It provides faster access to the fieldInfo, but may be incomplete.
  22. denseFields []*fieldInfo
  23. // rangeInfos is a list of all fields (not belonging to a oneof) and oneofs.
  24. rangeInfos []any // either *fieldInfo or *oneofInfo
  25. getUnknown func(pointer) protoreflect.RawFields
  26. setUnknown func(pointer, protoreflect.RawFields)
  27. extensionMap func(pointer) *extensionMap
  28. nilMessage atomicNilMessage
  29. }
  30. // makeReflectFuncs generates the set of functions to support reflection.
  31. func (mi *MessageInfo) makeReflectFuncs(t reflect.Type, si structInfo) {
  32. mi.makeKnownFieldsFunc(si)
  33. mi.makeUnknownFieldsFunc(t, si)
  34. mi.makeExtensionFieldsFunc(t, si)
  35. mi.makeFieldTypes(si)
  36. }
  37. // makeKnownFieldsFunc generates functions for operations that can be performed
  38. // on each protobuf message field. It takes in a reflect.Type representing the
  39. // Go struct and matches message fields with struct fields.
  40. //
  41. // This code assumes that the struct is well-formed and panics if there are
  42. // any discrepancies.
  43. func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
  44. mi.fields = map[protoreflect.FieldNumber]*fieldInfo{}
  45. md := mi.Desc
  46. fds := md.Fields()
  47. for i := 0; i < fds.Len(); i++ {
  48. fd := fds.Get(i)
  49. fs := si.fieldsByNumber[fd.Number()]
  50. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  51. if isOneof {
  52. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  53. }
  54. var fi fieldInfo
  55. switch {
  56. case fs.Type == nil:
  57. fi = fieldInfoForMissing(fd) // never occurs for officially generated message types
  58. case isOneof:
  59. fi = fieldInfoForOneof(fd, fs, mi.Exporter, si.oneofWrappersByNumber[fd.Number()])
  60. case fd.IsMap():
  61. fi = fieldInfoForMap(fd, fs, mi.Exporter)
  62. case fd.IsList():
  63. fi = fieldInfoForList(fd, fs, mi.Exporter)
  64. case fd.Message() != nil:
  65. fi = fieldInfoForMessage(fd, fs, mi.Exporter)
  66. default:
  67. fi = fieldInfoForScalar(fd, fs, mi.Exporter)
  68. }
  69. mi.fields[fd.Number()] = &fi
  70. }
  71. mi.oneofs = map[protoreflect.Name]*oneofInfo{}
  72. for i := 0; i < md.Oneofs().Len(); i++ {
  73. od := md.Oneofs().Get(i)
  74. mi.oneofs[od.Name()] = makeOneofInfo(od, si, mi.Exporter)
  75. }
  76. mi.denseFields = make([]*fieldInfo, fds.Len()*2)
  77. for i := 0; i < fds.Len(); i++ {
  78. if fd := fds.Get(i); int(fd.Number()) < len(mi.denseFields) {
  79. mi.denseFields[fd.Number()] = mi.fields[fd.Number()]
  80. }
  81. }
  82. for i := 0; i < fds.Len(); {
  83. fd := fds.Get(i)
  84. if od := fd.ContainingOneof(); od != nil && !od.IsSynthetic() {
  85. mi.rangeInfos = append(mi.rangeInfos, mi.oneofs[od.Name()])
  86. i += od.Fields().Len()
  87. } else {
  88. mi.rangeInfos = append(mi.rangeInfos, mi.fields[fd.Number()])
  89. i++
  90. }
  91. }
  92. // Introduce instability to iteration order, but keep it deterministic.
  93. if len(mi.rangeInfos) > 1 && detrand.Bool() {
  94. i := detrand.Intn(len(mi.rangeInfos) - 1)
  95. mi.rangeInfos[i], mi.rangeInfos[i+1] = mi.rangeInfos[i+1], mi.rangeInfos[i]
  96. }
  97. }
  98. func (mi *MessageInfo) makeUnknownFieldsFunc(t reflect.Type, si structInfo) {
  99. switch {
  100. case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsAType:
  101. // Handle as []byte.
  102. mi.getUnknown = func(p pointer) protoreflect.RawFields {
  103. if p.IsNil() {
  104. return nil
  105. }
  106. return *p.Apply(mi.unknownOffset).Bytes()
  107. }
  108. mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
  109. if p.IsNil() {
  110. panic("invalid SetUnknown on nil Message")
  111. }
  112. *p.Apply(mi.unknownOffset).Bytes() = b
  113. }
  114. case si.unknownOffset.IsValid() && si.unknownType == unknownFieldsBType:
  115. // Handle as *[]byte.
  116. mi.getUnknown = func(p pointer) protoreflect.RawFields {
  117. if p.IsNil() {
  118. return nil
  119. }
  120. bp := p.Apply(mi.unknownOffset).BytesPtr()
  121. if *bp == nil {
  122. return nil
  123. }
  124. return **bp
  125. }
  126. mi.setUnknown = func(p pointer, b protoreflect.RawFields) {
  127. if p.IsNil() {
  128. panic("invalid SetUnknown on nil Message")
  129. }
  130. bp := p.Apply(mi.unknownOffset).BytesPtr()
  131. if *bp == nil {
  132. *bp = new([]byte)
  133. }
  134. **bp = b
  135. }
  136. default:
  137. mi.getUnknown = func(pointer) protoreflect.RawFields {
  138. return nil
  139. }
  140. mi.setUnknown = func(p pointer, _ protoreflect.RawFields) {
  141. if p.IsNil() {
  142. panic("invalid SetUnknown on nil Message")
  143. }
  144. }
  145. }
  146. }
  147. func (mi *MessageInfo) makeExtensionFieldsFunc(t reflect.Type, si structInfo) {
  148. if si.extensionOffset.IsValid() {
  149. mi.extensionMap = func(p pointer) *extensionMap {
  150. if p.IsNil() {
  151. return (*extensionMap)(nil)
  152. }
  153. v := p.Apply(si.extensionOffset).AsValueOf(extensionFieldsType)
  154. return (*extensionMap)(v.Interface().(*map[int32]ExtensionField))
  155. }
  156. } else {
  157. mi.extensionMap = func(pointer) *extensionMap {
  158. return (*extensionMap)(nil)
  159. }
  160. }
  161. }
  162. func (mi *MessageInfo) makeFieldTypes(si structInfo) {
  163. md := mi.Desc
  164. fds := md.Fields()
  165. for i := 0; i < fds.Len(); i++ {
  166. var ft reflect.Type
  167. fd := fds.Get(i)
  168. fs := si.fieldsByNumber[fd.Number()]
  169. isOneof := fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic()
  170. if isOneof {
  171. fs = si.oneofsByName[fd.ContainingOneof().Name()]
  172. }
  173. var isMessage bool
  174. switch {
  175. case fs.Type == nil:
  176. continue // never occurs for officially generated message types
  177. case isOneof:
  178. if fd.Enum() != nil || fd.Message() != nil {
  179. ft = si.oneofWrappersByNumber[fd.Number()].Field(0).Type
  180. }
  181. case fd.IsMap():
  182. if fd.MapValue().Enum() != nil || fd.MapValue().Message() != nil {
  183. ft = fs.Type.Elem()
  184. }
  185. isMessage = fd.MapValue().Message() != nil
  186. case fd.IsList():
  187. if fd.Enum() != nil || fd.Message() != nil {
  188. ft = fs.Type.Elem()
  189. if ft.Kind() == reflect.Slice {
  190. ft = ft.Elem()
  191. }
  192. }
  193. isMessage = fd.Message() != nil
  194. case fd.Enum() != nil:
  195. ft = fs.Type
  196. if fd.HasPresence() && ft.Kind() == reflect.Ptr {
  197. ft = ft.Elem()
  198. }
  199. case fd.Message() != nil:
  200. ft = fs.Type
  201. isMessage = true
  202. }
  203. if isMessage && ft != nil && ft.Kind() != reflect.Ptr {
  204. ft = reflect.PtrTo(ft) // never occurs for officially generated message types
  205. }
  206. if ft != nil {
  207. if mi.fieldTypes == nil {
  208. mi.fieldTypes = make(map[protoreflect.FieldNumber]any)
  209. }
  210. mi.fieldTypes[fd.Number()] = reflect.Zero(ft).Interface()
  211. }
  212. }
  213. }
  214. type extensionMap map[int32]ExtensionField
  215. func (m *extensionMap) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
  216. if m != nil {
  217. for _, x := range *m {
  218. xd := x.Type().TypeDescriptor()
  219. v := x.Value()
  220. if xd.IsList() && v.List().Len() == 0 {
  221. continue
  222. }
  223. if !f(xd, v) {
  224. return
  225. }
  226. }
  227. }
  228. }
  229. func (m *extensionMap) Has(xd protoreflect.ExtensionTypeDescriptor) (ok bool) {
  230. if m == nil {
  231. return false
  232. }
  233. x, ok := (*m)[int32(xd.Number())]
  234. if !ok {
  235. return false
  236. }
  237. if x.isUnexpandedLazy() {
  238. // Avoid calling x.Value(), which triggers a lazy unmarshal.
  239. return true
  240. }
  241. switch {
  242. case xd.IsList():
  243. return x.Value().List().Len() > 0
  244. case xd.IsMap():
  245. return x.Value().Map().Len() > 0
  246. }
  247. return true
  248. }
  249. func (m *extensionMap) Clear(xd protoreflect.ExtensionTypeDescriptor) {
  250. delete(*m, int32(xd.Number()))
  251. }
  252. func (m *extensionMap) Get(xd protoreflect.ExtensionTypeDescriptor) protoreflect.Value {
  253. if m != nil {
  254. if x, ok := (*m)[int32(xd.Number())]; ok {
  255. return x.Value()
  256. }
  257. }
  258. return xd.Type().Zero()
  259. }
  260. func (m *extensionMap) Set(xd protoreflect.ExtensionTypeDescriptor, v protoreflect.Value) {
  261. xt := xd.Type()
  262. isValid := true
  263. switch {
  264. case !xt.IsValidValue(v):
  265. isValid = false
  266. case xd.IsList():
  267. isValid = v.List().IsValid()
  268. case xd.IsMap():
  269. isValid = v.Map().IsValid()
  270. case xd.Message() != nil:
  271. isValid = v.Message().IsValid()
  272. }
  273. if !isValid {
  274. panic(fmt.Sprintf("%v: assigning invalid value", xd.FullName()))
  275. }
  276. if *m == nil {
  277. *m = make(map[int32]ExtensionField)
  278. }
  279. var x ExtensionField
  280. x.Set(xt, v)
  281. (*m)[int32(xd.Number())] = x
  282. }
  283. func (m *extensionMap) Mutable(xd protoreflect.ExtensionTypeDescriptor) protoreflect.Value {
  284. if xd.Kind() != protoreflect.MessageKind && xd.Kind() != protoreflect.GroupKind && !xd.IsList() && !xd.IsMap() {
  285. panic("invalid Mutable on field with non-composite type")
  286. }
  287. if x, ok := (*m)[int32(xd.Number())]; ok {
  288. return x.Value()
  289. }
  290. v := xd.Type().New()
  291. m.Set(xd, v)
  292. return v
  293. }
  294. // MessageState is a data structure that is nested as the first field in a
  295. // concrete message. It provides a way to implement the ProtoReflect method
  296. // in an allocation-free way without needing to have a shadow Go type generated
  297. // for every message type. This technique only works using unsafe.
  298. //
  299. // Example generated code:
  300. //
  301. // type M struct {
  302. // state protoimpl.MessageState
  303. //
  304. // Field1 int32
  305. // Field2 string
  306. // Field3 *BarMessage
  307. // ...
  308. // }
  309. //
  310. // func (m *M) ProtoReflect() protoreflect.Message {
  311. // mi := &file_fizz_buzz_proto_msgInfos[5]
  312. // if protoimpl.UnsafeEnabled && m != nil {
  313. // ms := protoimpl.X.MessageStateOf(Pointer(m))
  314. // if ms.LoadMessageInfo() == nil {
  315. // ms.StoreMessageInfo(mi)
  316. // }
  317. // return ms
  318. // }
  319. // return mi.MessageOf(m)
  320. // }
  321. //
  322. // The MessageState type holds a *MessageInfo, which must be atomically set to
  323. // the message info associated with a given message instance.
  324. // By unsafely converting a *M into a *MessageState, the MessageState object
  325. // has access to all the information needed to implement protobuf reflection.
  326. // It has access to the message info as its first field, and a pointer to the
  327. // MessageState is identical to a pointer to the concrete message value.
  328. //
  329. // Requirements:
  330. // - The type M must implement protoreflect.ProtoMessage.
  331. // - The address of m must not be nil.
  332. // - The address of m and the address of m.state must be equal,
  333. // even though they are different Go types.
  334. type MessageState struct {
  335. pragma.NoUnkeyedLiterals
  336. pragma.DoNotCompare
  337. pragma.DoNotCopy
  338. atomicMessageInfo *MessageInfo
  339. }
  340. type messageState MessageState
  341. var (
  342. _ protoreflect.Message = (*messageState)(nil)
  343. _ unwrapper = (*messageState)(nil)
  344. )
  345. // messageDataType is a tuple of a pointer to the message data and
  346. // a pointer to the message type. It is a generalized way of providing a
  347. // reflective view over a message instance. The disadvantage of this approach
  348. // is the need to allocate this tuple of 16B.
  349. type messageDataType struct {
  350. p pointer
  351. mi *MessageInfo
  352. }
  353. type (
  354. messageReflectWrapper messageDataType
  355. messageIfaceWrapper messageDataType
  356. )
  357. var (
  358. _ protoreflect.Message = (*messageReflectWrapper)(nil)
  359. _ unwrapper = (*messageReflectWrapper)(nil)
  360. _ protoreflect.ProtoMessage = (*messageIfaceWrapper)(nil)
  361. _ unwrapper = (*messageIfaceWrapper)(nil)
  362. )
  363. // MessageOf returns a reflective view over a message. The input must be a
  364. // pointer to a named Go struct. If the provided type has a ProtoReflect method,
  365. // it must be implemented by calling this method.
  366. func (mi *MessageInfo) MessageOf(m any) protoreflect.Message {
  367. if reflect.TypeOf(m) != mi.GoReflectType {
  368. panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoReflectType))
  369. }
  370. p := pointerOfIface(m)
  371. if p.IsNil() {
  372. return mi.nilMessage.Init(mi)
  373. }
  374. return &messageReflectWrapper{p, mi}
  375. }
  376. func (m *messageReflectWrapper) pointer() pointer { return m.p }
  377. func (m *messageReflectWrapper) messageInfo() *MessageInfo { return m.mi }
  378. // Reset implements the v1 proto.Message.Reset method.
  379. func (m *messageIfaceWrapper) Reset() {
  380. if mr, ok := m.protoUnwrap().(interface{ Reset() }); ok {
  381. mr.Reset()
  382. return
  383. }
  384. rv := reflect.ValueOf(m.protoUnwrap())
  385. if rv.Kind() == reflect.Ptr && !rv.IsNil() {
  386. rv.Elem().Set(reflect.Zero(rv.Type().Elem()))
  387. }
  388. }
  389. func (m *messageIfaceWrapper) ProtoReflect() protoreflect.Message {
  390. return (*messageReflectWrapper)(m)
  391. }
  392. func (m *messageIfaceWrapper) protoUnwrap() any {
  393. return m.p.AsIfaceOf(m.mi.GoReflectType.Elem())
  394. }
  395. // checkField verifies that the provided field descriptor is valid.
  396. // Exactly one of the returned values is populated.
  397. func (mi *MessageInfo) checkField(fd protoreflect.FieldDescriptor) (*fieldInfo, protoreflect.ExtensionTypeDescriptor) {
  398. var fi *fieldInfo
  399. if n := fd.Number(); 0 < n && int(n) < len(mi.denseFields) {
  400. fi = mi.denseFields[n]
  401. } else {
  402. fi = mi.fields[n]
  403. }
  404. if fi != nil {
  405. if fi.fieldDesc != fd {
  406. if got, want := fd.FullName(), fi.fieldDesc.FullName(); got != want {
  407. panic(fmt.Sprintf("mismatching field: got %v, want %v", got, want))
  408. }
  409. panic(fmt.Sprintf("mismatching field: %v", fd.FullName()))
  410. }
  411. return fi, nil
  412. }
  413. if fd.IsExtension() {
  414. if got, want := fd.ContainingMessage().FullName(), mi.Desc.FullName(); got != want {
  415. // TODO: Should this be exact containing message descriptor match?
  416. panic(fmt.Sprintf("extension %v has mismatching containing message: got %v, want %v", fd.FullName(), got, want))
  417. }
  418. if !mi.Desc.ExtensionRanges().Has(fd.Number()) {
  419. panic(fmt.Sprintf("extension %v extends %v outside the extension range", fd.FullName(), mi.Desc.FullName()))
  420. }
  421. xtd, ok := fd.(protoreflect.ExtensionTypeDescriptor)
  422. if !ok {
  423. panic(fmt.Sprintf("extension %v does not implement protoreflect.ExtensionTypeDescriptor", fd.FullName()))
  424. }
  425. return nil, xtd
  426. }
  427. panic(fmt.Sprintf("field %v is invalid", fd.FullName()))
  428. }