diff --git a/msgp/errors.go b/msgp/errors.go index 5c24f271..8f197267 100644 --- a/msgp/errors.go +++ b/msgp/errors.go @@ -88,6 +88,21 @@ func (u UintOverflow) Error() string { // Resumable is always 'true' for overflows func (u UintOverflow) Resumable() bool { return true } +// UintBelowZero is returned when a call +// would cast a signed integer below zero +// to an unsigned integer. +type UintBelowZero struct { + Value int64 // value of the incoming int +} + +// Error implements the error interface +func (u UintBelowZero) Error() string { + return fmt.Sprintf("msgp: attempted to cast int %d to unsigned", u.Value) +} + +// Resumable is always 'true' for overflows +func (u UintBelowZero) Resumable() bool { return true } + // A TypeError is returned when a particular // decoding method is unsuitable for decoding // a particular MessagePack value. diff --git a/msgp/read.go b/msgp/read.go index 20cd1ef8..aa668c57 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -583,6 +583,14 @@ func (m *Reader) ReadInt64() (i int64, err error) { i = int64(getMint8(p)) return + case muint8: + p, err = m.R.Next(2) + if err != nil { + return + } + i = int64(getMuint8(p)) + return + case mint16: p, err = m.R.Next(3) if err != nil { @@ -591,6 +599,14 @@ func (m *Reader) ReadInt64() (i int64, err error) { i = int64(getMint16(p)) return + case muint16: + p, err = m.R.Next(3) + if err != nil { + return + } + i = int64(getMuint16(p)) + return + case mint32: p, err = m.R.Next(5) if err != nil { @@ -599,6 +615,14 @@ func (m *Reader) ReadInt64() (i int64, err error) { i = int64(getMint32(p)) return + case muint32: + p, err = m.R.Next(5) + if err != nil { + return + } + i = int64(getMuint32(p)) + return + case mint64: p, err = m.R.Next(9) if err != nil { @@ -607,6 +631,19 @@ func (m *Reader) ReadInt64() (i int64, err error) { i = getMint64(p) return + case muint64: + p, err = m.R.Next(9) + if err != nil { + return + } + u := getMuint64(p) + if u > math.MaxInt64 { + err = UintOverflow{Value: u, FailedBitsize: 64} + return + } + i = int64(u) + return + default: err = badPrefix(IntType, lead) return @@ -678,6 +715,19 @@ func (m *Reader) ReadUint64() (u uint64, err error) { return } switch lead { + case mint8: + p, err = m.R.Next(2) + if err != nil { + return + } + v := int64(getMint8(p)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + return + case muint8: p, err = m.R.Next(2) if err != nil { @@ -686,6 +736,19 @@ func (m *Reader) ReadUint64() (u uint64, err error) { u = uint64(getMuint8(p)) return + case mint16: + p, err = m.R.Next(3) + if err != nil { + return + } + v := int64(getMint16(p)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + return + case muint16: p, err = m.R.Next(3) if err != nil { @@ -694,6 +757,19 @@ func (m *Reader) ReadUint64() (u uint64, err error) { u = uint64(getMuint16(p)) return + case mint32: + p, err = m.R.Next(5) + if err != nil { + return + } + v := int64(getMint32(p)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + return + case muint32: p, err = m.R.Next(5) if err != nil { @@ -702,6 +778,19 @@ func (m *Reader) ReadUint64() (u uint64, err error) { u = uint64(getMuint32(p)) return + case mint64: + p, err = m.R.Next(9) + if err != nil { + return + } + v := int64(getMint64(p)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + return + case muint64: p, err = m.R.Next(9) if err != nil { @@ -711,7 +800,11 @@ func (m *Reader) ReadUint64() (u uint64, err error) { return default: - err = badPrefix(UintType, lead) + if isnfixint(lead) { + err = UintBelowZero{Value: int64(rnfixint(lead))} + } else { + err = badPrefix(UintType, lead) + } return } diff --git a/msgp/read_bytes.go b/msgp/read_bytes.go index 78e466fc..001baf84 100644 --- a/msgp/read_bytes.go +++ b/msgp/read_bytes.go @@ -368,6 +368,15 @@ func ReadInt64Bytes(b []byte) (i int64, o []byte, err error) { o = b[2:] return + case muint8: + if l < 2 { + err = ErrShortBytes + return + } + i = int64(getMuint8(b)) + o = b[2:] + return + case mint16: if l < 3 { err = ErrShortBytes @@ -377,6 +386,15 @@ func ReadInt64Bytes(b []byte) (i int64, o []byte, err error) { o = b[3:] return + case muint16: + if l < 3 { + err = ErrShortBytes + return + } + i = int64(getMuint16(b)) + o = b[3:] + return + case mint32: if l < 5 { err = ErrShortBytes @@ -386,12 +404,35 @@ func ReadInt64Bytes(b []byte) (i int64, o []byte, err error) { o = b[5:] return + case muint32: + if l < 5 { + err = ErrShortBytes + return + } + i = int64(getMuint32(b)) + o = b[5:] + return + case mint64: if l < 9 { err = ErrShortBytes return } - i = getMint64(b) + i = int64(getMint64(b)) + o = b[9:] + return + + case muint64: + if l < 9 { + err = ErrShortBytes + return + } + u := getMuint64(b) + if u > math.MaxInt64 { + err = UintOverflow{Value: u, FailedBitsize: 64} + return + } + i = int64(u) o = b[9:] return @@ -477,6 +518,20 @@ func ReadUint64Bytes(b []byte) (u uint64, o []byte, err error) { } switch lead { + case mint8: + if l < 2 { + err = ErrShortBytes + return + } + v := int64(getMint8(b)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + o = b[2:] + return + case muint8: if l < 2 { err = ErrShortBytes @@ -486,6 +541,20 @@ func ReadUint64Bytes(b []byte) (u uint64, o []byte, err error) { o = b[2:] return + case mint16: + if l < 3 { + err = ErrShortBytes + return + } + v := int64(getMint16(b)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + o = b[3:] + return + case muint16: if l < 3 { err = ErrShortBytes @@ -495,6 +564,20 @@ func ReadUint64Bytes(b []byte) (u uint64, o []byte, err error) { o = b[3:] return + case mint32: + if l < 5 { + err = ErrShortBytes + return + } + v := int64(getMint32(b)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + o = b[5:] + return + case muint32: if l < 5 { err = ErrShortBytes @@ -504,6 +587,20 @@ func ReadUint64Bytes(b []byte) (u uint64, o []byte, err error) { o = b[5:] return + case mint64: + if l < 9 { + err = ErrShortBytes + return + } + v := int64(getMint64(b)) + if v < 0 { + err = UintBelowZero{Value: v} + return + } + u = uint64(v) + o = b[9:] + return + case muint64: if l < 9 { err = ErrShortBytes @@ -514,7 +611,11 @@ func ReadUint64Bytes(b []byte) (u uint64, o []byte, err error) { return default: - err = badPrefix(UintType, lead) + if isnfixint(lead) { + err = UintBelowZero{Value: int64(rnfixint(lead))} + } else { + err = badPrefix(UintType, lead) + } return } } diff --git a/msgp/read_bytes_test.go b/msgp/read_bytes_test.go index 0049471b..45af9f04 100644 --- a/msgp/read_bytes_test.go +++ b/msgp/read_bytes_test.go @@ -2,6 +2,9 @@ package msgp import ( "bytes" + "fmt" + "log" + "math" "reflect" "testing" "time" @@ -222,53 +225,211 @@ func BenchmarkReadBoolBytes(b *testing.B) { func TestReadInt64Bytes(t *testing.T) { var buf bytes.Buffer - en := NewWriter(&buf) + wr := NewWriter(&buf) - tests := []int64{-5, -30, 0, 1, 127, 300, 40921, 34908219} + ints := []int64{-100000, -5000, -5, 0, 8, 240, int64(tuint16), int64(tuint32), int64(tuint64), + -5, -30, 0, 1, 127, 300, 40921, 34908219} - for i, v := range tests { + uints := []uint64{0, 8, 240, uint64(tuint16), uint64(tuint32), uint64(tuint64)} + + all := make([]interface{}, 0, len(ints)+len(uints)) + for _, v := range ints { + all = append(all, v) + } + for _, v := range uints { + all = append(all, v) + } + + for i, num := range all { buf.Reset() - en.WriteInt64(v) - en.Flush() - out, left, err := ReadInt64Bytes(buf.Bytes()) + var err error + + var in int64 + switch num := num.(type) { + case int64: + err = wr.WriteInt64(num) + in = num + case uint64: + err = wr.WriteUint64(num) + in = int64(num) + default: + panic(num) + } + if err != nil { + t.Fatal(err) + } + err = wr.Flush() + if err != nil { + t.Fatal(err) + } + out, left, err := ReadInt64Bytes(buf.Bytes()) + if out != in { + t.Errorf("Test case %d: put %d in and got %d out", i, num, in) + } if err != nil { t.Errorf("test case %d: %s", i, err) } - if len(left) != 0 { t.Errorf("expected 0 bytes left; found %d", len(left)) } - - if out != v { - t.Errorf("%d in; %d out", v, out) - } } } func TestReadUint64Bytes(t *testing.T) { var buf bytes.Buffer - en := NewWriter(&buf) + wr := NewWriter(&buf) - tests := []uint64{0, 1, 127, 300, 40921, 34908219} + vs := []interface{}{ + int64(0), int64(8), int64(240), int64(tuint16), int64(tuint32), int64(tuint64), + uint64(0), uint64(8), uint64(240), uint64(tuint16), uint64(tuint32), uint64(tuint64), + uint64(math.MaxUint64), + } - for i, v := range tests { + for i, num := range vs { buf.Reset() - en.WriteUint64(v) - en.Flush() - out, left, err := ReadUint64Bytes(buf.Bytes()) + var err error + + var in uint64 + switch num := num.(type) { + case int64: + err = wr.WriteInt64(num) + in = uint64(num) + case uint64: + err = wr.WriteUint64(num) + in = (num) + default: + panic(num) + } + if err != nil { + t.Fatal(err) + } + err = wr.Flush() + if err != nil { + t.Fatal(err) + } + out, left, err := ReadUint64Bytes(buf.Bytes()) + if out != in { + t.Errorf("Test case %d: put %d in and got %d out", i, num, in) + } if err != nil { t.Errorf("test case %d: %s", i, err) } - if len(left) != 0 { t.Errorf("expected 0 bytes left; found %d", len(left)) } + } +} - if out != v { - t.Errorf("%d in; %d out", v, out) - } +func TestReadIntBytesOverflows(t *testing.T) { + var buf bytes.Buffer + wr := NewWriter(&buf) + + i8, i16, i32, i64, u8, u16, u32, u64 := 1, 2, 3, 4, 5, 6, 7, 8 + + overflowErr := func(err error, failBits int) bool { + bits := 0 + switch err := err.(type) { + case IntOverflow: + bits = err.FailedBitsize + case UintOverflow: + bits = err.FailedBitsize + } + if bits == failBits { + return true + } + log.Println("bits mismatch", bits, failBits) + return false + } + + belowZeroErr := func(err error, failBits int) bool { + switch err.(type) { + case UintBelowZero: + return true + } + return false + } + + vs := []struct { + v interface{} + rdBits int + failBits int + errCheck func(err error, failBits int) bool + }{ + {uint64(math.MaxInt64), i32, 32, overflowErr}, + {uint64(math.MaxInt64), i16, 16, overflowErr}, + {uint64(math.MaxInt64), i8, 8, overflowErr}, + + {uint64(math.MaxUint64), i64, 64, overflowErr}, + {uint64(math.MaxUint64), i32, 64, overflowErr}, + {uint64(math.MaxUint64), i16, 64, overflowErr}, + {uint64(math.MaxUint64), i8, 64, overflowErr}, + + {uint64(math.MaxUint32), i32, 32, overflowErr}, + {uint64(math.MaxUint32), i16, 16, overflowErr}, + {uint64(math.MaxUint32), i8, 8, overflowErr}, + + {int64(math.MinInt64), u64, 64, belowZeroErr}, + {int64(math.MinInt64), u32, 64, belowZeroErr}, + {int64(math.MinInt64), u16, 64, belowZeroErr}, + {int64(math.MinInt64), u8, 64, belowZeroErr}, + {int64(math.MinInt32), u64, 64, belowZeroErr}, + {int64(math.MinInt32), u32, 32, belowZeroErr}, + {int64(math.MinInt32), u16, 16, belowZeroErr}, + {int64(math.MinInt32), u8, 8, belowZeroErr}, + {int64(math.MinInt16), u64, 64, belowZeroErr}, + {int64(math.MinInt16), u32, 32, belowZeroErr}, + {int64(math.MinInt16), u16, 16, belowZeroErr}, + {int64(math.MinInt16), u8, 8, belowZeroErr}, + {int64(math.MinInt8), u64, 64, belowZeroErr}, + {int64(math.MinInt8), u32, 32, belowZeroErr}, + {int64(math.MinInt8), u16, 16, belowZeroErr}, + {int64(math.MinInt8), u8, 8, belowZeroErr}, + {-1, u64, 64, belowZeroErr}, + {-1, u32, 32, belowZeroErr}, + {-1, u16, 16, belowZeroErr}, + {-1, u8, 8, belowZeroErr}, + } + + for i, v := range vs { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + buf.Reset() + switch num := v.v.(type) { + case int: + wr.WriteInt64(int64(num)) + case int64: + wr.WriteInt64(num) + case uint64: + wr.WriteUint64(num) + default: + panic(num) + } + wr.Flush() + + var err error + switch v.rdBits { + case i64: + _, _, err = ReadInt64Bytes(buf.Bytes()) + case i32: + _, _, err = ReadInt32Bytes(buf.Bytes()) + case i16: + _, _, err = ReadInt16Bytes(buf.Bytes()) + case i8: + _, _, err = ReadInt8Bytes(buf.Bytes()) + case u64: + _, _, err = ReadUint64Bytes(buf.Bytes()) + case u32: + _, _, err = ReadUint32Bytes(buf.Bytes()) + case u16: + _, _, err = ReadUint16Bytes(buf.Bytes()) + case u8: + _, _, err = ReadUint8Bytes(buf.Bytes()) + } + if !v.errCheck(err, v.failBits) { + t.Fatal(err) + } + }) } } diff --git a/msgp/read_test.go b/msgp/read_test.go index 8e781c10..b469c1bc 100644 --- a/msgp/read_test.go +++ b/msgp/read_test.go @@ -2,6 +2,7 @@ package msgp import ( "bytes" + "fmt" "io" "math" "math/rand" @@ -299,11 +300,31 @@ func TestReadInt64(t *testing.T) { rd := NewReader(&buf) ints := []int64{-100000, -5000, -5, 0, 8, 240, int64(tuint16), int64(tuint32), int64(tuint64)} + uints := []uint64{0, 8, 240, uint64(tuint16), uint64(tuint32), uint64(tuint64)} - for i, num := range ints { + all := make([]interface{}, 0, len(ints)+len(uints)) + for _, v := range ints { + all = append(all, v) + } + for _, v := range uints { + all = append(all, v) + } + + for i, num := range all { buf.Reset() + var err error - err := wr.WriteInt64(num) + var in int64 + switch num := num.(type) { + case int64: + err = wr.WriteInt64(num) + in = num + case uint64: + err = wr.WriteUint64(num) + in = int64(num) + default: + panic(num) + } if err != nil { t.Fatal(err) } @@ -315,12 +336,122 @@ func TestReadInt64(t *testing.T) { if err != nil { t.Fatal(err) } - if out != num { - t.Errorf("Test case %d: put %d in and got %d out", i, num, out) + if out != in { + t.Errorf("Test case %d: put %d in and got %d out", i, num, in) } } } +func TestReadIntOverflows(t *testing.T) { + var buf bytes.Buffer + wr := NewWriter(&buf) + rd := NewReader(&buf) + + i8, i16, i32, i64, u8, u16, u32, u64 := 1, 2, 3, 4, 5, 6, 7, 8 + + overflowErr := func(err error, failBits int) bool { + bits := 0 + switch err := err.(type) { + case IntOverflow: + bits = err.FailedBitsize + case UintOverflow: + bits = err.FailedBitsize + } + if bits == failBits { + return true + } + return false + } + + belowZeroErr := func(err error, failBits int) bool { + switch err.(type) { + case UintBelowZero: + return true + } + return false + } + + vs := []struct { + v interface{} + rdBits int + failBits int + errCheck func(err error, failBits int) bool + }{ + {uint64(math.MaxInt64), i32, 32, overflowErr}, + {uint64(math.MaxInt64), i16, 16, overflowErr}, + {uint64(math.MaxInt64), i8, 8, overflowErr}, + + {uint64(math.MaxUint64), i64, 64, overflowErr}, + {uint64(math.MaxUint64), i32, 64, overflowErr}, + {uint64(math.MaxUint64), i16, 64, overflowErr}, + {uint64(math.MaxUint64), i8, 64, overflowErr}, + + {uint64(math.MaxUint32), i32, 32, overflowErr}, + {uint64(math.MaxUint32), i16, 16, overflowErr}, + {uint64(math.MaxUint32), i8, 8, overflowErr}, + + {int64(math.MinInt64), u64, 64, belowZeroErr}, + {int64(math.MinInt64), u32, 64, belowZeroErr}, + {int64(math.MinInt64), u16, 64, belowZeroErr}, + {int64(math.MinInt64), u8, 64, belowZeroErr}, + {int64(math.MinInt32), u64, 64, belowZeroErr}, + {int64(math.MinInt32), u32, 32, belowZeroErr}, + {int64(math.MinInt32), u16, 16, belowZeroErr}, + {int64(math.MinInt32), u8, 8, belowZeroErr}, + {int64(math.MinInt16), u64, 64, belowZeroErr}, + {int64(math.MinInt16), u32, 32, belowZeroErr}, + {int64(math.MinInt16), u16, 16, belowZeroErr}, + {int64(math.MinInt16), u8, 8, belowZeroErr}, + {int64(math.MinInt8), u64, 64, belowZeroErr}, + {int64(math.MinInt8), u32, 32, belowZeroErr}, + {int64(math.MinInt8), u16, 16, belowZeroErr}, + {int64(math.MinInt8), u8, 8, belowZeroErr}, + {-1, u64, 64, belowZeroErr}, + {-1, u32, 32, belowZeroErr}, + {-1, u16, 16, belowZeroErr}, + {-1, u8, 8, belowZeroErr}, + } + + for i, v := range vs { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + switch num := v.v.(type) { + case int: + wr.WriteInt64(int64(num)) + case int64: + wr.WriteInt64(num) + case uint64: + wr.WriteUint64(num) + default: + panic(num) + } + wr.Flush() + + var err error + switch v.rdBits { + case i64: + _, err = rd.ReadInt64() + case i32: + _, err = rd.ReadInt32() + case i16: + _, err = rd.ReadInt16() + case i8: + _, err = rd.ReadInt8() + case u64: + _, err = rd.ReadUint64() + case u32: + _, err = rd.ReadUint32() + case u16: + _, err = rd.ReadUint16() + case u8: + _, err = rd.ReadUint8() + } + if !v.errCheck(err, v.failBits) { + t.Fatal(err) + } + }) + } +} + func BenchmarkReadInt64(b *testing.B) { is := []int64{0, 1, 65000, rand.Int63()} data := make([]byte, 0, 9*len(is)) @@ -339,6 +470,24 @@ func BenchmarkReadInt64(b *testing.B) { } } +func BenchmarkReadUintWithInt64(b *testing.B) { + us := []uint64{0, 1, 10000, uint64(rand.Uint32() * 4)} + data := make([]byte, 0, 9*len(us)) + for _, n := range us { + data = AppendUint64(data, n) + } + rd := NewReader(NewEndlessReader(data, b)) + b.SetBytes(int64(len(data) / len(us))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rd.ReadInt64() + if err != nil { + b.Fatal(err) + } + } +} + func TestReadUint64(t *testing.T) { var buf bytes.Buffer wr := NewWriter(&buf) @@ -382,6 +531,24 @@ func BenchmarkReadUint64(b *testing.B) { } } +func BenchmarkReadIntWithUint64(b *testing.B) { + is := []int64{0, 1, 65000, rand.Int63()} + data := make([]byte, 0, 9*len(is)) + for _, n := range is { + data = AppendInt64(data, n) + } + rd := NewReader(NewEndlessReader(data, b)) + b.SetBytes(int64(len(data) / len(is))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := rd.ReadUint64() + if err != nil { + b.Fatal(err) + } + } +} + func TestReadBytes(t *testing.T) { var buf bytes.Buffer wr := NewWriter(&buf)