validate.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "math"
  8. "math/bits"
  9. "reflect"
  10. "unicode/utf8"
  11. "google.golang.org/protobuf/encoding/protowire"
  12. "google.golang.org/protobuf/internal/encoding/messageset"
  13. "google.golang.org/protobuf/internal/flags"
  14. "google.golang.org/protobuf/internal/genid"
  15. "google.golang.org/protobuf/internal/strs"
  16. "google.golang.org/protobuf/reflect/protoreflect"
  17. "google.golang.org/protobuf/reflect/protoregistry"
  18. "google.golang.org/protobuf/runtime/protoiface"
  19. )
  20. // ValidationStatus is the result of validating the wire-format encoding of a message.
  21. type ValidationStatus int
  22. const (
  23. // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
  24. // The validator was unable to render a judgement.
  25. //
  26. // The only causes of this status are an aberrant message type appearing somewhere
  27. // in the message or a failure in the extension resolver.
  28. ValidationUnknown ValidationStatus = iota + 1
  29. // ValidationInvalid indicates that unmarshaling the message will fail.
  30. ValidationInvalid
  31. // ValidationValid indicates that unmarshaling the message will succeed.
  32. ValidationValid
  33. // ValidationWrongWireType indicates that a validated field does not have
  34. // the expected wire type.
  35. ValidationWrongWireType
  36. )
  37. func (v ValidationStatus) String() string {
  38. switch v {
  39. case ValidationUnknown:
  40. return "ValidationUnknown"
  41. case ValidationInvalid:
  42. return "ValidationInvalid"
  43. case ValidationValid:
  44. return "ValidationValid"
  45. default:
  46. return fmt.Sprintf("ValidationStatus(%d)", int(v))
  47. }
  48. }
  49. // Validate determines whether the contents of the buffer are a valid wire encoding
  50. // of the message type.
  51. //
  52. // This function is exposed for testing.
  53. func Validate(mt protoreflect.MessageType, in protoiface.UnmarshalInput) (out protoiface.UnmarshalOutput, _ ValidationStatus) {
  54. mi, ok := mt.(*MessageInfo)
  55. if !ok {
  56. return out, ValidationUnknown
  57. }
  58. if in.Resolver == nil {
  59. in.Resolver = protoregistry.GlobalTypes
  60. }
  61. o, st := mi.validate(in.Buf, 0, unmarshalOptions{
  62. flags: in.Flags,
  63. resolver: in.Resolver,
  64. })
  65. if o.initialized {
  66. out.Flags |= protoiface.UnmarshalInitialized
  67. }
  68. return out, st
  69. }
  70. type validationInfo struct {
  71. mi *MessageInfo
  72. typ validationType
  73. keyType, valType validationType
  74. // For non-required fields, requiredBit is 0.
  75. //
  76. // For required fields, requiredBit's nth bit is set, where n is a
  77. // unique index in the range [0, MessageInfo.numRequiredFields).
  78. //
  79. // If there are more than 64 required fields, requiredBit is 0.
  80. requiredBit uint64
  81. }
  82. type validationType uint8
  83. const (
  84. validationTypeOther validationType = iota
  85. validationTypeMessage
  86. validationTypeGroup
  87. validationTypeMap
  88. validationTypeRepeatedVarint
  89. validationTypeRepeatedFixed32
  90. validationTypeRepeatedFixed64
  91. validationTypeVarint
  92. validationTypeFixed32
  93. validationTypeFixed64
  94. validationTypeBytes
  95. validationTypeUTF8String
  96. validationTypeMessageSetItem
  97. )
  98. func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
  99. var vi validationInfo
  100. switch {
  101. case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
  102. switch fd.Kind() {
  103. case protoreflect.MessageKind:
  104. vi.typ = validationTypeMessage
  105. if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
  106. vi.mi = getMessageInfo(ot.Field(0).Type)
  107. }
  108. case protoreflect.GroupKind:
  109. vi.typ = validationTypeGroup
  110. if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
  111. vi.mi = getMessageInfo(ot.Field(0).Type)
  112. }
  113. case protoreflect.StringKind:
  114. if strs.EnforceUTF8(fd) {
  115. vi.typ = validationTypeUTF8String
  116. }
  117. }
  118. default:
  119. vi = newValidationInfo(fd, ft)
  120. }
  121. if fd.Cardinality() == protoreflect.Required {
  122. // Avoid overflow. The required field check is done with a 64-bit mask, with
  123. // any message containing more than 64 required fields always reported as
  124. // potentially uninitialized, so it is not important to get a precise count
  125. // of the required fields past 64.
  126. if mi.numRequiredFields < math.MaxUint8 {
  127. mi.numRequiredFields++
  128. vi.requiredBit = 1 << (mi.numRequiredFields - 1)
  129. }
  130. }
  131. return vi
  132. }
  133. func newValidationInfo(fd protoreflect.FieldDescriptor, ft reflect.Type) validationInfo {
  134. var vi validationInfo
  135. switch {
  136. case fd.IsList():
  137. switch fd.Kind() {
  138. case protoreflect.MessageKind:
  139. vi.typ = validationTypeMessage
  140. if ft.Kind() == reflect.Ptr {
  141. // Repeated opaque message fields are *[]*T.
  142. ft = ft.Elem()
  143. }
  144. if ft.Kind() == reflect.Slice {
  145. vi.mi = getMessageInfo(ft.Elem())
  146. }
  147. case protoreflect.GroupKind:
  148. vi.typ = validationTypeGroup
  149. if ft.Kind() == reflect.Ptr {
  150. // Repeated opaque message fields are *[]*T.
  151. ft = ft.Elem()
  152. }
  153. if ft.Kind() == reflect.Slice {
  154. vi.mi = getMessageInfo(ft.Elem())
  155. }
  156. case protoreflect.StringKind:
  157. vi.typ = validationTypeBytes
  158. if strs.EnforceUTF8(fd) {
  159. vi.typ = validationTypeUTF8String
  160. }
  161. default:
  162. switch wireTypes[fd.Kind()] {
  163. case protowire.VarintType:
  164. vi.typ = validationTypeRepeatedVarint
  165. case protowire.Fixed32Type:
  166. vi.typ = validationTypeRepeatedFixed32
  167. case protowire.Fixed64Type:
  168. vi.typ = validationTypeRepeatedFixed64
  169. }
  170. }
  171. case fd.IsMap():
  172. vi.typ = validationTypeMap
  173. switch fd.MapKey().Kind() {
  174. case protoreflect.StringKind:
  175. if strs.EnforceUTF8(fd) {
  176. vi.keyType = validationTypeUTF8String
  177. }
  178. }
  179. switch fd.MapValue().Kind() {
  180. case protoreflect.MessageKind:
  181. vi.valType = validationTypeMessage
  182. if ft.Kind() == reflect.Map {
  183. vi.mi = getMessageInfo(ft.Elem())
  184. }
  185. case protoreflect.StringKind:
  186. if strs.EnforceUTF8(fd) {
  187. vi.valType = validationTypeUTF8String
  188. }
  189. }
  190. default:
  191. switch fd.Kind() {
  192. case protoreflect.MessageKind:
  193. vi.typ = validationTypeMessage
  194. vi.mi = getMessageInfo(ft)
  195. case protoreflect.GroupKind:
  196. vi.typ = validationTypeGroup
  197. vi.mi = getMessageInfo(ft)
  198. case protoreflect.StringKind:
  199. vi.typ = validationTypeBytes
  200. if strs.EnforceUTF8(fd) {
  201. vi.typ = validationTypeUTF8String
  202. }
  203. default:
  204. switch wireTypes[fd.Kind()] {
  205. case protowire.VarintType:
  206. vi.typ = validationTypeVarint
  207. case protowire.Fixed32Type:
  208. vi.typ = validationTypeFixed32
  209. case protowire.Fixed64Type:
  210. vi.typ = validationTypeFixed64
  211. case protowire.BytesType:
  212. vi.typ = validationTypeBytes
  213. }
  214. }
  215. }
  216. return vi
  217. }
  218. func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
  219. mi.init()
  220. type validationState struct {
  221. typ validationType
  222. keyType, valType validationType
  223. endGroup protowire.Number
  224. mi *MessageInfo
  225. tail []byte
  226. requiredMask uint64
  227. }
  228. // Pre-allocate some slots to avoid repeated slice reallocation.
  229. states := make([]validationState, 0, 16)
  230. states = append(states, validationState{
  231. typ: validationTypeMessage,
  232. mi: mi,
  233. })
  234. if groupTag > 0 {
  235. states[0].typ = validationTypeGroup
  236. states[0].endGroup = groupTag
  237. }
  238. initialized := true
  239. start := len(b)
  240. State:
  241. for len(states) > 0 {
  242. st := &states[len(states)-1]
  243. for len(b) > 0 {
  244. // Parse the tag (field number and wire type).
  245. var tag uint64
  246. if b[0] < 0x80 {
  247. tag = uint64(b[0])
  248. b = b[1:]
  249. } else if len(b) >= 2 && b[1] < 128 {
  250. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  251. b = b[2:]
  252. } else {
  253. var n int
  254. tag, n = protowire.ConsumeVarint(b)
  255. if n < 0 {
  256. return out, ValidationInvalid
  257. }
  258. b = b[n:]
  259. }
  260. var num protowire.Number
  261. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  262. return out, ValidationInvalid
  263. } else {
  264. num = protowire.Number(n)
  265. }
  266. wtyp := protowire.Type(tag & 7)
  267. if wtyp == protowire.EndGroupType {
  268. if st.endGroup == num {
  269. goto PopState
  270. }
  271. return out, ValidationInvalid
  272. }
  273. var vi validationInfo
  274. switch {
  275. case st.typ == validationTypeMap:
  276. switch num {
  277. case genid.MapEntry_Key_field_number:
  278. vi.typ = st.keyType
  279. case genid.MapEntry_Value_field_number:
  280. vi.typ = st.valType
  281. vi.mi = st.mi
  282. vi.requiredBit = 1
  283. }
  284. case flags.ProtoLegacy && st.mi.isMessageSet:
  285. switch num {
  286. case messageset.FieldItem:
  287. vi.typ = validationTypeMessageSetItem
  288. }
  289. default:
  290. var f *coderFieldInfo
  291. if int(num) < len(st.mi.denseCoderFields) {
  292. f = st.mi.denseCoderFields[num]
  293. } else {
  294. f = st.mi.coderFields[num]
  295. }
  296. if f != nil {
  297. vi = f.validation
  298. break
  299. }
  300. // Possible extension field.
  301. //
  302. // TODO: We should return ValidationUnknown when:
  303. // 1. The resolver is not frozen. (More extensions may be added to it.)
  304. // 2. The resolver returns preg.NotFound.
  305. // In this case, a type added to the resolver in the future could cause
  306. // unmarshaling to begin failing. Supporting this requires some way to
  307. // determine if the resolver is frozen.
  308. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
  309. if err != nil && err != protoregistry.NotFound {
  310. return out, ValidationUnknown
  311. }
  312. if err == nil {
  313. vi = getExtensionFieldInfo(xt).validation
  314. }
  315. }
  316. if vi.requiredBit != 0 {
  317. // Check that the field has a compatible wire type.
  318. // We only need to consider non-repeated field types,
  319. // since repeated fields (and maps) can never be required.
  320. ok := false
  321. switch vi.typ {
  322. case validationTypeVarint:
  323. ok = wtyp == protowire.VarintType
  324. case validationTypeFixed32:
  325. ok = wtyp == protowire.Fixed32Type
  326. case validationTypeFixed64:
  327. ok = wtyp == protowire.Fixed64Type
  328. case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
  329. ok = wtyp == protowire.BytesType
  330. case validationTypeGroup:
  331. ok = wtyp == protowire.StartGroupType
  332. }
  333. if ok {
  334. st.requiredMask |= vi.requiredBit
  335. }
  336. }
  337. switch wtyp {
  338. case protowire.VarintType:
  339. if len(b) >= 10 {
  340. switch {
  341. case b[0] < 0x80:
  342. b = b[1:]
  343. case b[1] < 0x80:
  344. b = b[2:]
  345. case b[2] < 0x80:
  346. b = b[3:]
  347. case b[3] < 0x80:
  348. b = b[4:]
  349. case b[4] < 0x80:
  350. b = b[5:]
  351. case b[5] < 0x80:
  352. b = b[6:]
  353. case b[6] < 0x80:
  354. b = b[7:]
  355. case b[7] < 0x80:
  356. b = b[8:]
  357. case b[8] < 0x80:
  358. b = b[9:]
  359. case b[9] < 0x80 && b[9] < 2:
  360. b = b[10:]
  361. default:
  362. return out, ValidationInvalid
  363. }
  364. } else {
  365. switch {
  366. case len(b) > 0 && b[0] < 0x80:
  367. b = b[1:]
  368. case len(b) > 1 && b[1] < 0x80:
  369. b = b[2:]
  370. case len(b) > 2 && b[2] < 0x80:
  371. b = b[3:]
  372. case len(b) > 3 && b[3] < 0x80:
  373. b = b[4:]
  374. case len(b) > 4 && b[4] < 0x80:
  375. b = b[5:]
  376. case len(b) > 5 && b[5] < 0x80:
  377. b = b[6:]
  378. case len(b) > 6 && b[6] < 0x80:
  379. b = b[7:]
  380. case len(b) > 7 && b[7] < 0x80:
  381. b = b[8:]
  382. case len(b) > 8 && b[8] < 0x80:
  383. b = b[9:]
  384. case len(b) > 9 && b[9] < 2:
  385. b = b[10:]
  386. default:
  387. return out, ValidationInvalid
  388. }
  389. }
  390. continue State
  391. case protowire.BytesType:
  392. var size uint64
  393. if len(b) >= 1 && b[0] < 0x80 {
  394. size = uint64(b[0])
  395. b = b[1:]
  396. } else if len(b) >= 2 && b[1] < 128 {
  397. size = uint64(b[0]&0x7f) + uint64(b[1])<<7
  398. b = b[2:]
  399. } else {
  400. var n int
  401. size, n = protowire.ConsumeVarint(b)
  402. if n < 0 {
  403. return out, ValidationInvalid
  404. }
  405. b = b[n:]
  406. }
  407. if size > uint64(len(b)) {
  408. return out, ValidationInvalid
  409. }
  410. v := b[:size]
  411. b = b[size:]
  412. switch vi.typ {
  413. case validationTypeMessage:
  414. if vi.mi == nil {
  415. return out, ValidationUnknown
  416. }
  417. vi.mi.init()
  418. fallthrough
  419. case validationTypeMap:
  420. if vi.mi != nil {
  421. vi.mi.init()
  422. }
  423. states = append(states, validationState{
  424. typ: vi.typ,
  425. keyType: vi.keyType,
  426. valType: vi.valType,
  427. mi: vi.mi,
  428. tail: b,
  429. })
  430. b = v
  431. continue State
  432. case validationTypeRepeatedVarint:
  433. // Packed field.
  434. for len(v) > 0 {
  435. _, n := protowire.ConsumeVarint(v)
  436. if n < 0 {
  437. return out, ValidationInvalid
  438. }
  439. v = v[n:]
  440. }
  441. case validationTypeRepeatedFixed32:
  442. // Packed field.
  443. if len(v)%4 != 0 {
  444. return out, ValidationInvalid
  445. }
  446. case validationTypeRepeatedFixed64:
  447. // Packed field.
  448. if len(v)%8 != 0 {
  449. return out, ValidationInvalid
  450. }
  451. case validationTypeUTF8String:
  452. if !utf8.Valid(v) {
  453. return out, ValidationInvalid
  454. }
  455. }
  456. case protowire.Fixed32Type:
  457. if len(b) < 4 {
  458. return out, ValidationInvalid
  459. }
  460. b = b[4:]
  461. case protowire.Fixed64Type:
  462. if len(b) < 8 {
  463. return out, ValidationInvalid
  464. }
  465. b = b[8:]
  466. case protowire.StartGroupType:
  467. switch {
  468. case vi.typ == validationTypeGroup:
  469. if vi.mi == nil {
  470. return out, ValidationUnknown
  471. }
  472. vi.mi.init()
  473. states = append(states, validationState{
  474. typ: validationTypeGroup,
  475. mi: vi.mi,
  476. endGroup: num,
  477. })
  478. continue State
  479. case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
  480. typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
  481. if err != nil {
  482. return out, ValidationInvalid
  483. }
  484. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
  485. switch {
  486. case err == protoregistry.NotFound:
  487. b = b[n:]
  488. case err != nil:
  489. return out, ValidationUnknown
  490. default:
  491. xvi := getExtensionFieldInfo(xt).validation
  492. if xvi.mi != nil {
  493. xvi.mi.init()
  494. }
  495. states = append(states, validationState{
  496. typ: xvi.typ,
  497. mi: xvi.mi,
  498. tail: b[n:],
  499. })
  500. b = v
  501. continue State
  502. }
  503. default:
  504. n := protowire.ConsumeFieldValue(num, wtyp, b)
  505. if n < 0 {
  506. return out, ValidationInvalid
  507. }
  508. b = b[n:]
  509. }
  510. default:
  511. return out, ValidationInvalid
  512. }
  513. }
  514. if st.endGroup != 0 {
  515. return out, ValidationInvalid
  516. }
  517. if len(b) != 0 {
  518. return out, ValidationInvalid
  519. }
  520. b = st.tail
  521. PopState:
  522. numRequiredFields := 0
  523. switch st.typ {
  524. case validationTypeMessage, validationTypeGroup:
  525. numRequiredFields = int(st.mi.numRequiredFields)
  526. case validationTypeMap:
  527. // If this is a map field with a message value that contains
  528. // required fields, require that the value be present.
  529. if st.mi != nil && st.mi.numRequiredFields > 0 {
  530. numRequiredFields = 1
  531. }
  532. }
  533. // If there are more than 64 required fields, this check will
  534. // always fail and we will report that the message is potentially
  535. // uninitialized.
  536. if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
  537. initialized = false
  538. }
  539. states = states[:len(states)-1]
  540. }
  541. out.n = start - len(b)
  542. if initialized {
  543. out.initialized = true
  544. }
  545. return out, ValidationValid
  546. }