bufferreader.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. // Copyright 2024 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // Helper code for parsing a protocol buffer
  5. package protolazy
  6. import (
  7. "errors"
  8. "fmt"
  9. "io"
  10. "google.golang.org/protobuf/encoding/protowire"
  11. )
  12. // BufferReader is a structure encapsulating a protobuf and a current position
  13. type BufferReader struct {
  14. Buf []byte
  15. Pos int
  16. }
  17. // NewBufferReader creates a new BufferRead from a protobuf
  18. func NewBufferReader(buf []byte) BufferReader {
  19. return BufferReader{Buf: buf, Pos: 0}
  20. }
  21. var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
  22. var errOverflow = errors.New("proto: integer overflow")
  23. func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
  24. i := b.Pos
  25. l := len(b.Buf)
  26. for shift := uint(0); shift < 64; shift += 7 {
  27. if i >= l {
  28. err = io.ErrUnexpectedEOF
  29. return
  30. }
  31. v := b.Buf[i]
  32. i++
  33. x |= (uint64(v) & 0x7F) << shift
  34. if v < 0x80 {
  35. b.Pos = i
  36. return
  37. }
  38. }
  39. // The number is too large to represent in a 64-bit value.
  40. err = errOverflow
  41. return
  42. }
  43. // decodeVarint decodes a varint at the current position
  44. func (b *BufferReader) DecodeVarint() (x uint64, err error) {
  45. i := b.Pos
  46. buf := b.Buf
  47. if i >= len(buf) {
  48. return 0, io.ErrUnexpectedEOF
  49. } else if buf[i] < 0x80 {
  50. b.Pos++
  51. return uint64(buf[i]), nil
  52. } else if len(buf)-i < 10 {
  53. return b.DecodeVarintSlow()
  54. }
  55. var v uint64
  56. // we already checked the first byte
  57. x = uint64(buf[i]) & 127
  58. i++
  59. v = uint64(buf[i])
  60. i++
  61. x |= (v & 127) << 7
  62. if v < 128 {
  63. goto done
  64. }
  65. v = uint64(buf[i])
  66. i++
  67. x |= (v & 127) << 14
  68. if v < 128 {
  69. goto done
  70. }
  71. v = uint64(buf[i])
  72. i++
  73. x |= (v & 127) << 21
  74. if v < 128 {
  75. goto done
  76. }
  77. v = uint64(buf[i])
  78. i++
  79. x |= (v & 127) << 28
  80. if v < 128 {
  81. goto done
  82. }
  83. v = uint64(buf[i])
  84. i++
  85. x |= (v & 127) << 35
  86. if v < 128 {
  87. goto done
  88. }
  89. v = uint64(buf[i])
  90. i++
  91. x |= (v & 127) << 42
  92. if v < 128 {
  93. goto done
  94. }
  95. v = uint64(buf[i])
  96. i++
  97. x |= (v & 127) << 49
  98. if v < 128 {
  99. goto done
  100. }
  101. v = uint64(buf[i])
  102. i++
  103. x |= (v & 127) << 56
  104. if v < 128 {
  105. goto done
  106. }
  107. v = uint64(buf[i])
  108. i++
  109. x |= (v & 127) << 63
  110. if v < 128 {
  111. goto done
  112. }
  113. return 0, errOverflow
  114. done:
  115. b.Pos = i
  116. return
  117. }
  118. // decodeVarint32 decodes a varint32 at the current position
  119. func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
  120. i := b.Pos
  121. buf := b.Buf
  122. if i >= len(buf) {
  123. return 0, io.ErrUnexpectedEOF
  124. } else if buf[i] < 0x80 {
  125. b.Pos++
  126. return uint32(buf[i]), nil
  127. } else if len(buf)-i < 5 {
  128. v, err := b.DecodeVarintSlow()
  129. return uint32(v), err
  130. }
  131. var v uint32
  132. // we already checked the first byte
  133. x = uint32(buf[i]) & 127
  134. i++
  135. v = uint32(buf[i])
  136. i++
  137. x |= (v & 127) << 7
  138. if v < 128 {
  139. goto done
  140. }
  141. v = uint32(buf[i])
  142. i++
  143. x |= (v & 127) << 14
  144. if v < 128 {
  145. goto done
  146. }
  147. v = uint32(buf[i])
  148. i++
  149. x |= (v & 127) << 21
  150. if v < 128 {
  151. goto done
  152. }
  153. v = uint32(buf[i])
  154. i++
  155. x |= (v & 127) << 28
  156. if v < 128 {
  157. goto done
  158. }
  159. return 0, errOverflow
  160. done:
  161. b.Pos = i
  162. return
  163. }
  164. // skipValue skips a value in the protobuf, based on the specified tag
  165. func (b *BufferReader) SkipValue(tag uint32) (err error) {
  166. wireType := tag & 0x7
  167. switch protowire.Type(wireType) {
  168. case protowire.VarintType:
  169. err = b.SkipVarint()
  170. case protowire.Fixed64Type:
  171. err = b.SkipFixed64()
  172. case protowire.BytesType:
  173. var n uint32
  174. n, err = b.DecodeVarint32()
  175. if err == nil {
  176. err = b.Skip(int(n))
  177. }
  178. case protowire.StartGroupType:
  179. err = b.SkipGroup(tag)
  180. case protowire.Fixed32Type:
  181. err = b.SkipFixed32()
  182. default:
  183. err = fmt.Errorf("Unexpected wire type (%d)", wireType)
  184. }
  185. return
  186. }
  187. // skipGroup skips a group with the specified tag. It executes efficiently using a tag stack
  188. func (b *BufferReader) SkipGroup(tag uint32) (err error) {
  189. tagStack := make([]uint32, 0, 16)
  190. tagStack = append(tagStack, tag)
  191. var n uint32
  192. for len(tagStack) > 0 {
  193. tag, err = b.DecodeVarint32()
  194. if err != nil {
  195. return err
  196. }
  197. switch protowire.Type(tag & 0x7) {
  198. case protowire.VarintType:
  199. err = b.SkipVarint()
  200. case protowire.Fixed64Type:
  201. err = b.Skip(8)
  202. case protowire.BytesType:
  203. n, err = b.DecodeVarint32()
  204. if err == nil {
  205. err = b.Skip(int(n))
  206. }
  207. case protowire.StartGroupType:
  208. tagStack = append(tagStack, tag)
  209. case protowire.Fixed32Type:
  210. err = b.SkipFixed32()
  211. case protowire.EndGroupType:
  212. if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
  213. tagStack = tagStack[:len(tagStack)-1]
  214. } else {
  215. err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
  216. protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
  217. }
  218. }
  219. if err != nil {
  220. return err
  221. }
  222. }
  223. return nil
  224. }
  225. // skipVarint effiently skips a varint
  226. func (b *BufferReader) SkipVarint() (err error) {
  227. i := b.Pos
  228. if len(b.Buf)-i < 10 {
  229. // Use DecodeVarintSlow() to check for buffer overflow, but ignore result
  230. if _, err := b.DecodeVarintSlow(); err != nil {
  231. return err
  232. }
  233. return nil
  234. }
  235. if b.Buf[i] < 0x80 {
  236. goto out
  237. }
  238. i++
  239. if b.Buf[i] < 0x80 {
  240. goto out
  241. }
  242. i++
  243. if b.Buf[i] < 0x80 {
  244. goto out
  245. }
  246. i++
  247. if b.Buf[i] < 0x80 {
  248. goto out
  249. }
  250. i++
  251. if b.Buf[i] < 0x80 {
  252. goto out
  253. }
  254. i++
  255. if b.Buf[i] < 0x80 {
  256. goto out
  257. }
  258. i++
  259. if b.Buf[i] < 0x80 {
  260. goto out
  261. }
  262. i++
  263. if b.Buf[i] < 0x80 {
  264. goto out
  265. }
  266. i++
  267. if b.Buf[i] < 0x80 {
  268. goto out
  269. }
  270. i++
  271. if b.Buf[i] < 0x80 {
  272. goto out
  273. }
  274. return errOverflow
  275. out:
  276. b.Pos = i + 1
  277. return nil
  278. }
  279. // skip skips the specified number of bytes
  280. func (b *BufferReader) Skip(n int) (err error) {
  281. if len(b.Buf) < b.Pos+n {
  282. return io.ErrUnexpectedEOF
  283. }
  284. b.Pos += n
  285. return
  286. }
  287. // skipFixed64 skips a fixed64
  288. func (b *BufferReader) SkipFixed64() (err error) {
  289. return b.Skip(8)
  290. }
  291. // skipFixed32 skips a fixed32
  292. func (b *BufferReader) SkipFixed32() (err error) {
  293. return b.Skip(4)
  294. }
  295. // skipBytes skips a set of bytes
  296. func (b *BufferReader) SkipBytes() (err error) {
  297. n, err := b.DecodeVarint32()
  298. if err != nil {
  299. return err
  300. }
  301. return b.Skip(int(n))
  302. }
  303. // Done returns whether we are at the end of the protobuf
  304. func (b *BufferReader) Done() bool {
  305. return b.Pos == len(b.Buf)
  306. }
  307. // Remaining returns how many bytes remain
  308. func (b *BufferReader) Remaining() int {
  309. return len(b.Buf) - b.Pos
  310. }