123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- // Copyright 2024 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // Helper code for parsing a protocol buffer
- package protolazy
- import (
- "errors"
- "fmt"
- "io"
- "google.golang.org/protobuf/encoding/protowire"
- )
- // BufferReader is a structure encapsulating a protobuf and a current position
- type BufferReader struct {
- Buf []byte
- Pos int
- }
- // NewBufferReader creates a new BufferRead from a protobuf
- func NewBufferReader(buf []byte) BufferReader {
- return BufferReader{Buf: buf, Pos: 0}
- }
- var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
- var errOverflow = errors.New("proto: integer overflow")
- func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
- i := b.Pos
- l := len(b.Buf)
- for shift := uint(0); shift < 64; shift += 7 {
- if i >= l {
- err = io.ErrUnexpectedEOF
- return
- }
- v := b.Buf[i]
- i++
- x |= (uint64(v) & 0x7F) << shift
- if v < 0x80 {
- b.Pos = i
- return
- }
- }
- // The number is too large to represent in a 64-bit value.
- err = errOverflow
- return
- }
- // decodeVarint decodes a varint at the current position
- func (b *BufferReader) DecodeVarint() (x uint64, err error) {
- i := b.Pos
- buf := b.Buf
- if i >= len(buf) {
- return 0, io.ErrUnexpectedEOF
- } else if buf[i] < 0x80 {
- b.Pos++
- return uint64(buf[i]), nil
- } else if len(buf)-i < 10 {
- return b.DecodeVarintSlow()
- }
- var v uint64
- // we already checked the first byte
- x = uint64(buf[i]) & 127
- i++
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 7
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 14
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 21
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 28
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 35
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 42
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 49
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 56
- if v < 128 {
- goto done
- }
- v = uint64(buf[i])
- i++
- x |= (v & 127) << 63
- if v < 128 {
- goto done
- }
- return 0, errOverflow
- done:
- b.Pos = i
- return
- }
- // decodeVarint32 decodes a varint32 at the current position
- func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
- i := b.Pos
- buf := b.Buf
- if i >= len(buf) {
- return 0, io.ErrUnexpectedEOF
- } else if buf[i] < 0x80 {
- b.Pos++
- return uint32(buf[i]), nil
- } else if len(buf)-i < 5 {
- v, err := b.DecodeVarintSlow()
- return uint32(v), err
- }
- var v uint32
- // we already checked the first byte
- x = uint32(buf[i]) & 127
- i++
- v = uint32(buf[i])
- i++
- x |= (v & 127) << 7
- if v < 128 {
- goto done
- }
- v = uint32(buf[i])
- i++
- x |= (v & 127) << 14
- if v < 128 {
- goto done
- }
- v = uint32(buf[i])
- i++
- x |= (v & 127) << 21
- if v < 128 {
- goto done
- }
- v = uint32(buf[i])
- i++
- x |= (v & 127) << 28
- if v < 128 {
- goto done
- }
- return 0, errOverflow
- done:
- b.Pos = i
- return
- }
- // skipValue skips a value in the protobuf, based on the specified tag
- func (b *BufferReader) SkipValue(tag uint32) (err error) {
- wireType := tag & 0x7
- switch protowire.Type(wireType) {
- case protowire.VarintType:
- err = b.SkipVarint()
- case protowire.Fixed64Type:
- err = b.SkipFixed64()
- case protowire.BytesType:
- var n uint32
- n, err = b.DecodeVarint32()
- if err == nil {
- err = b.Skip(int(n))
- }
- case protowire.StartGroupType:
- err = b.SkipGroup(tag)
- case protowire.Fixed32Type:
- err = b.SkipFixed32()
- default:
- err = fmt.Errorf("Unexpected wire type (%d)", wireType)
- }
- return
- }
- // skipGroup skips a group with the specified tag. It executes efficiently using a tag stack
- func (b *BufferReader) SkipGroup(tag uint32) (err error) {
- tagStack := make([]uint32, 0, 16)
- tagStack = append(tagStack, tag)
- var n uint32
- for len(tagStack) > 0 {
- tag, err = b.DecodeVarint32()
- if err != nil {
- return err
- }
- switch protowire.Type(tag & 0x7) {
- case protowire.VarintType:
- err = b.SkipVarint()
- case protowire.Fixed64Type:
- err = b.Skip(8)
- case protowire.BytesType:
- n, err = b.DecodeVarint32()
- if err == nil {
- err = b.Skip(int(n))
- }
- case protowire.StartGroupType:
- tagStack = append(tagStack, tag)
- case protowire.Fixed32Type:
- err = b.SkipFixed32()
- case protowire.EndGroupType:
- if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
- tagStack = tagStack[:len(tagStack)-1]
- } else {
- err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
- protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
- }
- }
- if err != nil {
- return err
- }
- }
- return nil
- }
- // skipVarint effiently skips a varint
- func (b *BufferReader) SkipVarint() (err error) {
- i := b.Pos
- if len(b.Buf)-i < 10 {
- // Use DecodeVarintSlow() to check for buffer overflow, but ignore result
- if _, err := b.DecodeVarintSlow(); err != nil {
- return err
- }
- return nil
- }
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- i++
- if b.Buf[i] < 0x80 {
- goto out
- }
- return errOverflow
- out:
- b.Pos = i + 1
- return nil
- }
- // skip skips the specified number of bytes
- func (b *BufferReader) Skip(n int) (err error) {
- if len(b.Buf) < b.Pos+n {
- return io.ErrUnexpectedEOF
- }
- b.Pos += n
- return
- }
- // skipFixed64 skips a fixed64
- func (b *BufferReader) SkipFixed64() (err error) {
- return b.Skip(8)
- }
- // skipFixed32 skips a fixed32
- func (b *BufferReader) SkipFixed32() (err error) {
- return b.Skip(4)
- }
- // skipBytes skips a set of bytes
- func (b *BufferReader) SkipBytes() (err error) {
- n, err := b.DecodeVarint32()
- if err != nil {
- return err
- }
- return b.Skip(int(n))
- }
- // Done returns whether we are at the end of the protobuf
- func (b *BufferReader) Done() bool {
- return b.Pos == len(b.Buf)
- }
- // Remaining returns how many bytes remain
- func (b *BufferReader) Remaining() int {
- return len(b.Buf) - b.Pos
- }
|