Copied mongo repo (to patch it)

This commit is contained in:
2023-06-18 15:50:55 +02:00
parent 21d241f9b1
commit d471d7c396
544 changed files with 142039 additions and 1 deletions

View File

@@ -0,0 +1,437 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package primitive
import (
"encoding/json"
"errors"
"fmt"
"math/big"
"regexp"
"strconv"
"strings"
)
// These constants are the maximum and minimum values for the exponent field in a decimal128 value.
const (
MaxDecimal128Exp = 6111
MinDecimal128Exp = -6176
)
// These errors are returned when an invalid value is parsed as a big.Int.
var (
ErrParseNaN = errors.New("cannot parse NaN as a *big.Int")
ErrParseInf = errors.New("cannot parse Infinity as a *big.Int")
ErrParseNegInf = errors.New("cannot parse -Infinity as a *big.Int")
)
// Decimal128 holds decimal128 BSON values.
type Decimal128 struct {
h, l uint64
}
// NewDecimal128 creates a Decimal128 using the provide high and low uint64s.
func NewDecimal128(h, l uint64) Decimal128 {
return Decimal128{h: h, l: l}
}
// GetBytes returns the underlying bytes of the BSON decimal value as two uint64 values. The first
// contains the most first 8 bytes of the value and the second contains the latter.
func (d Decimal128) GetBytes() (uint64, uint64) {
return d.h, d.l
}
// String returns a string representation of the decimal value.
func (d Decimal128) String() string {
var posSign int // positive sign
var exp int // exponent
var high, low uint64 // significand high/low
if d.h>>63&1 == 0 {
posSign = 1
}
switch d.h >> 58 & (1<<5 - 1) {
case 0x1F:
return "NaN"
case 0x1E:
return "-Infinity"[posSign:]
}
low = d.l
if d.h>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(d.h >> 47 & (1<<14 - 1))
//high = 4<<47 | d.h&(1<<47-1)
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(d.h >> 49 & (1<<14 - 1))
high = d.h & (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
return "-0"[posSign:]
}
var repr [48]byte // Loop 5 times over 9 digits plus dot, negative sign, and leading zero.
var last = len(repr)
var i = len(repr)
var dot = len(repr) + exp
var rem uint32
Loop:
for d9 := 0; d9 < 5; d9++ {
high, low, rem = divmod(high, low, 1e9)
for d1 := 0; d1 < 9; d1++ {
// Handle "-0.0", "0.00123400", "-1.00E-6", "1.050E+3", etc.
if i < len(repr) && (dot == i || low == 0 && high == 0 && rem > 0 && rem < 10 && (dot < i-6 || exp > 0)) {
exp += len(repr) - i
i--
repr[i] = '.'
last = i - 1
dot = len(repr) // Unmark.
}
c := '0' + byte(rem%10)
rem /= 10
i--
repr[i] = c
// Handle "0E+3", "1E+3", etc.
if low == 0 && high == 0 && rem == 0 && i == len(repr)-1 && (dot < i-5 || exp > 0) {
last = i
break Loop
}
if c != '0' {
last = i
}
// Break early. Works without it, but why.
if dot > i && low == 0 && high == 0 && rem == 0 {
break Loop
}
}
}
repr[last-1] = '-'
last--
if exp > 0 {
return string(repr[last+posSign:]) + "E+" + strconv.Itoa(exp)
}
if exp < 0 {
return string(repr[last+posSign:]) + "E" + strconv.Itoa(exp)
}
return string(repr[last+posSign:])
}
// BigInt returns significand as big.Int and exponent, bi * 10 ^ exp.
func (d Decimal128) BigInt() (*big.Int, int, error) {
high, low := d.GetBytes()
posSign := high>>63&1 == 0 // positive sign
switch high >> 58 & (1<<5 - 1) {
case 0x1F:
return nil, 0, ErrParseNaN
case 0x1E:
if posSign {
return nil, 0, ErrParseInf
}
return nil, 0, ErrParseNegInf
}
var exp int
if high>>61&3 == 3 {
// Bits: 1*sign 2*ignored 14*exponent 111*significand.
// Implicit 0b100 prefix in significand.
exp = int(high >> 47 & (1<<14 - 1))
//high = 4<<47 | d.h&(1<<47-1)
// Spec says all of these values are out of range.
high, low = 0, 0
} else {
// Bits: 1*sign 14*exponent 113*significand
exp = int(high >> 49 & (1<<14 - 1))
high = high & (1<<49 - 1)
}
exp += MinDecimal128Exp
// Would be handled by the logic below, but that's trivial and common.
if high == 0 && low == 0 && exp == 0 {
if posSign {
return new(big.Int), 0, nil
}
return new(big.Int), 0, nil
}
bi := big.NewInt(0)
const host32bit = ^uint(0)>>32 == 0
if host32bit {
bi.SetBits([]big.Word{big.Word(low), big.Word(low >> 32), big.Word(high), big.Word(high >> 32)})
} else {
bi.SetBits([]big.Word{big.Word(low), big.Word(high)})
}
if !posSign {
return bi.Neg(bi), exp, nil
}
return bi, exp, nil
}
// IsNaN returns whether d is NaN.
func (d Decimal128) IsNaN() bool {
return d.h>>58&(1<<5-1) == 0x1F
}
// IsInf returns:
//
// +1 d == Infinity
// 0 other case
// -1 d == -Infinity
func (d Decimal128) IsInf() int {
if d.h>>58&(1<<5-1) != 0x1E {
return 0
}
if d.h>>63&1 == 0 {
return 1
}
return -1
}
// IsZero returns true if d is the empty Decimal128.
func (d Decimal128) IsZero() bool {
return d.h == 0 && d.l == 0
}
// MarshalJSON returns Decimal128 as a string.
func (d Decimal128) MarshalJSON() ([]byte, error) {
return json.Marshal(d.String())
}
// UnmarshalJSON creates a primitive.Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
// be unchanged.
func (d *Decimal128) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
// enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
var res interface{}
err := json.Unmarshal(b, &res)
if err != nil {
return err
}
str, ok := res.(string)
// Extended JSON
if !ok {
m, ok := res.(map[string]interface{})
if !ok {
return errors.New("not an extended JSON Decimal128: expected document")
}
d128, ok := m["$numberDecimal"]
if !ok {
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
}
str, ok = d128.(string)
if !ok {
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
}
}
*d, err = ParseDecimal128(str)
return err
}
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
div64 := uint64(div)
a := h >> 32
aq := a / div64
ar := a % div64
b := ar<<32 + h&(1<<32-1)
bq := b / div64
br := b % div64
c := br<<32 + l>>32
cq := c / div64
cr := c % div64
d := cr<<32 + l&(1<<32-1)
dq := d / div64
dr := d % div64
return (aq<<32 | bq), (cq<<32 | dq), uint32(dr)
}
var dNaN = Decimal128{0x1F << 58, 0}
var dPosInf = Decimal128{0x1E << 58, 0}
var dNegInf = Decimal128{0x3E << 58, 0}
func dErr(s string) (Decimal128, error) {
return dNaN, fmt.Errorf("cannot parse %q as a decimal128", s)
}
// match scientific notation number, example -10.15e-18
var normalNumber = regexp.MustCompile(`^(?P<int>[-+]?\d*)?(?:\.(?P<dec>\d*))?(?:[Ee](?P<exp>[-+]?\d+))?$`)
// ParseDecimal128 takes the given string and attempts to parse it into a valid
// Decimal128 value.
func ParseDecimal128(s string) (Decimal128, error) {
if s == "" {
return dErr(s)
}
matches := normalNumber.FindStringSubmatch(s)
if len(matches) == 0 {
orig := s
neg := s[0] == '-'
if neg || s[0] == '+' {
s = s[1:]
}
if s == "NaN" || s == "nan" || strings.EqualFold(s, "nan") {
return dNaN, nil
}
if s == "Inf" || s == "inf" || strings.EqualFold(s, "inf") || strings.EqualFold(s, "infinity") {
if neg {
return dNegInf, nil
}
return dPosInf, nil
}
return dErr(orig)
}
intPart := matches[1]
decPart := matches[2]
expPart := matches[3]
var err error
exp := 0
if expPart != "" {
exp, err = strconv.Atoi(expPart)
if err != nil {
return dErr(s)
}
}
if decPart != "" {
exp -= len(decPart)
}
if len(strings.Trim(intPart+decPart, "-0")) > 35 {
return dErr(s)
}
// Parse the significand (i.e. the non-exponent part) as a big.Int.
bi, ok := new(big.Int).SetString(intPart+decPart, 10)
if !ok {
return dErr(s)
}
d, ok := ParseDecimal128FromBigInt(bi, exp)
if !ok {
return dErr(s)
}
if bi.Sign() == 0 && s[0] == '-' {
d.h |= 1 << 63
}
return d, nil
}
var (
ten = big.NewInt(10)
zero = new(big.Int)
maxS, _ = new(big.Int).SetString("9999999999999999999999999999999999", 10)
)
// ParseDecimal128FromBigInt attempts to parse the given significand and exponent into a valid Decimal128 value.
func ParseDecimal128FromBigInt(bi *big.Int, exp int) (Decimal128, bool) {
//copy
bi = new(big.Int).Set(bi)
q := new(big.Int)
r := new(big.Int)
// If the significand is zero, the logical value will always be zero, independent of the
// exponent. However, the loops for handling out-of-range exponent values below may be extremely
// slow for zero values because the significand never changes. Limit the exponent value to the
// supported range here to prevent entering the loops below.
if bi.Cmp(zero) == 0 {
if exp > MaxDecimal128Exp {
exp = MaxDecimal128Exp
}
if exp < MinDecimal128Exp {
exp = MinDecimal128Exp
}
}
for bigIntCmpAbs(bi, maxS) == 1 {
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
if exp > MaxDecimal128Exp {
return Decimal128{}, false
}
}
for exp < MinDecimal128Exp {
// Subnormal.
bi, _ = q.QuoRem(bi, ten, r)
if r.Cmp(zero) != 0 {
return Decimal128{}, false
}
exp++
}
for exp > MaxDecimal128Exp {
// Clamped.
bi.Mul(bi, ten)
if bigIntCmpAbs(bi, maxS) == 1 {
return Decimal128{}, false
}
exp--
}
b := bi.Bytes()
var h, l uint64
for i := 0; i < len(b); i++ {
if i < len(b)-8 {
h = h<<8 | uint64(b[i])
continue
}
l = l<<8 | uint64(b[i])
}
h |= uint64(exp-MinDecimal128Exp) & uint64(1<<14-1) << 49
if bi.Sign() == -1 {
h |= 1 << 63
}
return Decimal128{h: h, l: l}, true
}
// bigIntCmpAbs computes big.Int.Cmp(absoluteValue(x), absoluteValue(y)).
func bigIntCmpAbs(x, y *big.Int) int {
xAbs := bigIntAbsValue(x)
yAbs := bigIntAbsValue(y)
return xAbs.Cmp(yAbs)
}
// bigIntAbsValue returns a big.Int containing the absolute value of b.
// If b is already a non-negative number, it is returned without any changes or copies.
func bigIntAbsValue(b *big.Int) *big.Int {
if b.Sign() >= 0 {
return b // already positive
}
return new(big.Int).Abs(b)
}

View File

@@ -0,0 +1,236 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package primitive
import (
"encoding/json"
"fmt"
"math/big"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type bigIntTestCase struct {
s string
h uint64
l uint64
bi *big.Int
exp int
remark string
}
func parseBigInt(s string) *big.Int {
bi, _ := new(big.Int).SetString(s, 10)
return bi
}
var (
one = big.NewInt(1)
biMaxS = new(big.Int).SetBytes([]byte{0x1, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
biNMaxS = new(big.Int).Neg(biMaxS)
biOverflow = new(big.Int).Add(biMaxS, one)
biNOverflow = new(big.Int).Neg(biOverflow)
bi12345 = parseBigInt("12345")
biN12345 = parseBigInt("-12345")
bi9_14 = parseBigInt("90123456789012")
biN9_14 = parseBigInt("-90123456789012")
bi9_34 = parseBigInt("9999999999999999999999999999999999")
biN9_34 = parseBigInt("-9999999999999999999999999999999999")
)
var bigIntTestCases = []bigIntTestCase{
{s: "12345", h: 0x3040000000000000, l: 12345, bi: bi12345},
{s: "-12345", h: 0xB040000000000000, l: 12345, bi: biN12345},
{s: "90123456.789012", h: 0x3034000000000000, l: 90123456789012, bi: bi9_14, exp: -6},
{s: "-90123456.789012", h: 0xB034000000000000, l: 90123456789012, bi: biN9_14, exp: -6},
{s: "9.0123456789012E+22", h: 0x3052000000000000, l: 90123456789012, bi: bi9_14, exp: 9},
{s: "-9.0123456789012E+22", h: 0xB052000000000000, l: 90123456789012, bi: biN9_14, exp: 9},
{s: "9.0123456789012E-8", h: 0x3016000000000000, l: 90123456789012, bi: bi9_14, exp: -21},
{s: "-9.0123456789012E-8", h: 0xB016000000000000, l: 90123456789012, bi: biN9_14, exp: -21},
{s: "9999999999999999999999999999999999", h: 3477321013416265664, l: 4003012203950112767, bi: bi9_34},
{s: "-9999999999999999999999999999999999", h: 12700693050271041472, l: 4003012203950112767, bi: biN9_34},
{s: "0.9999999999999999999999999999999999", h: 3458180714999941056, l: 4003012203950112767, bi: bi9_34, exp: -34},
{s: "-0.9999999999999999999999999999999999", h: 12681552751854716864, l: 4003012203950112767, bi: biN9_34, exp: -34},
{s: "99999999999999999.99999999999999999", h: 3467750864208103360, l: 4003012203950112767, bi: bi9_34, exp: -17},
{s: "-99999999999999999.99999999999999999", h: 12691122901062879168, l: 4003012203950112767, bi: biN9_34, exp: -17},
{s: "9.999999999999999999999999999999999E+35", h: 3478446913323108288, l: 4003012203950112767, bi: bi9_34, exp: 2},
{s: "-9.999999999999999999999999999999999E+35", h: 12701818950177884096, l: 4003012203950112767, bi: biN9_34, exp: 2},
{s: "9.999999999999999999999999999999999E+40", h: 3481261663090214848, l: 4003012203950112767, bi: bi9_34, exp: 7},
{s: "-9.999999999999999999999999999999999E+40", h: 12704633699944990656, l: 4003012203950112767, bi: biN9_34, exp: 7},
{s: "99999999999999999999999999999.99999", h: 3474506263649159104, l: 4003012203950112767, bi: bi9_34, exp: -5},
{s: "-99999999999999999999999999999.99999", h: 12697878300503934912, l: 4003012203950112767, bi: biN9_34, exp: -5},
{s: "1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x333333333333, l: 0x3333333333333333, bi: parseBigInt("10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
{s: "-1.038459371706965525706099265844019E-6143", remark: "subnormal", h: 0x8000333333333333, l: 0x3333333333333333, bi: parseBigInt("-10384593717069655257060992658440190"), exp: MinDecimal128Exp - 1},
{s: "rounding overflow 1", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
{s: "rounding overflow 2", remark: "overflow", bi: parseBigInt("103845937170696552570609926584401910"), exp: MaxDecimal128Exp},
{s: "subnormal overflow 1", remark: "overflow", bi: biMaxS, exp: MinDecimal128Exp - 1},
{s: "subnormal overflow 2", remark: "overflow", bi: biNMaxS, exp: MinDecimal128Exp - 1},
{s: "clamped overflow 1", remark: "overflow", bi: biMaxS, exp: MaxDecimal128Exp + 1},
{s: "clamped overflow 2", remark: "overflow", bi: biNMaxS, exp: MaxDecimal128Exp + 1},
{s: "biMaxS+1 overflow", remark: "overflow", bi: biOverflow, exp: MaxDecimal128Exp},
{s: "biNMaxS-1 overflow", remark: "overflow", bi: biNOverflow, exp: MaxDecimal128Exp},
{s: "NaN", h: 0x7c00000000000000, l: 0, remark: "NaN"},
{s: "Infinity", h: 0x7800000000000000, l: 0, remark: "Infinity"},
{s: "-Infinity", h: 0xf800000000000000, l: 0, remark: "-Infinity"},
}
func TestDecimal128_BigInt(t *testing.T) {
for _, c := range bigIntTestCases {
t.Run(c.s, func(t *testing.T) {
switch c.remark {
case "NaN", "Infinity", "-Infinity":
d128 := NewDecimal128(c.h, c.l)
_, _, err := d128.BigInt()
require.Error(t, err, "case %s", c.s)
case "":
d128 := NewDecimal128(c.h, c.l)
bi, e, err := d128.BigInt()
require.NoError(t, err, "case %s", c.s)
require.Equal(t, 0, c.bi.Cmp(bi), "case %s e:%s a:%s", c.s, c.bi.String(), bi.String())
require.Equal(t, c.exp, e, "case %s", c.s, d128.String())
}
})
}
}
func TestParseDecimal128FromBigInt(t *testing.T) {
for _, c := range bigIntTestCases {
switch c.remark {
case "overflow":
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
require.Equal(t, false, ok, "case %s %s", c.s, d128.String(), c.remark)
case "", "rounding", "subnormal", "clamped":
d128, ok := ParseDecimal128FromBigInt(c.bi, c.exp)
require.Equal(t, true, ok, "case %s", c.s)
require.Equal(t, c.s, d128.String(), "case %s", c.s)
require.Equal(t, c.h, d128.h, "case %s", c.s, d128.l)
require.Equal(t, c.l, d128.l, "case %s", c.s, d128.h)
}
}
}
func TestParseDecimal128(t *testing.T) {
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
cases = append(cases, bigIntTestCases...)
cases = append(cases,
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125},
// Test that parsing negative zero returns negative zero with a zero exponent.
bigIntTestCase{s: "-0", h: 0xb040000000000000, l: 0},
// Test that parsing negative zero with an in-range exponent returns negative zero and
// preserves the specified exponent value.
bigIntTestCase{s: "-0E999", h: 0xb80e000000000000, l: 0},
// Test that parsing zero with an out-of-range positive exponent returns zero with the
// maximum positive exponent (i.e. 0e+6111).
bigIntTestCase{s: "0E2000000000000", h: 0x5ffe000000000000, l: 0},
// Test that parsing zero with an out-of-range negative exponent returns zero with the
// minimum negative exponent (i.e. 0e-6176).
bigIntTestCase{s: "-0E2000000000000", h: 0xdffe000000000000, l: 0},
bigIntTestCase{s: "", remark: "parse fail"})
for _, c := range cases {
t.Run(c.s, func(t *testing.T) {
switch c.remark {
case "overflow", "parse fail":
_, err := ParseDecimal128(c.s)
assert.Error(t, err, "ParseDecimal128(%q) should return an error", c.s)
default:
got, err := ParseDecimal128(c.s)
require.NoError(t, err, "ParseDecimal128(%q) error", c.s)
want := Decimal128{h: c.h, l: c.l}
// Decimal128 doesn't implement an equality function, so compare the expected
// low/high uint64 values directly. Also print the string representation of each
// number to make debugging failures easier.
assert.Equal(t, want, got, "ParseDecimal128(%q) = %s, want %s", c.s, got, want)
}
})
}
}
func TestDecimal128_JSON(t *testing.T) {
t.Run("roundTrip", func(t *testing.T) {
decimal := NewDecimal128(0x3040000000000000, 12345)
bytes, err := json.Marshal(decimal)
assert.Nil(t, err, "json.Marshal error: %v", err)
got := NewDecimal128(0, 0)
err = json.Unmarshal(bytes, &got)
assert.Nil(t, err, "json.Unmarshal error: %v", err)
assert.Equal(t, decimal.h, got.h, "expected h: %v got: %v", decimal.h, got.h)
assert.Equal(t, decimal.l, got.l, "expected l: %v got: %v", decimal.l, got.l)
})
t.Run("unmarshal extendedJSON", func(t *testing.T) {
want := NewDecimal128(0x3040000000000000, 12345)
extJSON := fmt.Sprintf(`{"$numberDecimal": %q}`, want.String())
got := NewDecimal128(0, 0)
err := json.Unmarshal([]byte(extJSON), &got)
assert.Nil(t, err, "json.Unmarshal error: %v", err)
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
})
t.Run("unmarshal null", func(t *testing.T) {
want := NewDecimal128(0, 0)
extJSON := `null`
got := NewDecimal128(0, 0)
err := json.Unmarshal([]byte(extJSON), &got)
assert.Nil(t, err, "json.Unmarshal error: %v", err)
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
})
t.Run("unmarshal", func(t *testing.T) {
cases := make([]bigIntTestCase, 0, len(bigIntTestCases))
cases = append(cases, bigIntTestCases...)
cases = append(cases,
bigIntTestCase{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
bigIntTestCase{s: "12345E+21", h: 0x306a000000000000, l: 12345},
bigIntTestCase{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
bigIntTestCase{s: ".125e1", h: 0x303c000000000000, l: 125},
bigIntTestCase{s: ".125", h: 0x303a000000000000, l: 125})
for _, c := range cases {
t.Run(c.s, func(t *testing.T) {
input := fmt.Sprintf(`{"foo": %q}`, c.s)
var got map[string]Decimal128
err := json.Unmarshal([]byte(input), &got)
switch c.remark {
case "overflow", "parse fail":
assert.NotNil(t, err, "expected Unmarshal error, got nil")
default:
assert.Nil(t, err, "Unmarshal error: %v", err)
gotDecimal := got["foo"]
assert.Equal(t, c.h, gotDecimal.h, "expected h: %v got: %v", c.h, gotDecimal.l)
assert.Equal(t, c.l, gotDecimal.l, "expected l: %v got: %v", c.l, gotDecimal.h)
}
})
}
})
}

View File

@@ -0,0 +1,206 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package primitive
import (
"crypto/rand"
"encoding"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"sync/atomic"
"time"
)
// ErrInvalidHex indicates that a hex string cannot be converted to an ObjectID.
var ErrInvalidHex = errors.New("the provided hex string is not a valid ObjectID")
// ObjectID is the BSON ObjectID type.
type ObjectID [12]byte
// NilObjectID is the zero value for ObjectID.
var NilObjectID ObjectID
var objectIDCounter = readRandomUint32()
var processUnique = processUniqueBytes()
var _ encoding.TextMarshaler = ObjectID{}
var _ encoding.TextUnmarshaler = &ObjectID{}
// NewObjectID generates a new ObjectID.
func NewObjectID() ObjectID {
return NewObjectIDFromTimestamp(time.Now())
}
// NewObjectIDFromTimestamp generates a new ObjectID based on the given time.
func NewObjectIDFromTimestamp(timestamp time.Time) ObjectID {
var b [12]byte
binary.BigEndian.PutUint32(b[0:4], uint32(timestamp.Unix()))
copy(b[4:9], processUnique[:])
putUint24(b[9:12], atomic.AddUint32(&objectIDCounter, 1))
return b
}
// Timestamp extracts the time part of the ObjectId.
func (id ObjectID) Timestamp() time.Time {
unixSecs := binary.BigEndian.Uint32(id[0:4])
return time.Unix(int64(unixSecs), 0).UTC()
}
// Hex returns the hex encoding of the ObjectID as a string.
func (id ObjectID) Hex() string {
var buf [24]byte
hex.Encode(buf[:], id[:])
return string(buf[:])
}
func (id ObjectID) String() string {
return fmt.Sprintf("ObjectID(%q)", id.Hex())
}
// IsZero returns true if id is the empty ObjectID.
func (id ObjectID) IsZero() bool {
return id == NilObjectID
}
// ObjectIDFromHex creates a new ObjectID from a hex string. It returns an error if the hex string is not a
// valid ObjectID.
func ObjectIDFromHex(s string) (ObjectID, error) {
if len(s) != 24 {
return NilObjectID, ErrInvalidHex
}
b, err := hex.DecodeString(s)
if err != nil {
return NilObjectID, err
}
var oid [12]byte
copy(oid[:], b)
return oid, nil
}
// IsValidObjectID returns true if the provided hex string represents a valid ObjectID and false if not.
func IsValidObjectID(s string) bool {
_, err := ObjectIDFromHex(s)
return err == nil
}
// MarshalText returns the ObjectID as UTF-8-encoded text. Implementing this allows us to use ObjectID
// as a map key when marshalling JSON. See https://pkg.go.dev/encoding#TextMarshaler
func (id ObjectID) MarshalText() ([]byte, error) {
return []byte(id.Hex()), nil
}
// UnmarshalText populates the byte slice with the ObjectID. Implementing this allows us to use ObjectID
// as a map key when unmarshalling JSON. See https://pkg.go.dev/encoding#TextUnmarshaler
func (id *ObjectID) UnmarshalText(b []byte) error {
oid, err := ObjectIDFromHex(string(b))
if err != nil {
return err
}
*id = oid
return nil
}
// MarshalJSON returns the ObjectID as a string
func (id ObjectID) MarshalJSON() ([]byte, error) {
return json.Marshal(id.Hex())
}
// UnmarshalJSON populates the byte slice with the ObjectID. If the byte slice is 24 bytes long, it
// will be populated with the hex representation of the ObjectID. If the byte slice is twelve bytes
// long, it will be populated with the BSON representation of the ObjectID. This method also accepts empty strings and
// decodes them as NilObjectID. For any other inputs, an error will be returned.
func (id *ObjectID) UnmarshalJSON(b []byte) error {
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer ObjectID field
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
// enter the UnmarshalJSON hook.
if string(b) == "null" {
return nil
}
var err error
switch len(b) {
case 12:
copy(id[:], b)
default:
// Extended JSON
var res interface{}
err := json.Unmarshal(b, &res)
if err != nil {
return err
}
str, ok := res.(string)
if !ok {
m, ok := res.(map[string]interface{})
if !ok {
return errors.New("not an extended JSON ObjectID")
}
oid, ok := m["$oid"]
if !ok {
return errors.New("not an extended JSON ObjectID")
}
str, ok = oid.(string)
if !ok {
return errors.New("not an extended JSON ObjectID")
}
}
// An empty string is not a valid ObjectID, but we treat it as a special value that decodes as NilObjectID.
if len(str) == 0 {
copy(id[:], NilObjectID[:])
return nil
}
if len(str) != 24 {
return fmt.Errorf("cannot unmarshal into an ObjectID, the length must be 24 but it is %d", len(str))
}
_, err = hex.Decode(id[:], []byte(str))
if err != nil {
return err
}
}
return err
}
func processUniqueBytes() [5]byte {
var b [5]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
}
return b
}
func readRandomUint32() uint32 {
var b [4]byte
_, err := io.ReadFull(rand.Reader, b[:])
if err != nil {
panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err))
}
return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
}
func putUint24(b []byte, v uint32) {
b[0] = byte(v >> 16)
b[1] = byte(v >> 8)
b[2] = byte(v)
}

View File

@@ -0,0 +1,259 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package primitive
import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestNew(t *testing.T) {
// Ensure that objectid.NewObjectID() doesn't panic.
NewObjectID()
}
func TestString(t *testing.T) {
id := NewObjectID()
require.Contains(t, id.String(), id.Hex())
}
func BenchmarkHex(b *testing.B) {
id := NewObjectID()
for i := 0; i < b.N; i++ {
id.Hex()
}
}
func TestFromHex_RoundTrip(t *testing.T) {
before := NewObjectID()
after, err := ObjectIDFromHex(before.Hex())
require.NoError(t, err)
require.Equal(t, before, after)
}
func TestFromHex_InvalidHex(t *testing.T) {
_, err := ObjectIDFromHex("this is not a valid hex string!")
require.Error(t, err)
}
func TestFromHex_WrongLength(t *testing.T) {
_, err := ObjectIDFromHex("deadbeef")
require.Equal(t, ErrInvalidHex, err)
}
func TestIsValidObjectID(t *testing.T) {
testCases := []struct {
givenID string
expected bool
}{
{
givenID: "5ef7fdd91c19e3222b41b839",
expected: true,
},
{
givenID: "5ef7fdd91c19e3222b41b83",
expected: false,
},
}
for _, testcase := range testCases {
got := IsValidObjectID(testcase.givenID)
assert.Equal(t, testcase.expected, got, "expected hex string to be valid ObjectID: %v, got %v", testcase.expected, got)
}
}
func TestTimeStamp(t *testing.T) {
testCases := []struct {
Hex string
Expected string
}{
{
"000000001111111111111111",
"1970-01-01 00:00:00 +0000 UTC",
},
{
"7FFFFFFF1111111111111111",
"2038-01-19 03:14:07 +0000 UTC",
},
{
"800000001111111111111111",
"2038-01-19 03:14:08 +0000 UTC",
},
{
"FFFFFFFF1111111111111111",
"2106-02-07 06:28:15 +0000 UTC",
},
}
for _, testcase := range testCases {
id, err := ObjectIDFromHex(testcase.Hex)
require.NoError(t, err)
secs := int64(binary.BigEndian.Uint32(id[0:4]))
timestamp := time.Unix(secs, 0).UTC()
require.Equal(t, testcase.Expected, timestamp.String())
}
}
func TestCreateFromTime(t *testing.T) {
testCases := []struct {
time string
Expected string
}{
{
"1970-01-01T00:00:00.000Z",
"00000000",
},
{
"2038-01-19T03:14:07.000Z",
"7fffffff",
},
{
"2038-01-19T03:14:08.000Z",
"80000000",
},
{
"2106-02-07T06:28:15.000Z",
"ffffffff",
},
}
layout := "2006-01-02T15:04:05.000Z"
for _, testcase := range testCases {
time, err := time.Parse(layout, testcase.time)
require.NoError(t, err)
id := NewObjectIDFromTimestamp(time)
timeStr := hex.EncodeToString(id[0:4])
require.Equal(t, testcase.Expected, timeStr)
}
}
func TestGenerationTime(t *testing.T) {
testCases := []struct {
hex string
Expected string
}{
{
"000000001111111111111111",
"1970-01-01 00:00:00 +0000 UTC",
},
{
"7FFFFFFF1111111111111111",
"2038-01-19 03:14:07 +0000 UTC",
},
{
"800000001111111111111111",
"2038-01-19 03:14:08 +0000 UTC",
},
{
"FFFFFFFF1111111111111111",
"2106-02-07 06:28:15 +0000 UTC",
},
}
for _, testcase := range testCases {
id, err := ObjectIDFromHex(testcase.hex)
require.NoError(t, err)
genTime := id.Timestamp()
require.Equal(t, testcase.Expected, genTime.String())
}
}
func TestCounterOverflow(t *testing.T) {
objectIDCounter = 0xFFFFFFFF
NewObjectID()
require.Equal(t, uint32(0), objectIDCounter)
}
func TestObjectID_MarshalJSONMap(t *testing.T) {
type mapOID struct {
Map map[ObjectID]string
}
oid := NewObjectID()
expectedJSON := []byte(fmt.Sprintf(`{"Map":{%q:"foo"}}`, oid.Hex()))
data := mapOID{
Map: map[ObjectID]string{oid: "foo"},
}
out, err := json.Marshal(&data)
require.NoError(t, err)
require.Equal(t, expectedJSON, out)
}
func TestObjectID_UnmarshalJSONMap(t *testing.T) {
type mapOID struct {
Map map[ObjectID]string
}
oid := NewObjectID()
mapOIDJSON := []byte(fmt.Sprintf(`{"Map":{%q:"foo"}}`, oid.Hex()))
expectedData := mapOID{
Map: map[ObjectID]string{oid: "foo"},
}
data := mapOID{}
err := json.Unmarshal(mapOIDJSON, &data)
require.NoError(t, err)
require.Equal(t, expectedData, data)
}
func TestObjectID_UnmarshalJSON(t *testing.T) {
oid := NewObjectID()
hexJSON := fmt.Sprintf(`{"foo": %q}`, oid.Hex())
extJSON := fmt.Sprintf(`{"foo": {"$oid": %q}}`, oid.Hex())
emptyStringJSON := `{"foo": ""}`
nullJSON := `{"foo": null}`
testCases := []struct {
name string
jsonString string
expected ObjectID
}{
{"hex bytes", hexJSON, oid},
{"extended JSON", extJSON, oid},
{"empty string", emptyStringJSON, NilObjectID},
{"null", nullJSON, NilObjectID},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got map[string]ObjectID
err := json.Unmarshal([]byte(tc.jsonString), &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
gotOid := got["foo"]
assert.Equal(t, tc.expected, gotOid, "expected ObjectID %s, got %s", tc.expected, gotOid)
})
}
}
func TestObjectID_MarshalText(t *testing.T) {
oid := ObjectID{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB}
b, err := oid.MarshalText()
assert.Nil(t, err, "MarshalText error: %v", err)
want := "000102030405060708090a0b"
got := string(b)
assert.Equal(t, want, got, "want %v, got %v", want, got)
}
func TestObjectID_UnmarshalText(t *testing.T) {
var oid ObjectID
err := oid.UnmarshalText([]byte("000102030405060708090a0b"))
assert.Nil(t, err, "UnmarshalText error: %v", err)
want := ObjectID{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xB}
assert.Equal(t, want, oid, "want %v, got %v", want, oid)
}

View File

@@ -0,0 +1,217 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package primitive contains types similar to Go primitives for BSON types that do not have direct
// Go primitive representations.
package primitive // import "go.mongodb.org/mongo-driver/bson/primitive"
import (
"bytes"
"encoding/json"
"fmt"
"time"
)
// Binary represents a BSON binary value.
type Binary struct {
Subtype byte
Data []byte
}
// Equal compares bp to bp2 and returns true if they are equal.
func (bp Binary) Equal(bp2 Binary) bool {
if bp.Subtype != bp2.Subtype {
return false
}
return bytes.Equal(bp.Data, bp2.Data)
}
// IsZero returns if bp is the empty Binary.
func (bp Binary) IsZero() bool {
return bp.Subtype == 0 && len(bp.Data) == 0
}
// Undefined represents the BSON undefined value type.
type Undefined struct{}
// DateTime represents the BSON datetime value.
type DateTime int64
var _ json.Marshaler = DateTime(0)
var _ json.Unmarshaler = (*DateTime)(nil)
// MarshalJSON marshal to time type.
func (d DateTime) MarshalJSON() ([]byte, error) {
return json.Marshal(d.Time())
}
// UnmarshalJSON creates a primitive.DateTime from a JSON string.
func (d *DateTime) UnmarshalJSON(data []byte) error {
// Ignore "null" to keep parity with the time.Time type and the standard library. Decoding "null" into a non-pointer
// DateTime field will leave the field unchanged. For pointer values, the encoding/json will set the pointer to nil
// and will not defer to the UnmarshalJSON hook.
if string(data) == "null" {
return nil
}
var tempTime time.Time
if err := json.Unmarshal(data, &tempTime); err != nil {
return err
}
*d = NewDateTimeFromTime(tempTime)
return nil
}
// Time returns the date as a time type.
func (d DateTime) Time() time.Time {
return time.Unix(int64(d)/1000, int64(d)%1000*1000000)
}
// NewDateTimeFromTime creates a new DateTime from a Time.
func NewDateTimeFromTime(t time.Time) DateTime {
return DateTime(t.Unix()*1e3 + int64(t.Nanosecond())/1e6)
}
// Null represents the BSON null value.
type Null struct{}
// Regex represents a BSON regex value.
type Regex struct {
Pattern string
Options string
}
func (rp Regex) String() string {
return fmt.Sprintf(`{"pattern": "%s", "options": "%s"}`, rp.Pattern, rp.Options)
}
// Equal compares rp to rp2 and returns true if they are equal.
func (rp Regex) Equal(rp2 Regex) bool {
return rp.Pattern == rp2.Pattern && rp.Options == rp2.Options
}
// IsZero returns if rp is the empty Regex.
func (rp Regex) IsZero() bool {
return rp.Pattern == "" && rp.Options == ""
}
// DBPointer represents a BSON dbpointer value.
type DBPointer struct {
DB string
Pointer ObjectID
}
func (d DBPointer) String() string {
return fmt.Sprintf(`{"db": "%s", "pointer": "%s"}`, d.DB, d.Pointer)
}
// Equal compares d to d2 and returns true if they are equal.
func (d DBPointer) Equal(d2 DBPointer) bool {
return d == d2
}
// IsZero returns if d is the empty DBPointer.
func (d DBPointer) IsZero() bool {
return d.DB == "" && d.Pointer.IsZero()
}
// JavaScript represents a BSON JavaScript code value.
type JavaScript string
// Symbol represents a BSON symbol value.
type Symbol string
// CodeWithScope represents a BSON JavaScript code with scope value.
type CodeWithScope struct {
Code JavaScript
Scope interface{}
}
func (cws CodeWithScope) String() string {
return fmt.Sprintf(`{"code": "%s", "scope": %v}`, cws.Code, cws.Scope)
}
// Timestamp represents a BSON timestamp value.
type Timestamp struct {
T uint32
I uint32
}
// Equal compares tp to tp2 and returns true if they are equal.
func (tp Timestamp) Equal(tp2 Timestamp) bool {
return tp.T == tp2.T && tp.I == tp2.I
}
// IsZero returns if tp is the zero Timestamp.
func (tp Timestamp) IsZero() bool {
return tp.T == 0 && tp.I == 0
}
// CompareTimestamp returns an integer comparing two Timestamps, where T is compared first, followed by I.
// Returns 0 if tp = tp2, 1 if tp > tp2, -1 if tp < tp2.
func CompareTimestamp(tp, tp2 Timestamp) int {
if tp.Equal(tp2) {
return 0
}
if tp.T > tp2.T {
return 1
}
if tp.T < tp2.T {
return -1
}
// Compare I values because T values are equal
if tp.I > tp2.I {
return 1
}
return -1
}
// MinKey represents the BSON minkey value.
type MinKey struct{}
// MaxKey represents the BSON maxkey value.
type MaxKey struct{}
// D is an ordered representation of a BSON document. This type should be used when the order of the elements matters,
// such as MongoDB command documents. If the order of the elements does not matter, an M should be used instead.
//
// Example usage:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
type D []E
// Map creates a map from the elements of the D.
func (d D) Map() M {
m := make(M, len(d))
for _, e := range d {
m[e.Key] = e.Value
}
return m
}
// E represents a BSON element for a D. It is usually used inside a D.
type E struct {
Key string
Value interface{}
}
// M is an unordered representation of a BSON document. This type should be used when the order of the elements does not
// matter. This type is handled as a regular map[string]interface{} when encoding and decoding. Elements will be
// serialized in an undefined, random order. If the order of the elements matters, a D should be used instead.
//
// Example usage:
//
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
type M map[string]interface{}
// An A is an ordered representation of a BSON array.
//
// Example usage:
//
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
type A []interface{}

View File

@@ -0,0 +1,122 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package primitive
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
// The same interface as bsoncodec.Zeroer implemented for tests.
type zeroer interface {
IsZero() bool
}
func TestTimestampCompare(t *testing.T) {
testcases := []struct {
name string
tp Timestamp
tp2 Timestamp
expected int
}{
{"equal", Timestamp{T: 12345, I: 67890}, Timestamp{T: 12345, I: 67890}, 0},
{"T greater than", Timestamp{T: 12345, I: 67890}, Timestamp{T: 2345, I: 67890}, 1},
{"I greater than", Timestamp{T: 12345, I: 67890}, Timestamp{T: 12345, I: 7890}, 1},
{"T less than", Timestamp{T: 12345, I: 67890}, Timestamp{T: 112345, I: 67890}, -1},
{"I less than", Timestamp{T: 12345, I: 67890}, Timestamp{T: 12345, I: 167890}, -1},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
result := CompareTimestamp(tc.tp, tc.tp2)
require.Equal(t, tc.expected, result)
})
}
}
func TestPrimitiveIsZero(t *testing.T) {
testcases := []struct {
name string
zero zeroer
nonzero zeroer
}{
{"binary", Binary{}, Binary{Data: []byte{0x01, 0x02, 0x03}, Subtype: 0xFF}},
{"decimal128", Decimal128{}, NewDecimal128(1, 2)},
{"objectID", ObjectID{}, NewObjectID()},
{"regex", Regex{}, Regex{Pattern: "foo", Options: "bar"}},
{"dbPointer", DBPointer{}, DBPointer{DB: "foobar", Pointer: ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}}},
{"timestamp", Timestamp{}, Timestamp{T: 12345, I: 67890}},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
require.True(t, tc.zero.IsZero())
require.False(t, tc.nonzero.IsZero())
})
}
}
func TestRegexCompare(t *testing.T) {
testcases := []struct {
name string
r1 Regex
r2 Regex
eq bool
}{
{"equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar1"}, true},
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar2"}, false},
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo1", Options: "bar2"}, false},
{"not equal", Regex{Pattern: "foo1", Options: "bar1"}, Regex{Pattern: "foo2", Options: "bar1"}, false},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
require.True(t, tc.r1.Equal(tc.r2) == tc.eq)
})
}
}
func TestDateTime(t *testing.T) {
t.Run("json", func(t *testing.T) {
t.Run("round trip", func(t *testing.T) {
original := DateTime(1000)
jsonBytes, err := json.Marshal(original)
assert.Nil(t, err, "Marshal error: %v", err)
var unmarshalled DateTime
err = json.Unmarshal(jsonBytes, &unmarshalled)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, original, unmarshalled, "expected DateTime %v, got %v", original, unmarshalled)
})
t.Run("decode null", func(t *testing.T) {
jsonBytes := []byte("null")
var dt DateTime
err := json.Unmarshal(jsonBytes, &dt)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, DateTime(0), dt, "expected DateTime value to be 0, got %v", dt)
})
})
t.Run("NewDateTimeFromTime", func(t *testing.T) {
t.Run("range is not limited", func(t *testing.T) {
// If the implementation internally calls time.Time.UnixNano(), the constructor cannot handle times after
// the year 2262.
timeFormat := "2006-01-02T15:04:05.999Z07:00"
timeString := "3001-01-01T00:00:00Z"
tt, err := time.Parse(timeFormat, timeString)
assert.Nil(t, err, "Parse error: %v", err)
dt := NewDateTimeFromTime(tt)
assert.True(t, dt > 0, "expected a valid DateTime greater than 0, got %v", dt)
})
})
}