Copied mongo repo (to patch it)
This commit is contained in:
6
mongo/x/README.md
Normal file
6
mongo/x/README.md
Normal file
@@ -0,0 +1,6 @@
|
||||
MongoDB Go Driver Unstable Libraries
|
||||
====================================
|
||||
This directory contains unstable MongoDB Go driver libraries and packages. The APIs of these
|
||||
packages are not stable and there is no backward compatibility guarantee.
|
||||
|
||||
**THESE PACKAGES ARE EXPERIMENTAL AND SUBJECT TO CHANGE.**
|
||||
97
mongo/x/bsonx/array.go
Normal file
97
mongo/x/bsonx/array.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx // import "go.mongodb.org/mongo-driver/x/bsonx"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// ErrNilArray indicates that an operation was attempted on a nil *Array.
|
||||
var ErrNilArray = errors.New("array is nil")
|
||||
|
||||
// Arr represents an array in BSON.
|
||||
type Arr []Val
|
||||
|
||||
// String implements the fmt.Stringer interface.
|
||||
func (a Arr) String() string {
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte("bson.Array["))
|
||||
for idx, val := range a {
|
||||
if idx > 0 {
|
||||
buf.Write([]byte(", "))
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", val)
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
|
||||
func (a Arr) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
if a == nil {
|
||||
// TODO: Should we do this?
|
||||
return bsontype.Null, nil, nil
|
||||
}
|
||||
|
||||
idx, dst := bsoncore.ReserveLength(nil)
|
||||
for idx, value := range a {
|
||||
t, data, _ := value.MarshalBSONValue() // marshalBSONValue never returns an error.
|
||||
dst = append(dst, byte(t))
|
||||
dst = append(dst, strconv.Itoa(idx)...)
|
||||
dst = append(dst, 0x00)
|
||||
dst = append(dst, data...)
|
||||
}
|
||||
dst = append(dst, 0x00)
|
||||
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||||
return bsontype.Array, dst, nil
|
||||
}
|
||||
|
||||
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.
|
||||
func (a *Arr) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
|
||||
if a == nil {
|
||||
return ErrNilArray
|
||||
}
|
||||
*a = (*a)[:0]
|
||||
|
||||
elements, err := bsoncore.Document(data).Elements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, elem := range elements {
|
||||
var val Val
|
||||
rawval := elem.Value()
|
||||
err = val.UnmarshalBSONValue(rawval.Type, rawval.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*a = append(*a, val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Equal compares this document to another, returning true if they are equal.
|
||||
func (a Arr) Equal(a2 Arr) bool {
|
||||
if len(a) != len(a2) {
|
||||
return false
|
||||
}
|
||||
for idx := range a {
|
||||
if !a[idx].Equal(a2[idx]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (Arr) idoc() {}
|
||||
36
mongo/x/bsonx/array_test.go
Normal file
36
mongo/x/bsonx/array_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ExampleArray() {
|
||||
internalVersion := "1234567"
|
||||
|
||||
f := func(appName string) Arr {
|
||||
arr := make(Arr, 0)
|
||||
arr = append(arr,
|
||||
Document(Doc{{"name", String("mongo-go-driver")}, {"version", String(internalVersion)}}),
|
||||
Document(Doc{{"type", String("darwin")}, {"architecture", String("amd64")}}),
|
||||
String("go1.9.2"),
|
||||
)
|
||||
if appName != "" {
|
||||
arr = append(arr, Document(MDoc{"name": String(appName)}))
|
||||
}
|
||||
|
||||
return arr
|
||||
}
|
||||
_, buf, err := f("hello-world").MarshalBSONValue()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fmt.Println(buf)
|
||||
|
||||
// Output: [154 0 0 0 3 48 0 52 0 0 0 2 110 97 109 101 0 16 0 0 0 109 111 110 103 111 45 103 111 45 100 114 105 118 101 114 0 2 118 101 114 115 105 111 110 0 8 0 0 0 49 50 51 52 53 54 55 0 0 3 49 0 46 0 0 0 2 116 121 112 101 0 7 0 0 0 100 97 114 119 105 110 0 2 97 114 99 104 105 116 101 99 116 117 114 101 0 6 0 0 0 97 109 100 54 52 0 0 2 50 0 8 0 0 0 103 111 49 46 57 46 50 0 3 51 0 27 0 0 0 2 110 97 109 101 0 12 0 0 0 104 101 108 108 111 45 119 111 114 108 100 0 0 0]
|
||||
}
|
||||
52
mongo/x/bsonx/bson_test.go
Normal file
52
mongo/x/bsonx/bson_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func noerr(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Errorf("Unexpected error: (%T)%v", err, err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
|
||||
d1H, d1L := d1.GetBytes()
|
||||
d2H, d2L := d2.GetBytes()
|
||||
|
||||
if d1H != d2H {
|
||||
return false
|
||||
}
|
||||
|
||||
if d1L != d2L {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func compareErrors(err1, err2 error) bool {
|
||||
if err1 == nil && err2 == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if err1 == nil || err2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err1.Error() != err2.Error() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
164
mongo/x/bsonx/bsoncore/array.go
Normal file
164
mongo/x/bsonx/bsoncore/array.go
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// NewArrayLengthError creates and returns an error for when the length of an array exceeds the
|
||||
// bytes available.
|
||||
func NewArrayLengthError(length, rem int) error {
|
||||
return lengthError("array", length, rem)
|
||||
}
|
||||
|
||||
// Array is a raw bytes representation of a BSON array.
|
||||
type Array []byte
|
||||
|
||||
// NewArrayFromReader reads an array from r. This function will only validate the length is
|
||||
// correct and that the array ends with a null byte.
|
||||
func NewArrayFromReader(r io.Reader) (Array, error) {
|
||||
return newBufferFromReader(r)
|
||||
}
|
||||
|
||||
// Index searches for and retrieves the value at the given index. This method will panic if
|
||||
// the array is invalid or if the index is out of bounds.
|
||||
func (a Array) Index(index uint) Value {
|
||||
value, err := a.IndexErr(index)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// IndexErr searches for and retrieves the value at the given index.
|
||||
func (a Array) IndexErr(index uint) (Value, error) {
|
||||
elem, err := indexErr(a, index)
|
||||
if err != nil {
|
||||
return Value{}, err
|
||||
}
|
||||
return elem.Value(), err
|
||||
}
|
||||
|
||||
// DebugString outputs a human readable version of Array. It will attempt to stringify the
|
||||
// valid components of the array even if the entire array is not valid.
|
||||
func (a Array) DebugString() string {
|
||||
if len(a) < 5 {
|
||||
return "<malformed>"
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("Array")
|
||||
length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length
|
||||
buf.WriteByte('(')
|
||||
buf.WriteString(strconv.Itoa(int(length)))
|
||||
length -= 4
|
||||
buf.WriteString(")[")
|
||||
var elem Element
|
||||
var ok bool
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
|
||||
break
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", elem.Value().DebugString())
|
||||
if length != 1 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// String outputs an ExtendedJSON version of Array. If the Array is not valid, this method
|
||||
// returns an empty string.
|
||||
func (a Array) String() string {
|
||||
if len(a) < 5 {
|
||||
return ""
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('[')
|
||||
|
||||
length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length
|
||||
|
||||
length -= 4
|
||||
|
||||
var elem Element
|
||||
var ok bool
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", elem.Value().String())
|
||||
if length > 1 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
}
|
||||
if length != 1 { // Missing final null byte or inaccurate length
|
||||
return ""
|
||||
}
|
||||
|
||||
buf.WriteByte(']')
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Values returns this array as a slice of values. The returned slice will contain valid values.
|
||||
// If the array is not valid, the values up to the invalid point will be returned along with an
|
||||
// error.
|
||||
func (a Array) Values() ([]Value, error) {
|
||||
return values(a)
|
||||
}
|
||||
|
||||
// Validate validates the array and ensures the elements contained within are valid.
|
||||
func (a Array) Validate() error {
|
||||
length, rem, ok := ReadLength(a)
|
||||
if !ok {
|
||||
return NewInsufficientBytesError(a, rem)
|
||||
}
|
||||
if int(length) > len(a) {
|
||||
return NewArrayLengthError(int(length), len(a))
|
||||
}
|
||||
if a[length-1] != 0x00 {
|
||||
return ErrMissingNull
|
||||
}
|
||||
|
||||
length -= 4
|
||||
var elem Element
|
||||
|
||||
var keyNum int64
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return NewInsufficientBytesError(a, rem)
|
||||
}
|
||||
|
||||
// validate element
|
||||
err := elem.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// validate keys increase numerically
|
||||
if fmt.Sprint(keyNum) != elem.Key() {
|
||||
return fmt.Errorf("array key %q is out of order or invalid", elem.Key())
|
||||
}
|
||||
keyNum++
|
||||
}
|
||||
|
||||
if len(rem) < 1 || rem[0] != 0x00 {
|
||||
return ErrMissingNull
|
||||
}
|
||||
return nil
|
||||
}
|
||||
349
mongo/x/bsonx/bsoncore/array_test.go
Normal file
349
mongo/x/bsonx/bsoncore/array_test.go
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
func TestArray(t *testing.T) {
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
t.Run("TooShort", func(t *testing.T) {
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
got := Array{'\x00', '\x00'}.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("InvalidLength", func(t *testing.T) {
|
||||
want := NewArrayLengthError(200, 5)
|
||||
r := make(Array, 5)
|
||||
binary.LittleEndian.PutUint32(r[0:4], 200)
|
||||
got := r.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Invalid Element", func(t *testing.T) {
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
r := make(Array, 7)
|
||||
binary.LittleEndian.PutUint32(r[0:4], 7)
|
||||
r[4], r[5], r[6] = 0x02, 'f', 0x00
|
||||
got := r.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Missing Null Terminator", func(t *testing.T) {
|
||||
want := ErrMissingNull
|
||||
r := make(Array, 6)
|
||||
binary.LittleEndian.PutUint32(r[0:4], 6)
|
||||
r[4], r[5] = 0x0A, '0'
|
||||
got := r.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
r Array
|
||||
want error
|
||||
}{
|
||||
{"array null", Array{'\x08', '\x00', '\x00', '\x00', '\x0A', '0', '\x00', '\x00'}, nil},
|
||||
{"array",
|
||||
Array{
|
||||
'\x1B', '\x00', '\x00', '\x00',
|
||||
'\x02',
|
||||
'0', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x72', '\x00',
|
||||
'\x02',
|
||||
'1', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x7a', '\x00',
|
||||
'\x00',
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{"subarray",
|
||||
Array{
|
||||
'\x13', '\x00', '\x00', '\x00',
|
||||
'\x04',
|
||||
'0', '\x00',
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', '0', '\x00',
|
||||
'\x0A', '1', '\x00', '\x00', '\x00',
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{"invalid key order",
|
||||
Array{
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', '2', '\x00',
|
||||
'\x0A', '0', '\x00', '\x00', '\x00',
|
||||
},
|
||||
errors.New(`array key "2" is out of order or invalid`),
|
||||
},
|
||||
{"invalid key type",
|
||||
Array{
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', 'p', '\x00',
|
||||
'\x0A', 'q', '\x00', '\x00', '\x00',
|
||||
},
|
||||
errors.New(`array key "p" is out of order or invalid`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.r.Validate()
|
||||
if !compareErrors(got, tc.want) {
|
||||
t.Errorf("Returned error does not match. got %v; want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Index", func(t *testing.T) {
|
||||
t.Run("Out of bounds", func(t *testing.T) {
|
||||
rdr := Array{0xe, 0x0, 0x0, 0x0, 0xa, '0', 0x0, 0xa, '1', 0x0, 0xa, 0x7a, 0x0, 0x0}
|
||||
_, err := rdr.IndexErr(3)
|
||||
if err != ErrOutOfBounds {
|
||||
t.Errorf("Out of bounds should be returned when accessing element beyond end of Array. got %v; want %v", err, ErrOutOfBounds)
|
||||
}
|
||||
})
|
||||
t.Run("Validation Error", func(t *testing.T) {
|
||||
rdr := Array{0x07, 0x00, 0x00, 0x00, 0x00}
|
||||
_, got := rdr.IndexErr(1)
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
testArray := Array{
|
||||
'\x26', '\x00', '\x00', '\x00',
|
||||
'\x02',
|
||||
'0', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x72', '\x00',
|
||||
'\x02',
|
||||
'1', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x7a', '\x00',
|
||||
'\x02',
|
||||
'2', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x71', '\x75', '\x78', '\x00',
|
||||
'\x00',
|
||||
}
|
||||
testCases := []struct {
|
||||
name string
|
||||
index uint
|
||||
want Value
|
||||
}{
|
||||
{"first",
|
||||
0,
|
||||
Value{
|
||||
Type: bsontype.String,
|
||||
Data: []byte{0x04, 0x00, 0x00, 0x00, 0x62, 0x61, 0x72, 0x00},
|
||||
},
|
||||
},
|
||||
{"second",
|
||||
1,
|
||||
Value{
|
||||
Type: bsontype.String,
|
||||
Data: []byte{0x04, 0x00, 0x00, 0x00, 0x62, 0x61, 0x7a, 0x00},
|
||||
},
|
||||
},
|
||||
{"third",
|
||||
2,
|
||||
Value{
|
||||
Type: bsontype.String,
|
||||
Data: []byte{0x04, 0x00, 0x00, 0x00, 0x71, 0x75, 0x78, 0x00},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("IndexErr", func(t *testing.T) {
|
||||
got, err := testArray.IndexErr(tc.index)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error from IndexErr: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("Arrays differ: (-got +want)\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("Index", func(t *testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
got := testArray.Index(tc.index)
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("Arrays differ: (-got +want)\n%s", diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("NewArrayFromReader", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ioReader io.Reader
|
||||
arr Array
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"nil reader",
|
||||
nil,
|
||||
nil,
|
||||
ErrNilReader,
|
||||
},
|
||||
{
|
||||
"premature end of reader",
|
||||
bytes.NewBuffer([]byte{}),
|
||||
nil,
|
||||
io.EOF,
|
||||
},
|
||||
{
|
||||
"empty Array",
|
||||
bytes.NewBuffer([]byte{5, 0, 0, 0, 0}),
|
||||
[]byte{5, 0, 0, 0, 0},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"non-empty Array",
|
||||
bytes.NewBuffer([]byte{
|
||||
'\x1B', '\x00', '\x00', '\x00',
|
||||
'\x02',
|
||||
'0', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x72', '\x00',
|
||||
'\x02',
|
||||
'1', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x7a', '\x00',
|
||||
'\x00',
|
||||
}),
|
||||
[]byte{
|
||||
'\x1B', '\x00', '\x00', '\x00',
|
||||
'\x02',
|
||||
'0', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x72', '\x00',
|
||||
'\x02',
|
||||
'1', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x7a', '\x00',
|
||||
'\x00',
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
arr, err := NewArrayFromReader(tc.ioReader)
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if !bytes.Equal(tc.arr, arr) {
|
||||
t.Errorf("Arrays differ. got %v; want %v", tc.arr, arr)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("DebugString", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
arr Array
|
||||
arrayString string
|
||||
arrayDebugString string
|
||||
}{
|
||||
{
|
||||
"array",
|
||||
Array{
|
||||
'\x1B', '\x00', '\x00', '\x00',
|
||||
'\x02',
|
||||
'0', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x72', '\x00',
|
||||
'\x02',
|
||||
'1', '\x00',
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x62', '\x61', '\x7a', '\x00',
|
||||
'\x00',
|
||||
},
|
||||
`["bar","baz"]`,
|
||||
`Array(27)["bar","baz"]`,
|
||||
},
|
||||
{
|
||||
"subarray",
|
||||
Array{
|
||||
'\x13', '\x00', '\x00', '\x00',
|
||||
'\x04',
|
||||
'0', '\x00',
|
||||
'\x0B', '\x00', '\x00', '\x00',
|
||||
'\x0A', '0', '\x00',
|
||||
'\x0A', '1', '\x00',
|
||||
'\x00', '\x00',
|
||||
},
|
||||
`[[null,null]]`,
|
||||
`Array(19)[Array(11)[null,null]]`,
|
||||
},
|
||||
{
|
||||
"malformed--length too small",
|
||||
Array{
|
||||
'\x04', '\x00', '\x00', '\x00',
|
||||
'\x00',
|
||||
},
|
||||
``,
|
||||
`Array(4)[]`,
|
||||
},
|
||||
{
|
||||
"malformed--length too large",
|
||||
Array{
|
||||
'\x13', '\x00', '\x00', '\x00',
|
||||
'\x00',
|
||||
},
|
||||
``,
|
||||
`Array(19)[<malformed (15)>]`,
|
||||
},
|
||||
{
|
||||
"malformed--missing null byte",
|
||||
Array{
|
||||
'\x06', '\x00', '\x00', '\x00',
|
||||
'\x02', '0',
|
||||
},
|
||||
``,
|
||||
`Array(6)[<malformed (2)>]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
arrayString := tc.arr.String()
|
||||
if arrayString != tc.arrayString {
|
||||
t.Errorf("array strings do not match. got %q; want %q",
|
||||
arrayString, tc.arrayString)
|
||||
}
|
||||
|
||||
arrayDebugString := tc.arr.DebugString()
|
||||
if arrayDebugString != tc.arrayDebugString {
|
||||
t.Errorf("array debug strings do not match. got %q; want %q",
|
||||
arrayDebugString, tc.arrayDebugString)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
201
mongo/x/bsonx/bsoncore/bson_arraybuilder.go
Normal file
201
mongo/x/bsonx/bsoncore/bson_arraybuilder.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// ArrayBuilder builds a bson array
|
||||
type ArrayBuilder struct {
|
||||
arr []byte
|
||||
indexes []int32
|
||||
keys []int
|
||||
}
|
||||
|
||||
// NewArrayBuilder creates a new ArrayBuilder
|
||||
func NewArrayBuilder() *ArrayBuilder {
|
||||
return (&ArrayBuilder{}).startArray()
|
||||
}
|
||||
|
||||
// startArray reserves the array's length and sets the index to where the length begins
|
||||
func (a *ArrayBuilder) startArray() *ArrayBuilder {
|
||||
var index int32
|
||||
index, a.arr = AppendArrayStart(a.arr)
|
||||
a.indexes = append(a.indexes, index)
|
||||
a.keys = append(a.keys, 0)
|
||||
return a
|
||||
}
|
||||
|
||||
// Build updates the length of the array and index to the beginning of the documents length
|
||||
// bytes, then returns the array (bson bytes)
|
||||
func (a *ArrayBuilder) Build() Array {
|
||||
lastIndex := len(a.indexes) - 1
|
||||
lastKey := len(a.keys) - 1
|
||||
a.arr, _ = AppendArrayEnd(a.arr, a.indexes[lastIndex])
|
||||
a.indexes = a.indexes[:lastIndex]
|
||||
a.keys = a.keys[:lastKey]
|
||||
return a.arr
|
||||
}
|
||||
|
||||
// incrementKey() increments the value keys and returns the key to be used to a.appendArray* functions
|
||||
func (a *ArrayBuilder) incrementKey() string {
|
||||
idx := len(a.keys) - 1
|
||||
key := strconv.Itoa(a.keys[idx])
|
||||
a.keys[idx]++
|
||||
return key
|
||||
}
|
||||
|
||||
// AppendInt32 will append i32 to ArrayBuilder.arr
|
||||
func (a *ArrayBuilder) AppendInt32(i32 int32) *ArrayBuilder {
|
||||
a.arr = AppendInt32Element(a.arr, a.incrementKey(), i32)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendDocument will append doc to ArrayBuilder.arr
|
||||
func (a *ArrayBuilder) AppendDocument(doc []byte) *ArrayBuilder {
|
||||
a.arr = AppendDocumentElement(a.arr, a.incrementKey(), doc)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendArray will append arr to ArrayBuilder.arr
|
||||
func (a *ArrayBuilder) AppendArray(arr []byte) *ArrayBuilder {
|
||||
a.arr = AppendArrayElement(a.arr, a.incrementKey(), arr)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendDouble will append f to ArrayBuilder.doc
|
||||
func (a *ArrayBuilder) AppendDouble(f float64) *ArrayBuilder {
|
||||
a.arr = AppendDoubleElement(a.arr, a.incrementKey(), f)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendString will append str to ArrayBuilder.doc
|
||||
func (a *ArrayBuilder) AppendString(str string) *ArrayBuilder {
|
||||
a.arr = AppendStringElement(a.arr, a.incrementKey(), str)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendObjectID will append oid to ArrayBuilder.doc
|
||||
func (a *ArrayBuilder) AppendObjectID(oid primitive.ObjectID) *ArrayBuilder {
|
||||
a.arr = AppendObjectIDElement(a.arr, a.incrementKey(), oid)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendBinary will append a BSON binary element using subtype, and
|
||||
// b to a.arr
|
||||
func (a *ArrayBuilder) AppendBinary(subtype byte, b []byte) *ArrayBuilder {
|
||||
a.arr = AppendBinaryElement(a.arr, a.incrementKey(), subtype, b)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendUndefined will append a BSON undefined element using key to a.arr
|
||||
func (a *ArrayBuilder) AppendUndefined() *ArrayBuilder {
|
||||
a.arr = AppendUndefinedElement(a.arr, a.incrementKey())
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendBoolean will append a boolean element using b to a.arr
|
||||
func (a *ArrayBuilder) AppendBoolean(b bool) *ArrayBuilder {
|
||||
a.arr = AppendBooleanElement(a.arr, a.incrementKey(), b)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendDateTime will append datetime element dt to a.arr
|
||||
func (a *ArrayBuilder) AppendDateTime(dt int64) *ArrayBuilder {
|
||||
a.arr = AppendDateTimeElement(a.arr, a.incrementKey(), dt)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendNull will append a null element to a.arr
|
||||
func (a *ArrayBuilder) AppendNull() *ArrayBuilder {
|
||||
a.arr = AppendNullElement(a.arr, a.incrementKey())
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendRegex will append pattern and options to a.arr
|
||||
func (a *ArrayBuilder) AppendRegex(pattern, options string) *ArrayBuilder {
|
||||
a.arr = AppendRegexElement(a.arr, a.incrementKey(), pattern, options)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendDBPointer will append ns and oid to a.arr
|
||||
func (a *ArrayBuilder) AppendDBPointer(ns string, oid primitive.ObjectID) *ArrayBuilder {
|
||||
a.arr = AppendDBPointerElement(a.arr, a.incrementKey(), ns, oid)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendJavaScript will append js to a.arr
|
||||
func (a *ArrayBuilder) AppendJavaScript(js string) *ArrayBuilder {
|
||||
a.arr = AppendJavaScriptElement(a.arr, a.incrementKey(), js)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendSymbol will append symbol to a.arr
|
||||
func (a *ArrayBuilder) AppendSymbol(symbol string) *ArrayBuilder {
|
||||
a.arr = AppendSymbolElement(a.arr, a.incrementKey(), symbol)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendCodeWithScope will append code and scope to a.arr
|
||||
func (a *ArrayBuilder) AppendCodeWithScope(code string, scope Document) *ArrayBuilder {
|
||||
a.arr = AppendCodeWithScopeElement(a.arr, a.incrementKey(), code, scope)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendTimestamp will append t and i to a.arr
|
||||
func (a *ArrayBuilder) AppendTimestamp(t, i uint32) *ArrayBuilder {
|
||||
a.arr = AppendTimestampElement(a.arr, a.incrementKey(), t, i)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendInt64 will append i64 to a.arr
|
||||
func (a *ArrayBuilder) AppendInt64(i64 int64) *ArrayBuilder {
|
||||
a.arr = AppendInt64Element(a.arr, a.incrementKey(), i64)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendDecimal128 will append d128 to a.arr
|
||||
func (a *ArrayBuilder) AppendDecimal128(d128 primitive.Decimal128) *ArrayBuilder {
|
||||
a.arr = AppendDecimal128Element(a.arr, a.incrementKey(), d128)
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendMaxKey will append a max key element to a.arr
|
||||
func (a *ArrayBuilder) AppendMaxKey() *ArrayBuilder {
|
||||
a.arr = AppendMaxKeyElement(a.arr, a.incrementKey())
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendMinKey will append a min key element to a.arr
|
||||
func (a *ArrayBuilder) AppendMinKey() *ArrayBuilder {
|
||||
a.arr = AppendMinKeyElement(a.arr, a.incrementKey())
|
||||
return a
|
||||
}
|
||||
|
||||
// AppendValue appends a BSON value to the array.
|
||||
func (a *ArrayBuilder) AppendValue(val Value) *ArrayBuilder {
|
||||
a.arr = AppendValueElement(a.arr, a.incrementKey(), val)
|
||||
return a
|
||||
}
|
||||
|
||||
// StartArray starts building an inline Array. After this document is completed,
|
||||
// the user must call a.FinishArray
|
||||
func (a *ArrayBuilder) StartArray() *ArrayBuilder {
|
||||
a.arr = AppendHeader(a.arr, bsontype.Array, a.incrementKey())
|
||||
a.startArray()
|
||||
return a
|
||||
}
|
||||
|
||||
// FinishArray builds the most recent array created
|
||||
func (a *ArrayBuilder) FinishArray() *ArrayBuilder {
|
||||
a.arr = a.Build()
|
||||
return a
|
||||
}
|
||||
213
mongo/x/bsonx/bsoncore/bson_arraybuilder_test.go
Normal file
213
mongo/x/bsonx/bsoncore/bson_arraybuilder_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func TestArrayBuilder(t *testing.T) {
|
||||
bits := math.Float64bits(3.14159)
|
||||
pi := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(pi, bits)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{}
|
||||
params []interface{}
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
"AppendInt32",
|
||||
NewArrayBuilder().AppendInt32,
|
||||
[]interface{}{int32(256)},
|
||||
BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(256))),
|
||||
},
|
||||
{
|
||||
"AppendDouble",
|
||||
NewArrayBuilder().AppendDouble,
|
||||
[]interface{}{float64(3.14159)},
|
||||
BuildDocumentFromElements(nil, AppendDoubleElement(nil, "0", float64(3.14159))),
|
||||
},
|
||||
{
|
||||
"AppendString",
|
||||
NewArrayBuilder().AppendString,
|
||||
[]interface{}{"x"},
|
||||
BuildDocumentFromElements(nil, AppendStringElement(nil, "0", "x")),
|
||||
},
|
||||
{
|
||||
"AppendDocument",
|
||||
NewArrayBuilder().AppendDocument,
|
||||
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
BuildDocumentFromElements(nil, AppendDocumentElement(nil, "0", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
|
||||
},
|
||||
{
|
||||
"AppendArray",
|
||||
NewArrayBuilder().AppendArray,
|
||||
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
|
||||
},
|
||||
{
|
||||
"AppendBinary",
|
||||
NewArrayBuilder().AppendBinary,
|
||||
[]interface{}{byte(0x02), []byte{0x01, 0x02, 0x03}},
|
||||
BuildDocumentFromElements(nil, AppendBinaryElement(nil, "0", byte(0x02), []byte{0x01, 0x02, 0x03})),
|
||||
},
|
||||
{
|
||||
"AppendObjectID",
|
||||
NewArrayBuilder().AppendObjectID,
|
||||
[]interface{}{
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
BuildDocumentFromElements(nil, AppendObjectIDElement(nil, "0",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
|
||||
},
|
||||
{
|
||||
"AppendBoolean",
|
||||
NewArrayBuilder().AppendBoolean,
|
||||
[]interface{}{true},
|
||||
BuildDocumentFromElements(nil, AppendBooleanElement(nil, "0", true)),
|
||||
},
|
||||
{
|
||||
"AppendDateTime",
|
||||
NewArrayBuilder().AppendDateTime,
|
||||
[]interface{}{int64(256)},
|
||||
BuildDocumentFromElements(nil, AppendDateTimeElement(nil, "0", int64(256))),
|
||||
},
|
||||
{
|
||||
"AppendNull",
|
||||
NewArrayBuilder().AppendNull,
|
||||
[]interface{}{},
|
||||
BuildDocumentFromElements(nil, AppendNullElement(nil, "0")),
|
||||
},
|
||||
{
|
||||
"AppendRegex",
|
||||
NewArrayBuilder().AppendRegex,
|
||||
[]interface{}{"bar", "baz"},
|
||||
BuildDocumentFromElements(nil, AppendRegexElement(nil, "0", "bar", "baz")),
|
||||
},
|
||||
{
|
||||
"AppendJavaScript",
|
||||
NewArrayBuilder().AppendJavaScript,
|
||||
[]interface{}{"barbaz"},
|
||||
BuildDocumentFromElements(nil, AppendJavaScriptElement(nil, "0", "barbaz")),
|
||||
},
|
||||
{
|
||||
"AppendCodeWithScope",
|
||||
NewArrayBuilder().AppendCodeWithScope,
|
||||
[]interface{}{"barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00})},
|
||||
BuildDocumentFromElements(nil, AppendCodeWithScopeElement(nil, "0", "barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00}))),
|
||||
},
|
||||
{
|
||||
"AppendTimestamp",
|
||||
NewArrayBuilder().AppendTimestamp,
|
||||
[]interface{}{uint32(65536), uint32(256)},
|
||||
BuildDocumentFromElements(nil, AppendTimestampElement(nil, "0", uint32(65536), uint32(256))),
|
||||
},
|
||||
{
|
||||
"AppendInt64",
|
||||
NewArrayBuilder().AppendInt64,
|
||||
[]interface{}{int64(4294967296)},
|
||||
BuildDocumentFromElements(nil, AppendInt64Element(nil, "0", int64(4294967296))),
|
||||
},
|
||||
{
|
||||
"AppendDecimal128",
|
||||
NewArrayBuilder().AppendDecimal128,
|
||||
[]interface{}{primitive.NewDecimal128(4294967296, 65536)},
|
||||
BuildDocumentFromElements(nil, AppendDecimal128Element(nil, "0", primitive.NewDecimal128(4294967296, 65536))),
|
||||
},
|
||||
{
|
||||
"AppendMaxKey",
|
||||
NewArrayBuilder().AppendMaxKey,
|
||||
[]interface{}{},
|
||||
BuildDocumentFromElements(nil, AppendMaxKeyElement(nil, "0")),
|
||||
},
|
||||
{
|
||||
"AppendMinKey",
|
||||
NewArrayBuilder().AppendMinKey,
|
||||
[]interface{}{},
|
||||
BuildDocumentFromElements(nil, AppendMinKeyElement(nil, "0")),
|
||||
},
|
||||
{
|
||||
"AppendSymbol",
|
||||
NewArrayBuilder().AppendSymbol,
|
||||
[]interface{}{"barbaz"},
|
||||
BuildDocumentFromElements(nil, AppendSymbolElement(nil, "0", "barbaz")),
|
||||
},
|
||||
{
|
||||
"AppendDBPointer",
|
||||
NewArrayBuilder().AppendDBPointer,
|
||||
[]interface{}{"barbaz",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}},
|
||||
BuildDocumentFromElements(nil, AppendDBPointerElement(nil, "0", "barbaz",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
|
||||
},
|
||||
{
|
||||
"AppendUndefined",
|
||||
NewArrayBuilder().AppendUndefined,
|
||||
[]interface{}{},
|
||||
BuildDocumentFromElements(nil, AppendUndefinedElement(nil, "0")),
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("fn must be of kind Func but is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != len(tc.params) {
|
||||
t.Fatalf("tc.params must match the number of params in tc.fn. params %d; fn %d", fn.Type().NumIn(), len(tc.params))
|
||||
}
|
||||
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf(&ArrayBuilder{}) {
|
||||
t.Fatalf("fn must have one return parameter and it must be an ArrayBuilder.")
|
||||
}
|
||||
params := make([]reflect.Value, 0, len(tc.params))
|
||||
for _, param := range tc.params {
|
||||
params = append(params, reflect.ValueOf(param))
|
||||
}
|
||||
results := fn.Call(params)
|
||||
got := results[0].Interface().(*ArrayBuilder).Build()
|
||||
want := tc.expected
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Did not receive expected bytes. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("TestBuildTwoElementsArray", func(t *testing.T) {
|
||||
intArr := BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(1)))
|
||||
expected := BuildDocumentFromElements(nil, AppendArrayElement(AppendInt32Element(nil, "0", int32(3)), "1", intArr))
|
||||
elem := NewArrayBuilder().AppendInt32(int32(1)).Build()
|
||||
result := NewArrayBuilder().AppendInt32(int32(3)).AppendArray(elem).Build()
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Arrays do not match. got %v; want %v", result, expected)
|
||||
}
|
||||
})
|
||||
t.Run("TestBuildInlineArray", func(t *testing.T) {
|
||||
docElement := BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(256)))
|
||||
expected := Document(BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", docElement)))
|
||||
result := NewArrayBuilder().StartArray().AppendInt32(int32(256)).FinishArray().Build()
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", result, expected)
|
||||
}
|
||||
})
|
||||
t.Run("TestBuildNestedInlineArray", func(t *testing.T) {
|
||||
docElement := BuildDocumentFromElements(nil, AppendDoubleElement(nil, "0", 3.14))
|
||||
docInline := BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", docElement))
|
||||
expected := Document(BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", docInline)))
|
||||
result := NewArrayBuilder().StartArray().StartArray().AppendDouble(3.14).FinishArray().FinishArray().Build()
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", result, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
189
mongo/x/bsonx/bsoncore/bson_documentbuilder.go
Normal file
189
mongo/x/bsonx/bsoncore/bson_documentbuilder.go
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// DocumentBuilder builds a bson document
|
||||
type DocumentBuilder struct {
|
||||
doc []byte
|
||||
indexes []int32
|
||||
}
|
||||
|
||||
// startDocument reserves the document's length and set the index to where the length begins
|
||||
func (db *DocumentBuilder) startDocument() *DocumentBuilder {
|
||||
var index int32
|
||||
index, db.doc = AppendDocumentStart(db.doc)
|
||||
db.indexes = append(db.indexes, index)
|
||||
return db
|
||||
}
|
||||
|
||||
// NewDocumentBuilder creates a new DocumentBuilder
|
||||
func NewDocumentBuilder() *DocumentBuilder {
|
||||
return (&DocumentBuilder{}).startDocument()
|
||||
}
|
||||
|
||||
// Build updates the length of the document and index to the beginning of the documents length
|
||||
// bytes, then returns the document (bson bytes)
|
||||
func (db *DocumentBuilder) Build() Document {
|
||||
last := len(db.indexes) - 1
|
||||
db.doc, _ = AppendDocumentEnd(db.doc, db.indexes[last])
|
||||
db.indexes = db.indexes[:last]
|
||||
return db.doc
|
||||
}
|
||||
|
||||
// AppendInt32 will append an int32 element using key and i32 to DocumentBuilder.doc
|
||||
func (db *DocumentBuilder) AppendInt32(key string, i32 int32) *DocumentBuilder {
|
||||
db.doc = AppendInt32Element(db.doc, key, i32)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendDocument will append a bson embedded document element using key
|
||||
// and doc to DocumentBuilder.doc
|
||||
func (db *DocumentBuilder) AppendDocument(key string, doc []byte) *DocumentBuilder {
|
||||
db.doc = AppendDocumentElement(db.doc, key, doc)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendArray will append a bson array using key and arr to DocumentBuilder.doc
|
||||
func (db *DocumentBuilder) AppendArray(key string, arr []byte) *DocumentBuilder {
|
||||
db.doc = AppendHeader(db.doc, bsontype.Array, key)
|
||||
db.doc = AppendArray(db.doc, arr)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendDouble will append a double element using key and f to DocumentBuilder.doc
|
||||
func (db *DocumentBuilder) AppendDouble(key string, f float64) *DocumentBuilder {
|
||||
db.doc = AppendDoubleElement(db.doc, key, f)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendString will append str to DocumentBuilder.doc with the given key
|
||||
func (db *DocumentBuilder) AppendString(key string, str string) *DocumentBuilder {
|
||||
db.doc = AppendStringElement(db.doc, key, str)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendObjectID will append oid to DocumentBuilder.doc with the given key
|
||||
func (db *DocumentBuilder) AppendObjectID(key string, oid primitive.ObjectID) *DocumentBuilder {
|
||||
db.doc = AppendObjectIDElement(db.doc, key, oid)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendBinary will append a BSON binary element using key, subtype, and
|
||||
// b to db.doc
|
||||
func (db *DocumentBuilder) AppendBinary(key string, subtype byte, b []byte) *DocumentBuilder {
|
||||
db.doc = AppendBinaryElement(db.doc, key, subtype, b)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendUndefined will append a BSON undefined element using key to db.doc
|
||||
func (db *DocumentBuilder) AppendUndefined(key string) *DocumentBuilder {
|
||||
db.doc = AppendUndefinedElement(db.doc, key)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendBoolean will append a boolean element using key and b to db.doc
|
||||
func (db *DocumentBuilder) AppendBoolean(key string, b bool) *DocumentBuilder {
|
||||
db.doc = AppendBooleanElement(db.doc, key, b)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendDateTime will append a datetime element using key and dt to db.doc
|
||||
func (db *DocumentBuilder) AppendDateTime(key string, dt int64) *DocumentBuilder {
|
||||
db.doc = AppendDateTimeElement(db.doc, key, dt)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendNull will append a null element using key to db.doc
|
||||
func (db *DocumentBuilder) AppendNull(key string) *DocumentBuilder {
|
||||
db.doc = AppendNullElement(db.doc, key)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendRegex will append pattern and options using key to db.doc
|
||||
func (db *DocumentBuilder) AppendRegex(key, pattern, options string) *DocumentBuilder {
|
||||
db.doc = AppendRegexElement(db.doc, key, pattern, options)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendDBPointer will append ns and oid to using key to db.doc
|
||||
func (db *DocumentBuilder) AppendDBPointer(key string, ns string, oid primitive.ObjectID) *DocumentBuilder {
|
||||
db.doc = AppendDBPointerElement(db.doc, key, ns, oid)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendJavaScript will append js using the provided key to db.doc
|
||||
func (db *DocumentBuilder) AppendJavaScript(key, js string) *DocumentBuilder {
|
||||
db.doc = AppendJavaScriptElement(db.doc, key, js)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendSymbol will append a BSON symbol element using key and symbol db.doc
|
||||
func (db *DocumentBuilder) AppendSymbol(key, symbol string) *DocumentBuilder {
|
||||
db.doc = AppendSymbolElement(db.doc, key, symbol)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendCodeWithScope will append code and scope using key to db.doc
|
||||
func (db *DocumentBuilder) AppendCodeWithScope(key string, code string, scope Document) *DocumentBuilder {
|
||||
db.doc = AppendCodeWithScopeElement(db.doc, key, code, scope)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendTimestamp will append t and i to db.doc using provided key
|
||||
func (db *DocumentBuilder) AppendTimestamp(key string, t, i uint32) *DocumentBuilder {
|
||||
db.doc = AppendTimestampElement(db.doc, key, t, i)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendInt64 will append i64 to dst using key to db.doc
|
||||
func (db *DocumentBuilder) AppendInt64(key string, i64 int64) *DocumentBuilder {
|
||||
db.doc = AppendInt64Element(db.doc, key, i64)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendDecimal128 will append d128 to db.doc using provided key
|
||||
func (db *DocumentBuilder) AppendDecimal128(key string, d128 primitive.Decimal128) *DocumentBuilder {
|
||||
db.doc = AppendDecimal128Element(db.doc, key, d128)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendMaxKey will append a max key element using key to db.doc
|
||||
func (db *DocumentBuilder) AppendMaxKey(key string) *DocumentBuilder {
|
||||
db.doc = AppendMaxKeyElement(db.doc, key)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendMinKey will append a min key element using key to db.doc
|
||||
func (db *DocumentBuilder) AppendMinKey(key string) *DocumentBuilder {
|
||||
db.doc = AppendMinKeyElement(db.doc, key)
|
||||
return db
|
||||
}
|
||||
|
||||
// AppendValue will append a BSON element with the provided key and value to the document.
|
||||
func (db *DocumentBuilder) AppendValue(key string, val Value) *DocumentBuilder {
|
||||
db.doc = AppendValueElement(db.doc, key, val)
|
||||
return db
|
||||
}
|
||||
|
||||
// StartDocument starts building an inline document element with the provided key
|
||||
// After this document is completed, the user must call finishDocument
|
||||
func (db *DocumentBuilder) StartDocument(key string) *DocumentBuilder {
|
||||
db.doc = AppendHeader(db.doc, bsontype.EmbeddedDocument, key)
|
||||
db = db.startDocument()
|
||||
return db
|
||||
}
|
||||
|
||||
// FinishDocument builds the most recent document created
|
||||
func (db *DocumentBuilder) FinishDocument() *DocumentBuilder {
|
||||
db.doc = db.Build()
|
||||
return db
|
||||
}
|
||||
215
mongo/x/bsonx/bsoncore/bson_documentbuilder_test.go
Normal file
215
mongo/x/bsonx/bsoncore/bson_documentbuilder_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func TestDocumentBuilder(t *testing.T) {
|
||||
bits := math.Float64bits(3.14159)
|
||||
pi := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(pi, bits)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{}
|
||||
params []interface{}
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
"AppendInt32",
|
||||
NewDocumentBuilder().AppendInt32,
|
||||
[]interface{}{"foobar", int32(256)},
|
||||
BuildDocumentFromElements(nil, AppendInt32Element(nil, "foobar", 256)),
|
||||
},
|
||||
{
|
||||
"AppendDouble",
|
||||
NewDocumentBuilder().AppendDouble,
|
||||
[]interface{}{"foobar", float64(3.14159)},
|
||||
BuildDocumentFromElements(nil, AppendDoubleElement(nil, "foobar", float64(3.14159))),
|
||||
},
|
||||
{
|
||||
"AppendString",
|
||||
NewDocumentBuilder().AppendString,
|
||||
[]interface{}{"foobar", "x"},
|
||||
BuildDocumentFromElements(nil, AppendStringElement(nil, "foobar", "x")),
|
||||
},
|
||||
{
|
||||
"AppendDocument",
|
||||
NewDocumentBuilder().AppendDocument,
|
||||
[]interface{}{"foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
BuildDocumentFromElements(nil, AppendDocumentElement(nil, "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
|
||||
},
|
||||
{
|
||||
"AppendArray",
|
||||
NewDocumentBuilder().AppendArray,
|
||||
[]interface{}{"foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
BuildDocumentFromElements(nil, AppendArrayElement(nil, "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
|
||||
},
|
||||
{
|
||||
"AppendBinary",
|
||||
NewDocumentBuilder().AppendBinary,
|
||||
[]interface{}{"foobar", byte(0x02), []byte{0x01, 0x02, 0x03}},
|
||||
BuildDocumentFromElements(nil, AppendBinaryElement(nil, "foobar", byte(0x02), []byte{0x01, 0x02, 0x03})),
|
||||
},
|
||||
{
|
||||
"AppendObjectID",
|
||||
NewDocumentBuilder().AppendObjectID,
|
||||
[]interface{}{
|
||||
"foobar",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
BuildDocumentFromElements(nil, AppendObjectIDElement(nil, "foobar",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
|
||||
},
|
||||
{
|
||||
"AppendBoolean",
|
||||
NewDocumentBuilder().AppendBoolean,
|
||||
[]interface{}{"foobar", true},
|
||||
BuildDocumentFromElements(nil, AppendBooleanElement(nil, "foobar", true)),
|
||||
},
|
||||
{
|
||||
"AppendDateTime",
|
||||
NewDocumentBuilder().AppendDateTime,
|
||||
[]interface{}{"foobar", int64(256)},
|
||||
BuildDocumentFromElements(nil, AppendDateTimeElement(nil, "foobar", int64(256))),
|
||||
},
|
||||
{
|
||||
"AppendNull",
|
||||
NewDocumentBuilder().AppendNull,
|
||||
[]interface{}{"foobar"},
|
||||
BuildDocumentFromElements(nil, AppendNullElement(nil, "foobar")),
|
||||
},
|
||||
{
|
||||
"AppendRegex",
|
||||
NewDocumentBuilder().AppendRegex,
|
||||
[]interface{}{"foobar", "bar", "baz"},
|
||||
BuildDocumentFromElements(nil, AppendRegexElement(nil, "foobar", "bar", "baz")),
|
||||
},
|
||||
{
|
||||
"AppendJavaScript",
|
||||
NewDocumentBuilder().AppendJavaScript,
|
||||
[]interface{}{"foobar", "barbaz"},
|
||||
BuildDocumentFromElements(nil, AppendJavaScriptElement(nil, "foobar", "barbaz")),
|
||||
},
|
||||
{
|
||||
"AppendCodeWithScope",
|
||||
NewDocumentBuilder().AppendCodeWithScope,
|
||||
[]interface{}{"foobar", "barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00})},
|
||||
BuildDocumentFromElements(nil, AppendCodeWithScopeElement(nil, "foobar", "barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00}))),
|
||||
},
|
||||
{
|
||||
"AppendTimestamp",
|
||||
NewDocumentBuilder().AppendTimestamp,
|
||||
[]interface{}{"foobar", uint32(65536), uint32(256)},
|
||||
BuildDocumentFromElements(nil, AppendTimestampElement(nil, "foobar", uint32(65536), uint32(256))),
|
||||
},
|
||||
{
|
||||
"AppendInt64",
|
||||
NewDocumentBuilder().AppendInt64,
|
||||
[]interface{}{"foobar", int64(4294967296)},
|
||||
BuildDocumentFromElements(nil, AppendInt64Element(nil, "foobar", int64(4294967296))),
|
||||
},
|
||||
{
|
||||
"AppendDecimal128",
|
||||
NewDocumentBuilder().AppendDecimal128,
|
||||
[]interface{}{"foobar", primitive.NewDecimal128(4294967296, 65536)},
|
||||
BuildDocumentFromElements(nil, AppendDecimal128Element(nil, "foobar", primitive.NewDecimal128(4294967296, 65536))),
|
||||
},
|
||||
{
|
||||
"AppendMaxKey",
|
||||
NewDocumentBuilder().AppendMaxKey,
|
||||
[]interface{}{"foobar"},
|
||||
BuildDocumentFromElements(nil, AppendMaxKeyElement(nil, "foobar")),
|
||||
},
|
||||
{
|
||||
"AppendMinKey",
|
||||
NewDocumentBuilder().AppendMinKey,
|
||||
[]interface{}{"foobar"},
|
||||
BuildDocumentFromElements(nil, AppendMinKeyElement(nil, "foobar")),
|
||||
},
|
||||
{
|
||||
"AppendSymbol",
|
||||
NewDocumentBuilder().AppendSymbol,
|
||||
[]interface{}{"foobar", "barbaz"},
|
||||
BuildDocumentFromElements(nil, AppendSymbolElement(nil, "foobar", "barbaz")),
|
||||
},
|
||||
{
|
||||
"AppendDBPointer",
|
||||
NewDocumentBuilder().AppendDBPointer,
|
||||
[]interface{}{"foobar", "barbaz",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}},
|
||||
BuildDocumentFromElements(nil, AppendDBPointerElement(nil, "foobar", "barbaz",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
|
||||
},
|
||||
{
|
||||
"AppendUndefined",
|
||||
NewDocumentBuilder().AppendUndefined,
|
||||
[]interface{}{"foobar"},
|
||||
BuildDocumentFromElements(nil, AppendUndefinedElement(nil, "foobar")),
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("fn must be of kind Func but is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != len(tc.params) {
|
||||
t.Fatalf("tc.params must match the number of params in tc.fn. params %d; fn %d", fn.Type().NumIn(), len(tc.params))
|
||||
}
|
||||
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf(&DocumentBuilder{}) {
|
||||
t.Fatalf("fn must have one return parameter and it must be a DocumentBuilder.")
|
||||
}
|
||||
params := make([]reflect.Value, 0, len(tc.params))
|
||||
for _, param := range tc.params {
|
||||
params = append(params, reflect.ValueOf(param))
|
||||
}
|
||||
results := fn.Call(params)
|
||||
got := results[0].Interface().(*DocumentBuilder)
|
||||
doc := got.Build()
|
||||
want := tc.expected
|
||||
if !bytes.Equal(doc, want) {
|
||||
t.Errorf("Did not receive expected bytes. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
t.Run("TestBuildTwoElements", func(t *testing.T) {
|
||||
intArr := BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(1)))
|
||||
expected := BuildDocumentFromElements(nil, AppendArrayElement(AppendInt32Element(nil, "x", int32(3)), "y", intArr))
|
||||
elem := NewArrayBuilder().AppendInt32(int32(1)).Build()
|
||||
result := NewDocumentBuilder().AppendInt32("x", int32(3)).AppendArray("y", elem).Build()
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", result, expected)
|
||||
}
|
||||
})
|
||||
t.Run("TestBuildInlineDocument", func(t *testing.T) {
|
||||
docElement := BuildDocumentFromElements(nil, AppendInt32Element(nil, "x", int32(256)))
|
||||
expected := Document(BuildDocumentFromElements(nil, AppendDocumentElement(nil, "y", docElement)))
|
||||
result := NewDocumentBuilder().StartDocument("y").AppendInt32("x", int32(256)).FinishDocument().Build()
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", result, expected)
|
||||
}
|
||||
})
|
||||
t.Run("TestBuildNestedInlineDocument", func(t *testing.T) {
|
||||
docElement := BuildDocumentFromElements(nil, AppendDoubleElement(nil, "x", 3.14))
|
||||
docInline := BuildDocumentFromElements(nil, AppendDocumentElement(nil, "y", docElement))
|
||||
expected := Document(BuildDocumentFromElements(nil, AppendDocumentElement(nil, "z", docInline)))
|
||||
result := NewDocumentBuilder().StartDocument("z").StartDocument("y").AppendDouble("x", 3.14).FinishDocument().FinishDocument().Build()
|
||||
if !bytes.Equal(result, expected) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", result, expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
862
mongo/x/bsonx/bsoncore/bsoncore.go
Normal file
862
mongo/x/bsonx/bsoncore/bsoncore.go
Normal file
@@ -0,0 +1,862 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Package bsoncore contains functions that can be used to encode and decode BSON
|
||||
// elements and values to or from a slice of bytes. These functions are aimed at
|
||||
// allowing low level manipulation of BSON and can be used to build a higher
|
||||
// level BSON library.
|
||||
//
|
||||
// The Read* functions within this package return the values of the element and
|
||||
// a boolean indicating if the values are valid. A boolean was used instead of
|
||||
// an error because any error that would be returned would be the same: not
|
||||
// enough bytes. This library attempts to do no validation, it will only return
|
||||
// false if there are not enough bytes for an item to be read. For example, the
|
||||
// ReadDocument function checks the length, if that length is larger than the
|
||||
// number of bytes available, it will return false, if there are enough bytes, it
|
||||
// will return those bytes and true. It is the consumers responsibility to
|
||||
// validate those bytes.
|
||||
//
|
||||
// The Append* functions within this package will append the type value to the
|
||||
// given dst slice. If the slice has enough capacity, it will not grow the
|
||||
// slice. The Append*Element functions within this package operate in the same
|
||||
// way, but additionally append the BSON type and the key before the value.
|
||||
package bsoncore // import "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
const (
|
||||
// EmptyDocumentLength is the length of a document that has been started/ended but has no elements.
|
||||
EmptyDocumentLength = 5
|
||||
// nullTerminator is a string version of the 0 byte that is appended at the end of cstrings.
|
||||
nullTerminator = string(byte(0))
|
||||
invalidKeyPanicMsg = "BSON element keys cannot contain null bytes"
|
||||
invalidRegexPanicMsg = "BSON regex values cannot contain null bytes"
|
||||
)
|
||||
|
||||
// AppendType will append t to dst and return the extended buffer.
|
||||
func AppendType(dst []byte, t bsontype.Type) []byte { return append(dst, byte(t)) }
|
||||
|
||||
// AppendKey will append key to dst and return the extended buffer.
|
||||
func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTerminator...) }
|
||||
|
||||
// AppendHeader will append Type t and key to dst and return the extended
|
||||
// buffer.
|
||||
func AppendHeader(dst []byte, t bsontype.Type, key string) []byte {
|
||||
if !isValidCString(key) {
|
||||
panic(invalidKeyPanicMsg)
|
||||
}
|
||||
|
||||
dst = AppendType(dst, t)
|
||||
dst = append(dst, key...)
|
||||
return append(dst, 0x00)
|
||||
// return append(AppendType(dst, t), key+string(0x00)...)
|
||||
}
|
||||
|
||||
// TODO(skriptble): All of the Read* functions should return src resliced to start just after what was read.
|
||||
|
||||
// ReadType will return the first byte of the provided []byte as a type. If
|
||||
// there is no available byte, false is returned.
|
||||
func ReadType(src []byte) (bsontype.Type, []byte, bool) {
|
||||
if len(src) < 1 {
|
||||
return 0, src, false
|
||||
}
|
||||
return bsontype.Type(src[0]), src[1:], true
|
||||
}
|
||||
|
||||
// ReadKey will read a key from src. The 0x00 byte will not be present
|
||||
// in the returned string. If there are not enough bytes available, false is
|
||||
// returned.
|
||||
func ReadKey(src []byte) (string, []byte, bool) { return readcstring(src) }
|
||||
|
||||
// ReadKeyBytes will read a key from src as bytes. The 0x00 byte will
|
||||
// not be present in the returned string. If there are not enough bytes
|
||||
// available, false is returned.
|
||||
func ReadKeyBytes(src []byte) ([]byte, []byte, bool) { return readcstringbytes(src) }
|
||||
|
||||
// ReadHeader will read a type byte and a key from src. If both of these
|
||||
// values cannot be read, false is returned.
|
||||
func ReadHeader(src []byte) (t bsontype.Type, key string, rem []byte, ok bool) {
|
||||
t, rem, ok = ReadType(src)
|
||||
if !ok {
|
||||
return 0, "", src, false
|
||||
}
|
||||
key, rem, ok = ReadKey(rem)
|
||||
if !ok {
|
||||
return 0, "", src, false
|
||||
}
|
||||
|
||||
return t, key, rem, true
|
||||
}
|
||||
|
||||
// ReadHeaderBytes will read a type and a key from src and the remainder of the bytes
|
||||
// are returned as rem. If either the type or key cannot be red, ok will be false.
|
||||
func ReadHeaderBytes(src []byte) (header []byte, rem []byte, ok bool) {
|
||||
if len(src) < 1 {
|
||||
return nil, src, false
|
||||
}
|
||||
idx := bytes.IndexByte(src[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return nil, src, false
|
||||
}
|
||||
return src[:idx], src[idx+1:], true
|
||||
}
|
||||
|
||||
// ReadElement reads the next full element from src. It returns the element, the remaining bytes in
|
||||
// the slice, and a boolean indicating if the read was successful.
|
||||
func ReadElement(src []byte) (Element, []byte, bool) {
|
||||
if len(src) < 1 {
|
||||
return nil, src, false
|
||||
}
|
||||
t := bsontype.Type(src[0])
|
||||
idx := bytes.IndexByte(src[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return nil, src, false
|
||||
}
|
||||
length, ok := valueLength(src[idx+2:], t) // We add 2 here because we called IndexByte with src[1:]
|
||||
if !ok {
|
||||
return nil, src, false
|
||||
}
|
||||
elemLength := 1 + idx + 1 + int(length)
|
||||
if elemLength > len(src) {
|
||||
return nil, src, false
|
||||
}
|
||||
if elemLength < 0 {
|
||||
return nil, src, false
|
||||
}
|
||||
return src[:elemLength], src[elemLength:], true
|
||||
}
|
||||
|
||||
// AppendValueElement appends value to dst as an element using key as the element's key.
|
||||
func AppendValueElement(dst []byte, key string, value Value) []byte {
|
||||
dst = AppendHeader(dst, value.Type, key)
|
||||
dst = append(dst, value.Data...)
|
||||
return dst
|
||||
}
|
||||
|
||||
// ReadValue reads the next value as the provided types and returns a Value, the remaining bytes,
|
||||
// and a boolean indicating if the read was successful.
|
||||
func ReadValue(src []byte, t bsontype.Type) (Value, []byte, bool) {
|
||||
data, rem, ok := readValue(src, t)
|
||||
if !ok {
|
||||
return Value{}, src, false
|
||||
}
|
||||
return Value{Type: t, Data: data}, rem, true
|
||||
}
|
||||
|
||||
// AppendDouble will append f to dst and return the extended buffer.
|
||||
func AppendDouble(dst []byte, f float64) []byte {
|
||||
return appendu64(dst, math.Float64bits(f))
|
||||
}
|
||||
|
||||
// AppendDoubleElement will append a BSON double element using key and f to dst
|
||||
// and return the extended buffer.
|
||||
func AppendDoubleElement(dst []byte, key string, f float64) []byte {
|
||||
return AppendDouble(AppendHeader(dst, bsontype.Double, key), f)
|
||||
}
|
||||
|
||||
// ReadDouble will read a float64 from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadDouble(src []byte) (float64, []byte, bool) {
|
||||
bits, src, ok := readu64(src)
|
||||
if !ok {
|
||||
return 0, src, false
|
||||
}
|
||||
return math.Float64frombits(bits), src, true
|
||||
}
|
||||
|
||||
// AppendString will append s to dst and return the extended buffer.
|
||||
func AppendString(dst []byte, s string) []byte {
|
||||
return appendstring(dst, s)
|
||||
}
|
||||
|
||||
// AppendStringElement will append a BSON string element using key and val to dst
|
||||
// and return the extended buffer.
|
||||
func AppendStringElement(dst []byte, key, val string) []byte {
|
||||
return AppendString(AppendHeader(dst, bsontype.String, key), val)
|
||||
}
|
||||
|
||||
// ReadString will read a string from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadString(src []byte) (string, []byte, bool) {
|
||||
return readstring(src)
|
||||
}
|
||||
|
||||
// AppendDocumentStart reserves a document's length and returns the index where the length begins.
|
||||
// This index can later be used to write the length of the document.
|
||||
func AppendDocumentStart(dst []byte) (index int32, b []byte) {
|
||||
// TODO(skriptble): We really need AppendDocumentStart and AppendDocumentEnd. AppendDocumentStart would handle calling
|
||||
// TODO ReserveLength and providing the index of the start of the document. AppendDocumentEnd would handle taking that
|
||||
// TODO start index, adding the null byte, calculating the length, and filling in the length at the start of the
|
||||
// TODO document.
|
||||
return ReserveLength(dst)
|
||||
}
|
||||
|
||||
// AppendDocumentStartInline functions the same as AppendDocumentStart but takes a pointer to the
|
||||
// index int32 which allows this function to be used inline.
|
||||
func AppendDocumentStartInline(dst []byte, index *int32) []byte {
|
||||
idx, doc := AppendDocumentStart(dst)
|
||||
*index = idx
|
||||
return doc
|
||||
}
|
||||
|
||||
// AppendDocumentElementStart writes a document element header and then reserves the length bytes.
|
||||
func AppendDocumentElementStart(dst []byte, key string) (index int32, b []byte) {
|
||||
return AppendDocumentStart(AppendHeader(dst, bsontype.EmbeddedDocument, key))
|
||||
}
|
||||
|
||||
// AppendDocumentEnd writes the null byte for a document and updates the length of the document.
|
||||
// The index should be the beginning of the document's length bytes.
|
||||
func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) {
|
||||
if int(index) > len(dst)-4 {
|
||||
return dst, fmt.Errorf("not enough bytes available after index to write length")
|
||||
}
|
||||
dst = append(dst, 0x00)
|
||||
dst = UpdateLength(dst, index, int32(len(dst[index:])))
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// AppendDocument will append doc to dst and return the extended buffer.
|
||||
func AppendDocument(dst []byte, doc []byte) []byte { return append(dst, doc...) }
|
||||
|
||||
// AppendDocumentElement will append a BSON embedded document element using key
|
||||
// and doc to dst and return the extended buffer.
|
||||
func AppendDocumentElement(dst []byte, key string, doc []byte) []byte {
|
||||
return AppendDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), doc)
|
||||
}
|
||||
|
||||
// BuildDocument will create a document with the given slice of elements and will append
|
||||
// it to dst and return the extended buffer.
|
||||
func BuildDocument(dst []byte, elems ...[]byte) []byte {
|
||||
idx, dst := ReserveLength(dst)
|
||||
for _, elem := range elems {
|
||||
dst = append(dst, elem...)
|
||||
}
|
||||
dst = append(dst, 0x00)
|
||||
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||||
return dst
|
||||
}
|
||||
|
||||
// BuildDocumentValue creates an Embedded Document value from the given elements.
|
||||
func BuildDocumentValue(elems ...[]byte) Value {
|
||||
return Value{Type: bsontype.EmbeddedDocument, Data: BuildDocument(nil, elems...)}
|
||||
}
|
||||
|
||||
// BuildDocumentElement will append a BSON embedded document elemnt using key and the provided
|
||||
// elements and return the extended buffer.
|
||||
func BuildDocumentElement(dst []byte, key string, elems ...[]byte) []byte {
|
||||
return BuildDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), elems...)
|
||||
}
|
||||
|
||||
// BuildDocumentFromElements is an alaias for the BuildDocument function.
|
||||
var BuildDocumentFromElements = BuildDocument
|
||||
|
||||
// ReadDocument will read a document from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadDocument(src []byte) (doc Document, rem []byte, ok bool) { return readLengthBytes(src) }
|
||||
|
||||
// AppendArrayStart appends the length bytes to an array and then returns the index of the start
|
||||
// of those length bytes.
|
||||
func AppendArrayStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) }
|
||||
|
||||
// AppendArrayElementStart appends an array element header and then the length bytes for an array,
|
||||
// returning the index where the length starts.
|
||||
func AppendArrayElementStart(dst []byte, key string) (index int32, b []byte) {
|
||||
return AppendArrayStart(AppendHeader(dst, bsontype.Array, key))
|
||||
}
|
||||
|
||||
// AppendArrayEnd appends the null byte to an array and calculates the length, inserting that
|
||||
// calculated length starting at index.
|
||||
func AppendArrayEnd(dst []byte, index int32) ([]byte, error) { return AppendDocumentEnd(dst, index) }
|
||||
|
||||
// AppendArray will append arr to dst and return the extended buffer.
|
||||
func AppendArray(dst []byte, arr []byte) []byte { return append(dst, arr...) }
|
||||
|
||||
// AppendArrayElement will append a BSON array element using key and arr to dst
|
||||
// and return the extended buffer.
|
||||
func AppendArrayElement(dst []byte, key string, arr []byte) []byte {
|
||||
return AppendArray(AppendHeader(dst, bsontype.Array, key), arr)
|
||||
}
|
||||
|
||||
// BuildArray will append a BSON array to dst built from values.
|
||||
func BuildArray(dst []byte, values ...Value) []byte {
|
||||
idx, dst := ReserveLength(dst)
|
||||
for pos, val := range values {
|
||||
dst = AppendValueElement(dst, strconv.Itoa(pos), val)
|
||||
}
|
||||
dst = append(dst, 0x00)
|
||||
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||||
return dst
|
||||
}
|
||||
|
||||
// BuildArrayElement will create an array element using the provided values.
|
||||
func BuildArrayElement(dst []byte, key string, values ...Value) []byte {
|
||||
return BuildArray(AppendHeader(dst, bsontype.Array, key), values...)
|
||||
}
|
||||
|
||||
// ReadArray will read an array from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadArray(src []byte) (arr Array, rem []byte, ok bool) { return readLengthBytes(src) }
|
||||
|
||||
// AppendBinary will append subtype and b to dst and return the extended buffer.
|
||||
func AppendBinary(dst []byte, subtype byte, b []byte) []byte {
|
||||
if subtype == 0x02 {
|
||||
return appendBinarySubtype2(dst, subtype, b)
|
||||
}
|
||||
dst = append(appendLength(dst, int32(len(b))), subtype)
|
||||
return append(dst, b...)
|
||||
}
|
||||
|
||||
// AppendBinaryElement will append a BSON binary element using key, subtype, and
|
||||
// b to dst and return the extended buffer.
|
||||
func AppendBinaryElement(dst []byte, key string, subtype byte, b []byte) []byte {
|
||||
return AppendBinary(AppendHeader(dst, bsontype.Binary, key), subtype, b)
|
||||
}
|
||||
|
||||
// ReadBinary will read a subtype and bin from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadBinary(src []byte) (subtype byte, bin []byte, rem []byte, ok bool) {
|
||||
length, rem, ok := ReadLength(src)
|
||||
if !ok {
|
||||
return 0x00, nil, src, false
|
||||
}
|
||||
if len(rem) < 1 { // subtype
|
||||
return 0x00, nil, src, false
|
||||
}
|
||||
subtype, rem = rem[0], rem[1:]
|
||||
|
||||
if len(rem) < int(length) {
|
||||
return 0x00, nil, src, false
|
||||
}
|
||||
|
||||
if subtype == 0x02 {
|
||||
length, rem, ok = ReadLength(rem)
|
||||
if !ok || len(rem) < int(length) {
|
||||
return 0x00, nil, src, false
|
||||
}
|
||||
}
|
||||
|
||||
return subtype, rem[:length], rem[length:], true
|
||||
}
|
||||
|
||||
// AppendUndefinedElement will append a BSON undefined element using key to dst
|
||||
// and return the extended buffer.
|
||||
func AppendUndefinedElement(dst []byte, key string) []byte {
|
||||
return AppendHeader(dst, bsontype.Undefined, key)
|
||||
}
|
||||
|
||||
// AppendObjectID will append oid to dst and return the extended buffer.
|
||||
func AppendObjectID(dst []byte, oid primitive.ObjectID) []byte { return append(dst, oid[:]...) }
|
||||
|
||||
// AppendObjectIDElement will append a BSON ObjectID element using key and oid to dst
|
||||
// and return the extended buffer.
|
||||
func AppendObjectIDElement(dst []byte, key string, oid primitive.ObjectID) []byte {
|
||||
return AppendObjectID(AppendHeader(dst, bsontype.ObjectID, key), oid)
|
||||
}
|
||||
|
||||
// ReadObjectID will read an ObjectID from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadObjectID(src []byte) (primitive.ObjectID, []byte, bool) {
|
||||
if len(src) < 12 {
|
||||
return primitive.ObjectID{}, src, false
|
||||
}
|
||||
var oid primitive.ObjectID
|
||||
copy(oid[:], src[0:12])
|
||||
return oid, src[12:], true
|
||||
}
|
||||
|
||||
// AppendBoolean will append b to dst and return the extended buffer.
|
||||
func AppendBoolean(dst []byte, b bool) []byte {
|
||||
if b {
|
||||
return append(dst, 0x01)
|
||||
}
|
||||
return append(dst, 0x00)
|
||||
}
|
||||
|
||||
// AppendBooleanElement will append a BSON boolean element using key and b to dst
|
||||
// and return the extended buffer.
|
||||
func AppendBooleanElement(dst []byte, key string, b bool) []byte {
|
||||
return AppendBoolean(AppendHeader(dst, bsontype.Boolean, key), b)
|
||||
}
|
||||
|
||||
// ReadBoolean will read a bool from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadBoolean(src []byte) (bool, []byte, bool) {
|
||||
if len(src) < 1 {
|
||||
return false, src, false
|
||||
}
|
||||
|
||||
return src[0] == 0x01, src[1:], true
|
||||
}
|
||||
|
||||
// AppendDateTime will append dt to dst and return the extended buffer.
|
||||
func AppendDateTime(dst []byte, dt int64) []byte { return appendi64(dst, dt) }
|
||||
|
||||
// AppendDateTimeElement will append a BSON datetime element using key and dt to dst
|
||||
// and return the extended buffer.
|
||||
func AppendDateTimeElement(dst []byte, key string, dt int64) []byte {
|
||||
return AppendDateTime(AppendHeader(dst, bsontype.DateTime, key), dt)
|
||||
}
|
||||
|
||||
// ReadDateTime will read an int64 datetime from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadDateTime(src []byte) (int64, []byte, bool) { return readi64(src) }
|
||||
|
||||
// AppendTime will append time as a BSON DateTime to dst and return the extended buffer.
|
||||
func AppendTime(dst []byte, t time.Time) []byte {
|
||||
return AppendDateTime(dst, t.Unix()*1000+int64(t.Nanosecond()/1e6))
|
||||
}
|
||||
|
||||
// AppendTimeElement will append a BSON datetime element using key and dt to dst
|
||||
// and return the extended buffer.
|
||||
func AppendTimeElement(dst []byte, key string, t time.Time) []byte {
|
||||
return AppendTime(AppendHeader(dst, bsontype.DateTime, key), t)
|
||||
}
|
||||
|
||||
// ReadTime will read an time.Time datetime from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadTime(src []byte) (time.Time, []byte, bool) {
|
||||
dt, rem, ok := readi64(src)
|
||||
return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok
|
||||
}
|
||||
|
||||
// AppendNullElement will append a BSON null element using key to dst
|
||||
// and return the extended buffer.
|
||||
func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, bsontype.Null, key) }
|
||||
|
||||
// AppendRegex will append pattern and options to dst and return the extended buffer.
|
||||
func AppendRegex(dst []byte, pattern, options string) []byte {
|
||||
if !isValidCString(pattern) || !isValidCString(options) {
|
||||
panic(invalidRegexPanicMsg)
|
||||
}
|
||||
|
||||
return append(dst, pattern+nullTerminator+options+nullTerminator...)
|
||||
}
|
||||
|
||||
// AppendRegexElement will append a BSON regex element using key, pattern, and
|
||||
// options to dst and return the extended buffer.
|
||||
func AppendRegexElement(dst []byte, key, pattern, options string) []byte {
|
||||
return AppendRegex(AppendHeader(dst, bsontype.Regex, key), pattern, options)
|
||||
}
|
||||
|
||||
// ReadRegex will read a pattern and options from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadRegex(src []byte) (pattern, options string, rem []byte, ok bool) {
|
||||
pattern, rem, ok = readcstring(src)
|
||||
if !ok {
|
||||
return "", "", src, false
|
||||
}
|
||||
options, rem, ok = readcstring(rem)
|
||||
if !ok {
|
||||
return "", "", src, false
|
||||
}
|
||||
return pattern, options, rem, true
|
||||
}
|
||||
|
||||
// AppendDBPointer will append ns and oid to dst and return the extended buffer.
|
||||
func AppendDBPointer(dst []byte, ns string, oid primitive.ObjectID) []byte {
|
||||
return append(appendstring(dst, ns), oid[:]...)
|
||||
}
|
||||
|
||||
// AppendDBPointerElement will append a BSON DBPointer element using key, ns,
|
||||
// and oid to dst and return the extended buffer.
|
||||
func AppendDBPointerElement(dst []byte, key, ns string, oid primitive.ObjectID) []byte {
|
||||
return AppendDBPointer(AppendHeader(dst, bsontype.DBPointer, key), ns, oid)
|
||||
}
|
||||
|
||||
// ReadDBPointer will read a ns and oid from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadDBPointer(src []byte) (ns string, oid primitive.ObjectID, rem []byte, ok bool) {
|
||||
ns, rem, ok = readstring(src)
|
||||
if !ok {
|
||||
return "", primitive.ObjectID{}, src, false
|
||||
}
|
||||
oid, rem, ok = ReadObjectID(rem)
|
||||
if !ok {
|
||||
return "", primitive.ObjectID{}, src, false
|
||||
}
|
||||
return ns, oid, rem, true
|
||||
}
|
||||
|
||||
// AppendJavaScript will append js to dst and return the extended buffer.
|
||||
func AppendJavaScript(dst []byte, js string) []byte { return appendstring(dst, js) }
|
||||
|
||||
// AppendJavaScriptElement will append a BSON JavaScript element using key and
|
||||
// js to dst and return the extended buffer.
|
||||
func AppendJavaScriptElement(dst []byte, key, js string) []byte {
|
||||
return AppendJavaScript(AppendHeader(dst, bsontype.JavaScript, key), js)
|
||||
}
|
||||
|
||||
// ReadJavaScript will read a js string from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadJavaScript(src []byte) (js string, rem []byte, ok bool) { return readstring(src) }
|
||||
|
||||
// AppendSymbol will append symbol to dst and return the extended buffer.
|
||||
func AppendSymbol(dst []byte, symbol string) []byte { return appendstring(dst, symbol) }
|
||||
|
||||
// AppendSymbolElement will append a BSON symbol element using key and symbol to dst
|
||||
// and return the extended buffer.
|
||||
func AppendSymbolElement(dst []byte, key, symbol string) []byte {
|
||||
return AppendSymbol(AppendHeader(dst, bsontype.Symbol, key), symbol)
|
||||
}
|
||||
|
||||
// ReadSymbol will read a symbol string from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readstring(src) }
|
||||
|
||||
// AppendCodeWithScope will append code and scope to dst and return the extended buffer.
|
||||
func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte {
|
||||
length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope
|
||||
dst = appendLength(dst, length)
|
||||
|
||||
return append(appendstring(dst, code), scope...)
|
||||
}
|
||||
|
||||
// AppendCodeWithScopeElement will append a BSON code with scope element using
|
||||
// key, code, and scope to dst
|
||||
// and return the extended buffer.
|
||||
func AppendCodeWithScopeElement(dst []byte, key, code string, scope []byte) []byte {
|
||||
return AppendCodeWithScope(AppendHeader(dst, bsontype.CodeWithScope, key), code, scope)
|
||||
}
|
||||
|
||||
// ReadCodeWithScope will read code and scope from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bool) {
|
||||
length, rem, ok := ReadLength(src)
|
||||
if !ok || len(src) < int(length) {
|
||||
return "", nil, src, false
|
||||
}
|
||||
|
||||
code, rem, ok = readstring(rem)
|
||||
if !ok {
|
||||
return "", nil, src, false
|
||||
}
|
||||
|
||||
scope, rem, ok = ReadDocument(rem)
|
||||
if !ok {
|
||||
return "", nil, src, false
|
||||
}
|
||||
return code, scope, rem, true
|
||||
}
|
||||
|
||||
// AppendInt32 will append i32 to dst and return the extended buffer.
|
||||
func AppendInt32(dst []byte, i32 int32) []byte { return appendi32(dst, i32) }
|
||||
|
||||
// AppendInt32Element will append a BSON int32 element using key and i32 to dst
|
||||
// and return the extended buffer.
|
||||
func AppendInt32Element(dst []byte, key string, i32 int32) []byte {
|
||||
return AppendInt32(AppendHeader(dst, bsontype.Int32, key), i32)
|
||||
}
|
||||
|
||||
// ReadInt32 will read an int32 from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadInt32(src []byte) (int32, []byte, bool) { return readi32(src) }
|
||||
|
||||
// AppendTimestamp will append t and i to dst and return the extended buffer.
|
||||
func AppendTimestamp(dst []byte, t, i uint32) []byte {
|
||||
return appendu32(appendu32(dst, i), t) // i is the lower 4 bytes, t is the higher 4 bytes
|
||||
}
|
||||
|
||||
// AppendTimestampElement will append a BSON timestamp element using key, t, and
|
||||
// i to dst and return the extended buffer.
|
||||
func AppendTimestampElement(dst []byte, key string, t, i uint32) []byte {
|
||||
return AppendTimestamp(AppendHeader(dst, bsontype.Timestamp, key), t, i)
|
||||
}
|
||||
|
||||
// ReadTimestamp will read t and i from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) {
|
||||
i, rem, ok = readu32(src)
|
||||
if !ok {
|
||||
return 0, 0, src, false
|
||||
}
|
||||
t, rem, ok = readu32(rem)
|
||||
if !ok {
|
||||
return 0, 0, src, false
|
||||
}
|
||||
return t, i, rem, true
|
||||
}
|
||||
|
||||
// AppendInt64 will append i64 to dst and return the extended buffer.
|
||||
func AppendInt64(dst []byte, i64 int64) []byte { return appendi64(dst, i64) }
|
||||
|
||||
// AppendInt64Element will append a BSON int64 element using key and i64 to dst
|
||||
// and return the extended buffer.
|
||||
func AppendInt64Element(dst []byte, key string, i64 int64) []byte {
|
||||
return AppendInt64(AppendHeader(dst, bsontype.Int64, key), i64)
|
||||
}
|
||||
|
||||
// ReadInt64 will read an int64 from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadInt64(src []byte) (int64, []byte, bool) { return readi64(src) }
|
||||
|
||||
// AppendDecimal128 will append d128 to dst and return the extended buffer.
|
||||
func AppendDecimal128(dst []byte, d128 primitive.Decimal128) []byte {
|
||||
high, low := d128.GetBytes()
|
||||
return appendu64(appendu64(dst, low), high)
|
||||
}
|
||||
|
||||
// AppendDecimal128Element will append a BSON primitive.28 element using key and
|
||||
// d128 to dst and return the extended buffer.
|
||||
func AppendDecimal128Element(dst []byte, key string, d128 primitive.Decimal128) []byte {
|
||||
return AppendDecimal128(AppendHeader(dst, bsontype.Decimal128, key), d128)
|
||||
}
|
||||
|
||||
// ReadDecimal128 will read a primitive.Decimal128 from src. If there are not enough bytes it
|
||||
// will return false.
|
||||
func ReadDecimal128(src []byte) (primitive.Decimal128, []byte, bool) {
|
||||
l, rem, ok := readu64(src)
|
||||
if !ok {
|
||||
return primitive.Decimal128{}, src, false
|
||||
}
|
||||
|
||||
h, rem, ok := readu64(rem)
|
||||
if !ok {
|
||||
return primitive.Decimal128{}, src, false
|
||||
}
|
||||
|
||||
return primitive.NewDecimal128(h, l), rem, true
|
||||
}
|
||||
|
||||
// AppendMaxKeyElement will append a BSON max key element using key to dst
|
||||
// and return the extended buffer.
|
||||
func AppendMaxKeyElement(dst []byte, key string) []byte {
|
||||
return AppendHeader(dst, bsontype.MaxKey, key)
|
||||
}
|
||||
|
||||
// AppendMinKeyElement will append a BSON min key element using key to dst
|
||||
// and return the extended buffer.
|
||||
func AppendMinKeyElement(dst []byte, key string) []byte {
|
||||
return AppendHeader(dst, bsontype.MinKey, key)
|
||||
}
|
||||
|
||||
// EqualValue will return true if the two values are equal.
|
||||
func EqualValue(t1, t2 bsontype.Type, v1, v2 []byte) bool {
|
||||
if t1 != t2 {
|
||||
return false
|
||||
}
|
||||
v1, _, ok := readValue(v1, t1)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
v2, _, ok = readValue(v2, t2)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return bytes.Equal(v1, v2)
|
||||
}
|
||||
|
||||
// valueLength will determine the length of the next value contained in src as if it
|
||||
// is type t. The returned bool will be false if there are not enough bytes in src for
|
||||
// a value of type t.
|
||||
func valueLength(src []byte, t bsontype.Type) (int32, bool) {
|
||||
var length int32
|
||||
ok := true
|
||||
switch t {
|
||||
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
|
||||
length, _, ok = ReadLength(src)
|
||||
case bsontype.Binary:
|
||||
length, _, ok = ReadLength(src)
|
||||
length += 4 + 1 // binary length + subtype byte
|
||||
case bsontype.Boolean:
|
||||
length = 1
|
||||
case bsontype.DBPointer:
|
||||
length, _, ok = ReadLength(src)
|
||||
length += 4 + 12 // string length + ObjectID length
|
||||
case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp:
|
||||
length = 8
|
||||
case bsontype.Decimal128:
|
||||
length = 16
|
||||
case bsontype.Int32:
|
||||
length = 4
|
||||
case bsontype.JavaScript, bsontype.String, bsontype.Symbol:
|
||||
length, _, ok = ReadLength(src)
|
||||
length += 4
|
||||
case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined:
|
||||
length = 0
|
||||
case bsontype.ObjectID:
|
||||
length = 12
|
||||
case bsontype.Regex:
|
||||
regex := bytes.IndexByte(src, 0x00)
|
||||
if regex < 0 {
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
pattern := bytes.IndexByte(src[regex+1:], 0x00)
|
||||
if pattern < 0 {
|
||||
ok = false
|
||||
break
|
||||
}
|
||||
length = int32(int64(regex) + 1 + int64(pattern) + 1)
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
|
||||
return length, ok
|
||||
}
|
||||
|
||||
func readValue(src []byte, t bsontype.Type) ([]byte, []byte, bool) {
|
||||
length, ok := valueLength(src, t)
|
||||
if !ok || int(length) > len(src) {
|
||||
return nil, src, false
|
||||
}
|
||||
|
||||
return src[:length], src[length:], true
|
||||
}
|
||||
|
||||
// ReserveLength reserves the space required for length and returns the index where to write the length
|
||||
// and the []byte with reserved space.
|
||||
func ReserveLength(dst []byte) (int32, []byte) {
|
||||
index := len(dst)
|
||||
return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00)
|
||||
}
|
||||
|
||||
// UpdateLength updates the length at index with length and returns the []byte.
|
||||
func UpdateLength(dst []byte, index, length int32) []byte {
|
||||
dst[index] = byte(length)
|
||||
dst[index+1] = byte(length >> 8)
|
||||
dst[index+2] = byte(length >> 16)
|
||||
dst[index+3] = byte(length >> 24)
|
||||
return dst
|
||||
}
|
||||
|
||||
func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) }
|
||||
|
||||
func appendi32(dst []byte, i32 int32) []byte {
|
||||
return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24))
|
||||
}
|
||||
|
||||
// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
|
||||
// there aren't enough bytes to read a valid length, src is returned unomdified and the returned
|
||||
// bool will be false.
|
||||
func ReadLength(src []byte) (int32, []byte, bool) {
|
||||
ln, src, ok := readi32(src)
|
||||
if ln < 0 {
|
||||
return ln, src, false
|
||||
}
|
||||
return ln, src, ok
|
||||
}
|
||||
|
||||
func readi32(src []byte) (int32, []byte, bool) {
|
||||
if len(src) < 4 {
|
||||
return 0, src, false
|
||||
}
|
||||
return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true
|
||||
}
|
||||
|
||||
func appendi64(dst []byte, i64 int64) []byte {
|
||||
return append(dst,
|
||||
byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24),
|
||||
byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56),
|
||||
)
|
||||
}
|
||||
|
||||
func readi64(src []byte) (int64, []byte, bool) {
|
||||
if len(src) < 8 {
|
||||
return 0, src, false
|
||||
}
|
||||
i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 |
|
||||
int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56)
|
||||
return i64, src[8:], true
|
||||
}
|
||||
|
||||
func appendu32(dst []byte, u32 uint32) []byte {
|
||||
return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24))
|
||||
}
|
||||
|
||||
func readu32(src []byte) (uint32, []byte, bool) {
|
||||
if len(src) < 4 {
|
||||
return 0, src, false
|
||||
}
|
||||
|
||||
return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true
|
||||
}
|
||||
|
||||
func appendu64(dst []byte, u64 uint64) []byte {
|
||||
return append(dst,
|
||||
byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24),
|
||||
byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56),
|
||||
)
|
||||
}
|
||||
|
||||
func readu64(src []byte) (uint64, []byte, bool) {
|
||||
if len(src) < 8 {
|
||||
return 0, src, false
|
||||
}
|
||||
u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 |
|
||||
uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56)
|
||||
return u64, src[8:], true
|
||||
}
|
||||
|
||||
// keep in sync with readcstringbytes
|
||||
func readcstring(src []byte) (string, []byte, bool) {
|
||||
idx := bytes.IndexByte(src, 0x00)
|
||||
if idx < 0 {
|
||||
return "", src, false
|
||||
}
|
||||
return string(src[:idx]), src[idx+1:], true
|
||||
}
|
||||
|
||||
// keep in sync with readcstring
|
||||
func readcstringbytes(src []byte) ([]byte, []byte, bool) {
|
||||
idx := bytes.IndexByte(src, 0x00)
|
||||
if idx < 0 {
|
||||
return nil, src, false
|
||||
}
|
||||
return src[:idx], src[idx+1:], true
|
||||
}
|
||||
|
||||
func appendstring(dst []byte, s string) []byte {
|
||||
l := int32(len(s) + 1)
|
||||
dst = appendLength(dst, l)
|
||||
dst = append(dst, s...)
|
||||
return append(dst, 0x00)
|
||||
}
|
||||
|
||||
func readstring(src []byte) (string, []byte, bool) {
|
||||
l, rem, ok := ReadLength(src)
|
||||
if !ok {
|
||||
return "", src, false
|
||||
}
|
||||
if len(src[4:]) < int(l) || l == 0 {
|
||||
return "", src, false
|
||||
}
|
||||
|
||||
return string(rem[:l-1]), rem[l:], true
|
||||
}
|
||||
|
||||
// readLengthBytes attempts to read a length and that number of bytes. This
|
||||
// function requires that the length include the four bytes for itself.
|
||||
func readLengthBytes(src []byte) ([]byte, []byte, bool) {
|
||||
l, _, ok := ReadLength(src)
|
||||
if !ok {
|
||||
return nil, src, false
|
||||
}
|
||||
if len(src) < int(l) {
|
||||
return nil, src, false
|
||||
}
|
||||
return src[:l], src[l:], true
|
||||
}
|
||||
|
||||
func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte {
|
||||
dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes
|
||||
dst = append(dst, subtype)
|
||||
dst = appendLength(dst, int32(len(b)))
|
||||
return append(dst, b...)
|
||||
}
|
||||
|
||||
func isValidCString(cs string) bool {
|
||||
return !strings.ContainsRune(cs, '\x00')
|
||||
}
|
||||
959
mongo/x/bsonx/bsoncore/bsoncore_test.go
Normal file
959
mongo/x/bsonx/bsoncore/bsoncore_test.go
Normal file
@@ -0,0 +1,959 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
)
|
||||
|
||||
func compareErrors(err1, err2 error) bool {
|
||||
if err1 == nil && err2 == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
if err1 == nil || err2 == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if err1.Error() != err2.Error() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestAppend(t *testing.T) {
|
||||
bits := math.Float64bits(3.14159)
|
||||
pi := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(pi, bits)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{}
|
||||
params []interface{}
|
||||
expected []byte
|
||||
}{
|
||||
{
|
||||
"AppendType",
|
||||
AppendType,
|
||||
[]interface{}{make([]byte, 0), bsontype.Null},
|
||||
[]byte{byte(bsontype.Null)},
|
||||
},
|
||||
{
|
||||
"AppendKey",
|
||||
AppendKey,
|
||||
[]interface{}{make([]byte, 0), "foobar"},
|
||||
[]byte{'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendHeader",
|
||||
AppendHeader,
|
||||
[]interface{}{make([]byte, 0), bsontype.Null, "foobar"},
|
||||
[]byte{byte(bsontype.Null), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendValueElement",
|
||||
AppendValueElement,
|
||||
[]interface{}{make([]byte, 0), "testing", Value{Type: bsontype.Boolean, Data: []byte{0x01}}},
|
||||
[]byte{byte(bsontype.Boolean), 't', 'e', 's', 't', 'i', 'n', 'g', 0x00, 0x01},
|
||||
},
|
||||
{
|
||||
"AppendDouble",
|
||||
AppendDouble,
|
||||
[]interface{}{make([]byte, 0), float64(3.14159)},
|
||||
pi,
|
||||
},
|
||||
{
|
||||
"AppendDoubleElement",
|
||||
AppendDoubleElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", float64(3.14159)},
|
||||
append([]byte{byte(bsontype.Double), 'f', 'o', 'o', 'b', 'a', 'r', 0x00}, pi...),
|
||||
},
|
||||
{
|
||||
"AppendString",
|
||||
AppendString,
|
||||
[]interface{}{make([]byte, 0), "barbaz"},
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendStringElement",
|
||||
AppendStringElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", "barbaz"},
|
||||
[]byte{byte(bsontype.String),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendDocument",
|
||||
AppendDocument,
|
||||
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
[]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendDocumentElement",
|
||||
AppendDocumentElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
[]byte{byte(bsontype.EmbeddedDocument),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x05, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendArray",
|
||||
AppendArray,
|
||||
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
[]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendArrayElement",
|
||||
AppendArrayElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
[]byte{byte(bsontype.Array),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x05, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"BuildArray",
|
||||
BuildArray,
|
||||
[]interface{}{make([]byte, 0), Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}},
|
||||
[]byte{
|
||||
0x10, 0x00, 0x00, 0x00,
|
||||
byte(bsontype.Double), '0', 0x00,
|
||||
pi[0], pi[1], pi[2], pi[3], pi[4], pi[5], pi[6], pi[7],
|
||||
0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"BuildArrayElement",
|
||||
BuildArrayElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}},
|
||||
[]byte{byte(bsontype.Array),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x10, 0x00, 0x00, 0x00,
|
||||
byte(bsontype.Double), '0', 0x00,
|
||||
pi[0], pi[1], pi[2], pi[3], pi[4], pi[5], pi[6], pi[7],
|
||||
0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendBinary Subtype2",
|
||||
AppendBinary,
|
||||
[]interface{}{make([]byte, 0), byte(0x02), []byte{0x01, 0x02, 0x03}},
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 0x02, 0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03},
|
||||
},
|
||||
{
|
||||
"AppendBinaryElement Subtype 2",
|
||||
AppendBinaryElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", byte(0x02), []byte{0x01, 0x02, 0x03}},
|
||||
[]byte{byte(bsontype.Binary),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x07, 0x00, 0x00, 0x00,
|
||||
0x02,
|
||||
0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendBinary",
|
||||
AppendBinary,
|
||||
[]interface{}{make([]byte, 0), byte(0xFF), []byte{0x01, 0x02, 0x03}},
|
||||
[]byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
|
||||
},
|
||||
{
|
||||
"AppendBinaryElement",
|
||||
AppendBinaryElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", byte(0xFF), []byte{0x01, 0x02, 0x03}},
|
||||
[]byte{byte(bsontype.Binary),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x03, 0x00, 0x00, 0x00,
|
||||
0xFF,
|
||||
0x01, 0x02, 0x03,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendUndefinedElement",
|
||||
AppendUndefinedElement,
|
||||
[]interface{}{make([]byte, 0), "foobar"},
|
||||
[]byte{byte(bsontype.Undefined), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendObjectID",
|
||||
AppendObjectID,
|
||||
[]interface{}{
|
||||
make([]byte, 0),
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
{
|
||||
"AppendObjectIDElement",
|
||||
AppendObjectIDElement,
|
||||
[]interface{}{
|
||||
make([]byte, 0), "foobar",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
[]byte{byte(bsontype.ObjectID),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendBoolean (true)",
|
||||
AppendBoolean,
|
||||
[]interface{}{make([]byte, 0), true},
|
||||
[]byte{0x01},
|
||||
},
|
||||
{
|
||||
"AppendBoolean (false)",
|
||||
AppendBoolean,
|
||||
[]interface{}{make([]byte, 0), false},
|
||||
[]byte{0x00},
|
||||
},
|
||||
{
|
||||
"AppendBooleanElement",
|
||||
AppendBooleanElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", true},
|
||||
[]byte{byte(bsontype.Boolean), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x01},
|
||||
},
|
||||
{
|
||||
"AppendDateTime",
|
||||
AppendDateTime,
|
||||
[]interface{}{make([]byte, 0), int64(256)},
|
||||
[]byte{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendDateTimeElement",
|
||||
AppendDateTimeElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", int64(256)},
|
||||
[]byte{byte(bsontype.DateTime), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendNullElement",
|
||||
AppendNullElement,
|
||||
[]interface{}{make([]byte, 0), "foobar"},
|
||||
[]byte{byte(bsontype.Null), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendRegex",
|
||||
AppendRegex,
|
||||
[]interface{}{make([]byte, 0), "bar", "baz"},
|
||||
[]byte{'b', 'a', 'r', 0x00, 'b', 'a', 'z', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendRegexElement",
|
||||
AppendRegexElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", "bar", "baz"},
|
||||
[]byte{byte(bsontype.Regex),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
'b', 'a', 'r', 0x00, 'b', 'a', 'z', 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendDBPointer",
|
||||
AppendDBPointer,
|
||||
[]interface{}{
|
||||
make([]byte, 0),
|
||||
"foobar",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
[]byte{
|
||||
0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendDBPointerElement",
|
||||
AppendDBPointerElement,
|
||||
[]interface{}{
|
||||
make([]byte, 0), "foobar",
|
||||
"barbaz",
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
},
|
||||
[]byte{byte(bsontype.DBPointer),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendJavaScript",
|
||||
AppendJavaScript,
|
||||
[]interface{}{make([]byte, 0), "barbaz"},
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendJavaScriptElement",
|
||||
AppendJavaScriptElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", "barbaz"},
|
||||
[]byte{byte(bsontype.JavaScript),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendSymbol",
|
||||
AppendSymbol,
|
||||
[]interface{}{make([]byte, 0), "barbaz"},
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendSymbolElement",
|
||||
AppendSymbolElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", "barbaz"},
|
||||
[]byte{byte(bsontype.Symbol),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendCodeWithScope",
|
||||
AppendCodeWithScope,
|
||||
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
[]byte{0x05, 0x00, 0x00, 0x00, 0x00,
|
||||
0x14, 0x00, 0x00, 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x05, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendCodeWithScopeElement",
|
||||
AppendCodeWithScopeElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", "barbaz", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
[]byte{byte(bsontype.CodeWithScope),
|
||||
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x14, 0x00, 0x00, 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
|
||||
0x05, 0x00, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendInt32",
|
||||
AppendInt32,
|
||||
[]interface{}{make([]byte, 0), int32(256)},
|
||||
[]byte{0x00, 0x01, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendInt32Element",
|
||||
AppendInt32Element,
|
||||
[]interface{}{make([]byte, 0), "foobar", int32(256)},
|
||||
[]byte{byte(bsontype.Int32), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x01, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendTimestamp",
|
||||
AppendTimestamp,
|
||||
[]interface{}{make([]byte, 0), uint32(65536), uint32(256)},
|
||||
[]byte{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendTimestampElement",
|
||||
AppendTimestampElement,
|
||||
[]interface{}{make([]byte, 0), "foobar", uint32(65536), uint32(256)},
|
||||
[]byte{byte(bsontype.Timestamp), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendInt64",
|
||||
AppendInt64,
|
||||
[]interface{}{make([]byte, 0), int64(4294967296)},
|
||||
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendInt64Element",
|
||||
AppendInt64Element,
|
||||
[]interface{}{make([]byte, 0), "foobar", int64(4294967296)},
|
||||
[]byte{byte(bsontype.Int64), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
|
||||
},
|
||||
{
|
||||
"AppendDecimal128",
|
||||
AppendDecimal128,
|
||||
[]interface{}{make([]byte, 0), primitive.NewDecimal128(4294967296, 65536)},
|
||||
[]byte{
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendDecimal128Element",
|
||||
AppendDecimal128Element,
|
||||
[]interface{}{make([]byte, 0), "foobar", primitive.NewDecimal128(4294967296, 65536)},
|
||||
[]byte{
|
||||
byte(bsontype.Decimal128), 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
{
|
||||
"AppendMaxKeyElement",
|
||||
AppendMaxKeyElement,
|
||||
[]interface{}{make([]byte, 0), "foobar"},
|
||||
[]byte{byte(bsontype.MaxKey), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
},
|
||||
{
|
||||
"AppendMinKeyElement",
|
||||
AppendMinKeyElement,
|
||||
[]interface{}{make([]byte, 0), "foobar"},
|
||||
[]byte{byte(bsontype.MinKey), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("fn must be of kind Func but is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != len(tc.params) {
|
||||
t.Fatalf("tc.params must match the number of params in tc.fn. params %d; fn %d", fn.Type().NumIn(), len(tc.params))
|
||||
}
|
||||
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf([]byte{}) {
|
||||
t.Fatalf("fn must have one return parameter and it must be a []byte.")
|
||||
}
|
||||
params := make([]reflect.Value, 0, len(tc.params))
|
||||
for _, param := range tc.params {
|
||||
params = append(params, reflect.ValueOf(param))
|
||||
}
|
||||
results := fn.Call(params)
|
||||
got := results[0].Interface().([]byte)
|
||||
want := tc.expected
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Did not receive expected bytes. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
bits := math.Float64bits(3.14159)
|
||||
pi := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(pi, bits)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{}
|
||||
param []byte
|
||||
expected []interface{}
|
||||
}{
|
||||
{
|
||||
"ReadType/not enough bytes",
|
||||
ReadType,
|
||||
[]byte{},
|
||||
[]interface{}{bsontype.Type(0), []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadType/success",
|
||||
ReadType,
|
||||
[]byte{0x0A},
|
||||
[]interface{}{bsontype.Null, []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadKey/not enough bytes",
|
||||
ReadKey,
|
||||
[]byte{},
|
||||
[]interface{}{"", []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadKey/success",
|
||||
ReadKey,
|
||||
[]byte{'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
[]interface{}{"foobar", []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadHeader/not enough bytes (type)",
|
||||
ReadHeader,
|
||||
[]byte{},
|
||||
[]interface{}{bsontype.Type(0), "", []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadHeader/not enough bytes (key)",
|
||||
ReadHeader,
|
||||
[]byte{0x0A, 'f', 'o', 'o'},
|
||||
[]interface{}{bsontype.Type(0), "", []byte{0x0A, 'f', 'o', 'o'}, false},
|
||||
},
|
||||
{
|
||||
"ReadHeader/success",
|
||||
ReadHeader,
|
||||
[]byte{0x0A, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
[]interface{}{bsontype.Null, "foobar", []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadDouble/not enough bytes",
|
||||
ReadDouble,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04},
|
||||
[]interface{}{float64(0.00), []byte{0x01, 0x02, 0x03, 0x04}, false},
|
||||
},
|
||||
{
|
||||
"ReadDouble/success",
|
||||
ReadDouble,
|
||||
pi,
|
||||
[]interface{}{float64(3.14159), []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadString/not enough bytes (length)",
|
||||
ReadString,
|
||||
[]byte{},
|
||||
[]interface{}{"", []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadString/not enough bytes (value)",
|
||||
ReadString,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{"", []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadString/success",
|
||||
ReadString,
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
[]interface{}{"foobar", []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadDocument/not enough bytes (length)",
|
||||
ReadDocument,
|
||||
[]byte{},
|
||||
[]interface{}{Document(nil), []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadDocument/not enough bytes (value)",
|
||||
ReadDocument,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{Document(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadDocument/success",
|
||||
ReadDocument,
|
||||
[]byte{0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00},
|
||||
[]interface{}{Document{0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00}, []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadArray/not enough bytes (length)",
|
||||
ReadArray,
|
||||
[]byte{},
|
||||
[]interface{}{Array(nil), []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadArray/not enough bytes (value)",
|
||||
ReadArray,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{Array(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadArray/success",
|
||||
ReadArray,
|
||||
[]byte{0x08, 0x00, 0x00, 0x00, 0x0A, '0', 0x00, 0x00},
|
||||
[]interface{}{Array{0x08, 0x00, 0x00, 0x00, 0x0A, '0', 0x00, 0x00}, []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadBinary/not enough bytes (length)",
|
||||
ReadBinary,
|
||||
[]byte{},
|
||||
[]interface{}{byte(0), []byte(nil), []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadBinary/not enough bytes (subtype)",
|
||||
ReadBinary,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{byte(0), []byte(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadBinary/not enough bytes (value)",
|
||||
ReadBinary,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00, 0x00},
|
||||
[]interface{}{byte(0), []byte(nil), []byte{0x0F, 0x00, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadBinary/not enough bytes (subtype 2 length)",
|
||||
ReadBinary,
|
||||
[]byte{0x03, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00},
|
||||
[]interface{}{byte(0), []byte(nil), []byte{0x03, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadBinary/not enough bytes (subtype 2 value)",
|
||||
ReadBinary,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00, 0x00, 0x01, 0x02},
|
||||
[]interface{}{
|
||||
byte(0), []byte(nil),
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00, 0x00, 0x01, 0x02}, false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"ReadBinary/success (subtype 2)",
|
||||
ReadBinary,
|
||||
[]byte{0x06, 0x00, 0x00, 0x00, 0x02, 0x02, 0x00, 0x00, 0x00, 0x01, 0x02},
|
||||
[]interface{}{byte(0x02), []byte{0x01, 0x02}, []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadBinary/success",
|
||||
ReadBinary,
|
||||
[]byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
|
||||
[]interface{}{byte(0xFF), []byte{0x01, 0x02, 0x03}, []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadObjectID/not enough bytes",
|
||||
ReadObjectID,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06},
|
||||
[]interface{}{primitive.ObjectID{}, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}, false},
|
||||
},
|
||||
{
|
||||
"ReadObjectID/success",
|
||||
ReadObjectID,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
[]interface{}{
|
||||
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
[]byte{}, true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"ReadBoolean/not enough bytes",
|
||||
ReadBoolean,
|
||||
[]byte{},
|
||||
[]interface{}{false, []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadBoolean/success",
|
||||
ReadBoolean,
|
||||
[]byte{0x01},
|
||||
[]interface{}{true, []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadDateTime/not enough bytes",
|
||||
ReadDateTime,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04},
|
||||
[]interface{}{int64(0), []byte{0x01, 0x02, 0x03, 0x04}, false},
|
||||
},
|
||||
{
|
||||
"ReadDateTime/success",
|
||||
ReadDateTime,
|
||||
[]byte{0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
[]interface{}{int64(65536), []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadRegex/not enough bytes (pattern)",
|
||||
ReadRegex,
|
||||
[]byte{},
|
||||
[]interface{}{"", "", []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadRegex/not enough bytes (options)",
|
||||
ReadRegex,
|
||||
[]byte{'f', 'o', 'o', 0x00},
|
||||
[]interface{}{"", "", []byte{'f', 'o', 'o', 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadRegex/success",
|
||||
ReadRegex,
|
||||
[]byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r', 0x00},
|
||||
[]interface{}{"foo", "bar", []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadDBPointer/not enough bytes (ns)",
|
||||
ReadDBPointer,
|
||||
[]byte{},
|
||||
[]interface{}{"", primitive.ObjectID{}, []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadDBPointer/not enough bytes (objectID)",
|
||||
ReadDBPointer,
|
||||
[]byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00},
|
||||
[]interface{}{"", primitive.ObjectID{}, []byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadDBPointer/success",
|
||||
ReadDBPointer,
|
||||
[]byte{
|
||||
0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00,
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
|
||||
},
|
||||
[]interface{}{
|
||||
"foo", primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
|
||||
[]byte{}, true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"ReadJavaScript/not enough bytes (length)",
|
||||
ReadJavaScript,
|
||||
[]byte{},
|
||||
[]interface{}{"", []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadJavaScript/not enough bytes (value)",
|
||||
ReadJavaScript,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{"", []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadJavaScript/success",
|
||||
ReadJavaScript,
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
[]interface{}{"foobar", []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadSymbol/not enough bytes (length)",
|
||||
ReadSymbol,
|
||||
[]byte{},
|
||||
[]interface{}{"", []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadSymbol/not enough bytes (value)",
|
||||
ReadSymbol,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{"", []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadSymbol/success",
|
||||
ReadSymbol,
|
||||
[]byte{0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
|
||||
[]interface{}{"foobar", []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadCodeWithScope/ not enough bytes (length)",
|
||||
ReadCodeWithScope,
|
||||
[]byte{},
|
||||
[]interface{}{"", []byte(nil), []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadCodeWithScope/ not enough bytes (value)",
|
||||
ReadCodeWithScope,
|
||||
[]byte{0x0F, 0x00, 0x00, 0x00},
|
||||
[]interface{}{"", []byte(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadCodeWithScope/not enough bytes (code value)",
|
||||
ReadCodeWithScope,
|
||||
[]byte{
|
||||
0x0C, 0x00, 0x00, 0x00,
|
||||
0x0F, 0x00, 0x00, 0x00,
|
||||
'f', 'o', 'o', 0x00,
|
||||
},
|
||||
[]interface{}{
|
||||
"", []byte(nil),
|
||||
[]byte{
|
||||
0x0C, 0x00, 0x00, 0x00,
|
||||
0x0F, 0x00, 0x00, 0x00,
|
||||
'f', 'o', 'o', 0x00,
|
||||
},
|
||||
false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"ReadCodeWithScope/success",
|
||||
ReadCodeWithScope,
|
||||
[]byte{
|
||||
0x19, 0x00, 0x00, 0x00,
|
||||
0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
|
||||
0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00,
|
||||
},
|
||||
[]interface{}{
|
||||
"foobar", []byte{0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00},
|
||||
[]byte{}, true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"ReadInt32/not enough bytes",
|
||||
ReadInt32,
|
||||
[]byte{0x01},
|
||||
[]interface{}{int32(0), []byte{0x01}, false},
|
||||
},
|
||||
{
|
||||
"ReadInt32/success",
|
||||
ReadInt32,
|
||||
[]byte{0x00, 0x01, 0x00, 0x00},
|
||||
[]interface{}{int32(256), []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadTimestamp/not enough bytes (increment)",
|
||||
ReadTimestamp,
|
||||
[]byte{},
|
||||
[]interface{}{uint32(0), uint32(0), []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadTimestamp/not enough bytes (timestamp)",
|
||||
ReadTimestamp,
|
||||
[]byte{0x00, 0x01, 0x00, 0x00},
|
||||
[]interface{}{uint32(0), uint32(0), []byte{0x00, 0x01, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadTimestamp/success",
|
||||
ReadTimestamp,
|
||||
[]byte{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
|
||||
[]interface{}{uint32(65536), uint32(256), []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadInt64/not enough bytes",
|
||||
ReadInt64,
|
||||
[]byte{0x01},
|
||||
[]interface{}{int64(0), []byte{0x01}, false},
|
||||
},
|
||||
{
|
||||
"ReadInt64/success",
|
||||
ReadInt64,
|
||||
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
|
||||
[]interface{}{int64(4294967296), []byte{}, true},
|
||||
},
|
||||
{
|
||||
"ReadDecimal128/not enough bytes (low)",
|
||||
ReadDecimal128,
|
||||
[]byte{},
|
||||
[]interface{}{primitive.Decimal128{}, []byte{}, false},
|
||||
},
|
||||
{
|
||||
"ReadDecimal128/not enough bytes (high)",
|
||||
ReadDecimal128,
|
||||
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
|
||||
[]interface{}{primitive.Decimal128{}, []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, false},
|
||||
},
|
||||
{
|
||||
"ReadDecimal128/success",
|
||||
ReadDecimal128,
|
||||
[]byte{
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
|
||||
},
|
||||
[]interface{}{primitive.NewDecimal128(4294967296, 16777216), []byte{}, true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != 1 || fn.Type().In(0) != reflect.TypeOf([]byte{}) {
|
||||
t.Fatalf("fn must have one parameter and it must be a []byte.")
|
||||
}
|
||||
results := fn.Call([]reflect.Value{reflect.ValueOf(tc.param)})
|
||||
if len(results) != len(tc.expected) {
|
||||
t.Fatalf("Length of results does not match. got %d; want %d", len(results), len(tc.expected))
|
||||
}
|
||||
for idx := range results {
|
||||
got := results[idx].Interface()
|
||||
want := tc.expected[idx]
|
||||
if !cmp.Equal(got, want, cmp.Comparer(compareDecimal128)) {
|
||||
t.Errorf("Result %d does not match. got %v; want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
elems [][]byte
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
"one element",
|
||||
[][]byte{AppendDoubleElement(nil, "pi", 3.14159)},
|
||||
[]byte{0x11, 0x00, 0x00, 0x00, 0x1, 0x70, 0x69, 0x00, 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x9, 0x40, 0x00},
|
||||
},
|
||||
{
|
||||
"two elements",
|
||||
[][]byte{AppendDoubleElement(nil, "pi", 3.14159), AppendStringElement(nil, "hello", "world!!")},
|
||||
[]byte{
|
||||
0x24, 0x00, 0x00, 0x00, 0x01, 0x70, 0x69, 0x00, 0x6e, 0x86, 0x1b, 0xf0,
|
||||
0xf9, 0x21, 0x09, 0x40, 0x02, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x00, 0x08,
|
||||
0x00, 0x00, 0x00, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x21, 0x00, 0x00,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("BuildDocument", func(t *testing.T) {
|
||||
elems := make([]byte, 0)
|
||||
for _, elem := range tc.elems {
|
||||
elems = append(elems, elem...)
|
||||
}
|
||||
got := BuildDocument(nil, elems)
|
||||
if !bytes.Equal(got, tc.want) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
t.Run("BuildDocumentFromElements", func(t *testing.T) {
|
||||
got := BuildDocumentFromElements(nil, tc.elems...)
|
||||
if !bytes.Equal(got, tc.want) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullBytes(t *testing.T) {
|
||||
// Helper function to execute the provided callback and assert that it panics with the expected message. The
|
||||
// createBSONFn callback should create a BSON document/array/value and return the stringified version.
|
||||
assertBSONCreationPanics := func(t *testing.T, createBSONFn func(), expected string) {
|
||||
t.Helper()
|
||||
|
||||
defer func() {
|
||||
got := recover()
|
||||
assert.Equal(t, expected, got, "expected panic with error %v, got error %v", expected, got)
|
||||
}()
|
||||
createBSONFn()
|
||||
}
|
||||
|
||||
t.Run("element keys", func(t *testing.T) {
|
||||
createDocFn := func() {
|
||||
NewDocumentBuilder().AppendString("a\x00", "foo")
|
||||
}
|
||||
assertBSONCreationPanics(t, createDocFn, invalidKeyPanicMsg)
|
||||
})
|
||||
t.Run("regex values", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
pattern string
|
||||
options string
|
||||
}{
|
||||
{"null bytes in pattern", "a\x00", "i"},
|
||||
{"null bytes in options", "pattern", "i\x00"},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+"-AppendRegexElement", func(t *testing.T) {
|
||||
createDocFn := func() {
|
||||
AppendRegexElement(nil, "foo", tc.pattern, tc.options)
|
||||
}
|
||||
assertBSONCreationPanics(t, createDocFn, invalidRegexPanicMsg)
|
||||
})
|
||||
t.Run(tc.name+"-AppendRegex", func(t *testing.T) {
|
||||
createValFn := func() {
|
||||
AppendRegex(nil, tc.pattern, tc.options)
|
||||
}
|
||||
assertBSONCreationPanics(t, createValFn, invalidRegexPanicMsg)
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("sub document field name", func(t *testing.T) {
|
||||
createDocFn := func() {
|
||||
NewDocumentBuilder().StartDocument("foobar").AppendDocument("a\x00", []byte("foo")).FinishDocument()
|
||||
}
|
||||
assertBSONCreationPanics(t, createDocFn, invalidKeyPanicMsg)
|
||||
})
|
||||
}
|
||||
|
||||
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
|
||||
d1H, d1L := d1.GetBytes()
|
||||
d2H, d2L := d2.GetBytes()
|
||||
|
||||
if d1H != d2H {
|
||||
return false
|
||||
}
|
||||
|
||||
if d1L != d2L {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
386
mongo/x/bsonx/bsoncore/document.go
Normal file
386
mongo/x/bsonx/bsoncore/document.go
Normal file
@@ -0,0 +1,386 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
// ValidationError is an error type returned when attempting to validate a document or array.
|
||||
type ValidationError string
|
||||
|
||||
func (ve ValidationError) Error() string { return string(ve) }
|
||||
|
||||
// NewDocumentLengthError creates and returns an error for when the length of a document exceeds the
|
||||
// bytes available.
|
||||
func NewDocumentLengthError(length, rem int) error {
|
||||
return lengthError("document", length, rem)
|
||||
}
|
||||
|
||||
func lengthError(bufferType string, length, rem int) error {
|
||||
return ValidationError(fmt.Sprintf("%v length exceeds available bytes. length=%d remainingBytes=%d",
|
||||
bufferType, length, rem))
|
||||
}
|
||||
|
||||
// InsufficientBytesError indicates that there were not enough bytes to read the next component.
|
||||
type InsufficientBytesError struct {
|
||||
Source []byte
|
||||
Remaining []byte
|
||||
}
|
||||
|
||||
// NewInsufficientBytesError creates a new InsufficientBytesError with the given Document and
|
||||
// remaining bytes.
|
||||
func NewInsufficientBytesError(src, rem []byte) InsufficientBytesError {
|
||||
return InsufficientBytesError{Source: src, Remaining: rem}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (ibe InsufficientBytesError) Error() string {
|
||||
return "too few bytes to read next component"
|
||||
}
|
||||
|
||||
// Equal checks that err2 also is an ErrTooSmall.
|
||||
func (ibe InsufficientBytesError) Equal(err2 error) bool {
|
||||
switch err2.(type) {
|
||||
case InsufficientBytesError:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidDepthTraversalError is returned when attempting a recursive Lookup when one component of
|
||||
// the path is neither an embedded document nor an array.
|
||||
type InvalidDepthTraversalError struct {
|
||||
Key string
|
||||
Type bsontype.Type
|
||||
}
|
||||
|
||||
func (idte InvalidDepthTraversalError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"attempt to traverse into %s, but it's type is %s, not %s nor %s",
|
||||
idte.Key, idte.Type, bsontype.EmbeddedDocument, bsontype.Array,
|
||||
)
|
||||
}
|
||||
|
||||
// ErrMissingNull is returned when a document or array's last byte is not null.
|
||||
const ErrMissingNull ValidationError = "document or array end is missing null byte"
|
||||
|
||||
// ErrInvalidLength indicates that a length in a binary representation of a BSON document or array
|
||||
// is invalid.
|
||||
const ErrInvalidLength ValidationError = "document or array length is invalid"
|
||||
|
||||
// ErrNilReader indicates that an operation was attempted on a nil io.Reader.
|
||||
var ErrNilReader = errors.New("nil reader")
|
||||
|
||||
// ErrEmptyKey indicates that no key was provided to a Lookup method.
|
||||
var ErrEmptyKey = errors.New("empty key provided")
|
||||
|
||||
// ErrElementNotFound indicates that an Element matching a certain condition does not exist.
|
||||
var ErrElementNotFound = errors.New("element not found")
|
||||
|
||||
// ErrOutOfBounds indicates that an index provided to access something was invalid.
|
||||
var ErrOutOfBounds = errors.New("out of bounds")
|
||||
|
||||
// Document is a raw bytes representation of a BSON document.
|
||||
type Document []byte
|
||||
|
||||
// NewDocumentFromReader reads a document from r. This function will only validate the length is
|
||||
// correct and that the document ends with a null byte.
|
||||
func NewDocumentFromReader(r io.Reader) (Document, error) {
|
||||
return newBufferFromReader(r)
|
||||
}
|
||||
|
||||
func newBufferFromReader(r io.Reader) ([]byte, error) {
|
||||
if r == nil {
|
||||
return nil, ErrNilReader
|
||||
}
|
||||
|
||||
var lengthBytes [4]byte
|
||||
|
||||
// ReadFull guarantees that we will have read at least len(lengthBytes) if err == nil
|
||||
_, err := io.ReadFull(r, lengthBytes[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, _, _ := readi32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length
|
||||
if length < 0 {
|
||||
return nil, ErrInvalidLength
|
||||
}
|
||||
buffer := make([]byte, length)
|
||||
|
||||
copy(buffer, lengthBytes[:])
|
||||
|
||||
_, err = io.ReadFull(r, buffer[4:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buffer[length-1] != 0x00 {
|
||||
return nil, ErrMissingNull
|
||||
}
|
||||
|
||||
return buffer, nil
|
||||
}
|
||||
|
||||
// Lookup searches the document, potentially recursively, for the given key. If there are multiple
|
||||
// keys provided, this method will recurse down, as long as the top and intermediate nodes are
|
||||
// either documents or arrays. If an error occurs or if the value doesn't exist, an empty Value is
|
||||
// returned.
|
||||
func (d Document) Lookup(key ...string) Value {
|
||||
val, _ := d.LookupErr(key...)
|
||||
return val
|
||||
}
|
||||
|
||||
// LookupErr is the same as Lookup, except it returns an error in addition to an empty Value.
|
||||
func (d Document) LookupErr(key ...string) (Value, error) {
|
||||
if len(key) < 1 {
|
||||
return Value{}, ErrEmptyKey
|
||||
}
|
||||
length, rem, ok := ReadLength(d)
|
||||
if !ok {
|
||||
return Value{}, NewInsufficientBytesError(d, rem)
|
||||
}
|
||||
|
||||
length -= 4
|
||||
|
||||
var elem Element
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return Value{}, NewInsufficientBytesError(d, rem)
|
||||
}
|
||||
// We use `KeyBytes` rather than `Key` to avoid a needless string alloc.
|
||||
if string(elem.KeyBytes()) != key[0] {
|
||||
continue
|
||||
}
|
||||
if len(key) > 1 {
|
||||
tt := bsontype.Type(elem[0])
|
||||
switch tt {
|
||||
case bsontype.EmbeddedDocument:
|
||||
val, err := elem.Value().Document().LookupErr(key[1:]...)
|
||||
if err != nil {
|
||||
return Value{}, err
|
||||
}
|
||||
return val, nil
|
||||
case bsontype.Array:
|
||||
// Convert to Document to continue Lookup recursion.
|
||||
val, err := Document(elem.Value().Array()).LookupErr(key[1:]...)
|
||||
if err != nil {
|
||||
return Value{}, err
|
||||
}
|
||||
return val, nil
|
||||
default:
|
||||
return Value{}, InvalidDepthTraversalError{Key: elem.Key(), Type: tt}
|
||||
}
|
||||
}
|
||||
return elem.ValueErr()
|
||||
}
|
||||
return Value{}, ErrElementNotFound
|
||||
}
|
||||
|
||||
// Index searches for and retrieves the element at the given index. This method will panic if
|
||||
// the document is invalid or if the index is out of bounds.
|
||||
func (d Document) Index(index uint) Element {
|
||||
elem, err := d.IndexErr(index)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return elem
|
||||
}
|
||||
|
||||
// IndexErr searches for and retrieves the element at the given index.
|
||||
func (d Document) IndexErr(index uint) (Element, error) {
|
||||
return indexErr(d, index)
|
||||
}
|
||||
|
||||
func indexErr(b []byte, index uint) (Element, error) {
|
||||
length, rem, ok := ReadLength(b)
|
||||
if !ok {
|
||||
return nil, NewInsufficientBytesError(b, rem)
|
||||
}
|
||||
|
||||
length -= 4
|
||||
|
||||
var current uint
|
||||
var elem Element
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return nil, NewInsufficientBytesError(b, rem)
|
||||
}
|
||||
if current != index {
|
||||
current++
|
||||
continue
|
||||
}
|
||||
return elem, nil
|
||||
}
|
||||
return nil, ErrOutOfBounds
|
||||
}
|
||||
|
||||
// DebugString outputs a human readable version of Document. It will attempt to stringify the
|
||||
// valid components of the document even if the entire document is not valid.
|
||||
func (d Document) DebugString() string {
|
||||
if len(d) < 5 {
|
||||
return "<malformed>"
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("Document")
|
||||
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
|
||||
buf.WriteByte('(')
|
||||
buf.WriteString(strconv.Itoa(int(length)))
|
||||
length -= 4
|
||||
buf.WriteString("){")
|
||||
var elem Element
|
||||
var ok bool
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
|
||||
break
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s ", elem.DebugString())
|
||||
}
|
||||
buf.WriteByte('}')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// String outputs an ExtendedJSON version of Document. If the document is not valid, this method
|
||||
// returns an empty string.
|
||||
func (d Document) String() string {
|
||||
if len(d) < 5 {
|
||||
return ""
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('{')
|
||||
|
||||
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
|
||||
|
||||
length -= 4
|
||||
|
||||
var elem Element
|
||||
var ok bool
|
||||
first := true
|
||||
for length > 1 {
|
||||
if !first {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
fmt.Fprintf(&buf, "%s", elem.String())
|
||||
first = false
|
||||
}
|
||||
buf.WriteByte('}')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Elements returns this document as a slice of elements. The returned slice will contain valid
|
||||
// elements. If the document is not valid, the elements up to the invalid point will be returned
|
||||
// along with an error.
|
||||
func (d Document) Elements() ([]Element, error) {
|
||||
length, rem, ok := ReadLength(d)
|
||||
if !ok {
|
||||
return nil, NewInsufficientBytesError(d, rem)
|
||||
}
|
||||
|
||||
length -= 4
|
||||
|
||||
var elem Element
|
||||
var elems []Element
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return elems, NewInsufficientBytesError(d, rem)
|
||||
}
|
||||
if err := elem.Validate(); err != nil {
|
||||
return elems, err
|
||||
}
|
||||
elems = append(elems, elem)
|
||||
}
|
||||
return elems, nil
|
||||
}
|
||||
|
||||
// Values returns this document as a slice of values. The returned slice will contain valid values.
|
||||
// If the document is not valid, the values up to the invalid point will be returned along with an
|
||||
// error.
|
||||
func (d Document) Values() ([]Value, error) {
|
||||
return values(d)
|
||||
}
|
||||
|
||||
func values(b []byte) ([]Value, error) {
|
||||
length, rem, ok := ReadLength(b)
|
||||
if !ok {
|
||||
return nil, NewInsufficientBytesError(b, rem)
|
||||
}
|
||||
|
||||
length -= 4
|
||||
|
||||
var elem Element
|
||||
var vals []Value
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return vals, NewInsufficientBytesError(b, rem)
|
||||
}
|
||||
if err := elem.Value().Validate(); err != nil {
|
||||
return vals, err
|
||||
}
|
||||
vals = append(vals, elem.Value())
|
||||
}
|
||||
return vals, nil
|
||||
}
|
||||
|
||||
// Validate validates the document and ensures the elements contained within are valid.
|
||||
func (d Document) Validate() error {
|
||||
length, rem, ok := ReadLength(d)
|
||||
if !ok {
|
||||
return NewInsufficientBytesError(d, rem)
|
||||
}
|
||||
if int(length) > len(d) {
|
||||
return NewDocumentLengthError(int(length), len(d))
|
||||
}
|
||||
if d[length-1] != 0x00 {
|
||||
return ErrMissingNull
|
||||
}
|
||||
|
||||
length -= 4
|
||||
var elem Element
|
||||
|
||||
for length > 1 {
|
||||
elem, rem, ok = ReadElement(rem)
|
||||
length -= int32(len(elem))
|
||||
if !ok {
|
||||
return NewInsufficientBytesError(d, rem)
|
||||
}
|
||||
err := elem.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(rem) < 1 || rem[0] != 0x00 {
|
||||
return ErrMissingNull
|
||||
}
|
||||
return nil
|
||||
}
|
||||
189
mongo/x/bsonx/bsoncore/document_sequence.go
Normal file
189
mongo/x/bsonx/bsoncore/document_sequence.go
Normal file
@@ -0,0 +1,189 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
// DocumentSequenceStyle is used to represent how a document sequence is laid out in a slice of
|
||||
// bytes.
|
||||
type DocumentSequenceStyle uint32
|
||||
|
||||
// These constants are the valid styles for a DocumentSequence.
|
||||
const (
|
||||
_ DocumentSequenceStyle = iota
|
||||
SequenceStyle
|
||||
ArrayStyle
|
||||
)
|
||||
|
||||
// DocumentSequence represents a sequence of documents. The Style field indicates how the documents
|
||||
// are laid out inside of the Data field.
|
||||
type DocumentSequence struct {
|
||||
Style DocumentSequenceStyle
|
||||
Data []byte
|
||||
Pos int
|
||||
}
|
||||
|
||||
// ErrCorruptedDocument is returned when a full document couldn't be read from the sequence.
|
||||
var ErrCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document")
|
||||
|
||||
// ErrNonDocument is returned when a DocumentSequence contains a non-document BSON value.
|
||||
var ErrNonDocument = errors.New("invalid DocumentSequence: a non-document value was found in sequence")
|
||||
|
||||
// ErrInvalidDocumentSequenceStyle is returned when an unknown DocumentSequenceStyle is set on a
|
||||
// DocumentSequence.
|
||||
var ErrInvalidDocumentSequenceStyle = errors.New("invalid DocumentSequenceStyle")
|
||||
|
||||
// DocumentCount returns the number of documents in the sequence.
|
||||
func (ds *DocumentSequence) DocumentCount() int {
|
||||
if ds == nil {
|
||||
return 0
|
||||
}
|
||||
switch ds.Style {
|
||||
case SequenceStyle:
|
||||
var count int
|
||||
var ok bool
|
||||
rem := ds.Data
|
||||
for len(rem) > 0 {
|
||||
_, rem, ok = ReadDocument(rem)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
case ArrayStyle:
|
||||
_, rem, ok := ReadLength(ds.Data)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
|
||||
var count int
|
||||
for len(rem) > 1 {
|
||||
_, rem, ok = ReadElement(rem)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// Empty returns true if the sequence is empty. It always returns true for unknown sequence styles.
|
||||
func (ds *DocumentSequence) Empty() bool {
|
||||
if ds == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
switch ds.Style {
|
||||
case SequenceStyle:
|
||||
return len(ds.Data) == 0
|
||||
case ArrayStyle:
|
||||
return len(ds.Data) <= 5
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// ResetIterator resets the iteration point for the Next method to the beginning of the document
|
||||
// sequence.
|
||||
func (ds *DocumentSequence) ResetIterator() {
|
||||
if ds == nil {
|
||||
return
|
||||
}
|
||||
ds.Pos = 0
|
||||
}
|
||||
|
||||
// Documents returns a slice of the documents. If nil either the Data field is also nil or could not
|
||||
// be properly read.
|
||||
func (ds *DocumentSequence) Documents() ([]Document, error) {
|
||||
if ds == nil {
|
||||
return nil, nil
|
||||
}
|
||||
switch ds.Style {
|
||||
case SequenceStyle:
|
||||
rem := ds.Data
|
||||
var docs []Document
|
||||
var doc Document
|
||||
var ok bool
|
||||
for {
|
||||
doc, rem, ok = ReadDocument(rem)
|
||||
if !ok {
|
||||
if len(rem) == 0 {
|
||||
break
|
||||
}
|
||||
return nil, ErrCorruptedDocument
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
return docs, nil
|
||||
case ArrayStyle:
|
||||
if len(ds.Data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
vals, err := Document(ds.Data).Values()
|
||||
if err != nil {
|
||||
return nil, ErrCorruptedDocument
|
||||
}
|
||||
docs := make([]Document, 0, len(vals))
|
||||
for _, v := range vals {
|
||||
if v.Type != bsontype.EmbeddedDocument {
|
||||
return nil, ErrNonDocument
|
||||
}
|
||||
docs = append(docs, v.Data)
|
||||
}
|
||||
return docs, nil
|
||||
default:
|
||||
return nil, ErrInvalidDocumentSequenceStyle
|
||||
}
|
||||
}
|
||||
|
||||
// Next retrieves the next document from this sequence and returns it. This method will return
|
||||
// io.EOF when it has reached the end of the sequence.
|
||||
func (ds *DocumentSequence) Next() (Document, error) {
|
||||
if ds == nil || ds.Pos >= len(ds.Data) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
switch ds.Style {
|
||||
case SequenceStyle:
|
||||
doc, _, ok := ReadDocument(ds.Data[ds.Pos:])
|
||||
if !ok {
|
||||
return nil, ErrCorruptedDocument
|
||||
}
|
||||
ds.Pos += len(doc)
|
||||
return doc, nil
|
||||
case ArrayStyle:
|
||||
if ds.Pos < 4 {
|
||||
if len(ds.Data) < 4 {
|
||||
return nil, ErrCorruptedDocument
|
||||
}
|
||||
ds.Pos = 4 // Skip the length of the document
|
||||
}
|
||||
if len(ds.Data[ds.Pos:]) == 1 && ds.Data[ds.Pos] == 0x00 {
|
||||
return nil, io.EOF // At the end of the document
|
||||
}
|
||||
elem, _, ok := ReadElement(ds.Data[ds.Pos:])
|
||||
if !ok {
|
||||
return nil, ErrCorruptedDocument
|
||||
}
|
||||
ds.Pos += len(elem)
|
||||
val := elem.Value()
|
||||
if val.Type != bsontype.EmbeddedDocument {
|
||||
return nil, ErrNonDocument
|
||||
}
|
||||
return val.Data, nil
|
||||
default:
|
||||
return nil, ErrInvalidDocumentSequenceStyle
|
||||
}
|
||||
}
|
||||
421
mongo/x/bsonx/bsoncore/document_sequence_test.go
Normal file
421
mongo/x/bsonx/bsoncore/document_sequence_test.go
Normal file
@@ -0,0 +1,421 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestDocumentSequence(t *testing.T) {
|
||||
|
||||
genArrayStyle := func(num int) []byte {
|
||||
idx, seq := AppendDocumentStart(nil)
|
||||
for i := 0; i < num; i++ {
|
||||
seq = AppendDocumentElement(
|
||||
seq, strconv.Itoa(i),
|
||||
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
|
||||
)
|
||||
}
|
||||
seq, _ = AppendDocumentEnd(seq, idx)
|
||||
return seq
|
||||
}
|
||||
genSequenceStyle := func(num int) []byte {
|
||||
var seq []byte
|
||||
for i := 0; i < num; i++ {
|
||||
seq = append(seq, BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159))...)
|
||||
}
|
||||
return seq
|
||||
}
|
||||
|
||||
idx, arrayStyle := AppendDocumentStart(nil)
|
||||
idx2, arrayStyle := AppendDocumentElementStart(arrayStyle, "0")
|
||||
arrayStyle = AppendDoubleElement(arrayStyle, "pi", 3.14159)
|
||||
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
|
||||
idx2, arrayStyle = AppendDocumentElementStart(arrayStyle, "1")
|
||||
arrayStyle = AppendStringElement(arrayStyle, "hello", "world")
|
||||
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
|
||||
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx)
|
||||
|
||||
t.Run("Documents", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
style DocumentSequenceStyle
|
||||
data []byte
|
||||
documents []Document
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"SequenceStle/corrupted document",
|
||||
SequenceStyle,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
|
||||
nil,
|
||||
ErrCorruptedDocument,
|
||||
},
|
||||
{
|
||||
"SequenceStyle/success",
|
||||
SequenceStyle,
|
||||
BuildDocument(
|
||||
BuildDocument(
|
||||
nil,
|
||||
AppendStringElement(AppendDoubleElement(nil, "pi", 3.14159), "hello", "world"),
|
||||
),
|
||||
AppendDoubleElement(AppendStringElement(nil, "hello", "world"), "pi", 3.14159),
|
||||
),
|
||||
[]Document{
|
||||
BuildDocument(nil, AppendStringElement(AppendDoubleElement(nil, "pi", 3.14159), "hello", "world")),
|
||||
BuildDocument(nil, AppendDoubleElement(AppendStringElement(nil, "hello", "world"), "pi", 3.14159)),
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/insufficient bytes",
|
||||
ArrayStyle,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
nil,
|
||||
ErrCorruptedDocument,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/non-document",
|
||||
ArrayStyle,
|
||||
BuildDocument(nil, AppendDoubleElement(nil, "0", 12345.67890)),
|
||||
nil,
|
||||
ErrNonDocument,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/success",
|
||||
ArrayStyle,
|
||||
arrayStyle,
|
||||
[]Document{
|
||||
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
|
||||
BuildDocument(nil, AppendStringElement(nil, "hello", "world")),
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{"Invalid DocumentSequenceStyle", 0, nil, nil, ErrInvalidDocumentSequenceStyle},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ds := &DocumentSequence{
|
||||
Style: tc.style,
|
||||
Data: tc.data,
|
||||
}
|
||||
documents, err := ds.Documents()
|
||||
if !cmp.Equal(documents, tc.documents) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", documents, tc.documents)
|
||||
}
|
||||
if err != tc.err {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Next", func(t *testing.T) {
|
||||
seqDoc := BuildDocument(
|
||||
BuildDocument(
|
||||
nil,
|
||||
AppendDoubleElement(nil, "pi", 3.14159),
|
||||
),
|
||||
AppendStringElement(nil, "hello", "world"),
|
||||
)
|
||||
|
||||
idx, arrayStyle := AppendDocumentStart(nil)
|
||||
idx2, arrayStyle := AppendDocumentElementStart(arrayStyle, "0")
|
||||
arrayStyle = AppendDoubleElement(arrayStyle, "pi", 3.14159)
|
||||
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
|
||||
idx2, arrayStyle = AppendDocumentElementStart(arrayStyle, "1")
|
||||
arrayStyle = AppendStringElement(arrayStyle, "hello", "world")
|
||||
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
|
||||
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
style DocumentSequenceStyle
|
||||
data []byte
|
||||
pos int
|
||||
document Document
|
||||
err error
|
||||
}{
|
||||
{"io.EOF", 0, make([]byte, 10), 10, nil, io.EOF},
|
||||
{
|
||||
"SequenceStyle/corrupted document",
|
||||
SequenceStyle,
|
||||
[]byte{0x01, 0x02, 0x03, 0x04},
|
||||
0,
|
||||
nil,
|
||||
ErrCorruptedDocument,
|
||||
},
|
||||
{
|
||||
"SequenceStyle/success/first",
|
||||
SequenceStyle,
|
||||
seqDoc,
|
||||
0,
|
||||
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"SequenceStyle/success/second",
|
||||
SequenceStyle,
|
||||
seqDoc,
|
||||
17,
|
||||
BuildDocument(nil, AppendStringElement(nil, "hello", "world")),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/corrupted document/too short",
|
||||
ArrayStyle,
|
||||
[]byte{0x01, 0x02, 0x03},
|
||||
0,
|
||||
nil,
|
||||
ErrCorruptedDocument,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/corrupted document/invalid element",
|
||||
ArrayStyle,
|
||||
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, '0', 0x00, 0x01, 0x02},
|
||||
0,
|
||||
nil,
|
||||
ErrCorruptedDocument,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/non-document",
|
||||
ArrayStyle,
|
||||
BuildDocument(nil, AppendDoubleElement(nil, "0", 12345.67890)),
|
||||
0,
|
||||
nil,
|
||||
ErrNonDocument,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/success/first",
|
||||
ArrayStyle,
|
||||
arrayStyle,
|
||||
0,
|
||||
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/success/second",
|
||||
ArrayStyle,
|
||||
arrayStyle,
|
||||
24,
|
||||
BuildDocument(nil, AppendStringElement(nil, "hello", "world")),
|
||||
nil,
|
||||
},
|
||||
{"Invalid DocumentSequenceStyle", 0, make([]byte, 4), 0, nil, ErrInvalidDocumentSequenceStyle},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ds := &DocumentSequence{
|
||||
Style: tc.style,
|
||||
Data: tc.data,
|
||||
Pos: tc.pos,
|
||||
}
|
||||
document, err := ds.Next()
|
||||
if !bytes.Equal(document, tc.document) {
|
||||
t.Errorf("Documents do not match. got %v; want %v", document, tc.document)
|
||||
}
|
||||
if err != tc.err {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Full Iteration", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
style DocumentSequenceStyle
|
||||
data []byte
|
||||
count int
|
||||
}{
|
||||
{"SequenceStyle/success/nil", SequenceStyle, nil, 0},
|
||||
{"SequenceStyle/success/0", SequenceStyle, []byte{}, 0},
|
||||
{"SequenceStyle/success/1", SequenceStyle, genSequenceStyle(1), 1},
|
||||
{"SequenceStyle/success/2", SequenceStyle, genSequenceStyle(2), 2},
|
||||
{"SequenceStyle/success/10", SequenceStyle, genSequenceStyle(10), 10},
|
||||
{"SequenceStyle/success/100", SequenceStyle, genSequenceStyle(100), 100},
|
||||
{"ArrayStyle/success/nil", ArrayStyle, nil, 0},
|
||||
{"ArrayStyle/success/0", ArrayStyle, []byte{0x05, 0x00, 0x00, 0x00, 0x00}, 0},
|
||||
{"ArrayStyle/success/1", ArrayStyle, genArrayStyle(1), 1},
|
||||
{"ArrayStyle/success/2", ArrayStyle, genArrayStyle(2), 2},
|
||||
{"ArrayStyle/success/10", ArrayStyle, genArrayStyle(10), 10},
|
||||
{"ArrayStyle/success/100", ArrayStyle, genArrayStyle(100), 100},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("Documents/"+tc.name, func(t *testing.T) {
|
||||
ds := &DocumentSequence{
|
||||
Style: tc.style,
|
||||
Data: tc.data,
|
||||
}
|
||||
docs, err := ds.Documents()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
count := len(docs)
|
||||
if count != tc.count {
|
||||
t.Errorf("Coun't fully iterate documents, wrong count. got %v; want %v", count, tc.count)
|
||||
}
|
||||
})
|
||||
t.Run("Next/"+tc.name, func(t *testing.T) {
|
||||
ds := &DocumentSequence{
|
||||
Style: tc.style,
|
||||
Data: tc.data,
|
||||
}
|
||||
var docs []Document
|
||||
for {
|
||||
doc, err := ds.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
docs = append(docs, doc)
|
||||
}
|
||||
count := len(docs)
|
||||
if count != tc.count {
|
||||
t.Errorf("Coun't fully iterate documents, wrong count. got %v; want %v", count, tc.count)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("DocumentCount", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
style DocumentSequenceStyle
|
||||
data []byte
|
||||
count int
|
||||
}{
|
||||
{
|
||||
"SequenceStyle/corrupt document/first",
|
||||
SequenceStyle,
|
||||
[]byte{0x01, 0x02, 0x03},
|
||||
0,
|
||||
},
|
||||
{
|
||||
"SequenceStyle/corrupt document/second",
|
||||
SequenceStyle,
|
||||
[]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03},
|
||||
0,
|
||||
},
|
||||
{"SequenceStyle/success/nil", SequenceStyle, nil, 0},
|
||||
{"SequenceStyle/success/0", SequenceStyle, []byte{}, 0},
|
||||
{"SequenceStyle/success/1", SequenceStyle, genSequenceStyle(1), 1},
|
||||
{"SequenceStyle/success/2", SequenceStyle, genSequenceStyle(2), 2},
|
||||
{"SequenceStyle/success/10", SequenceStyle, genSequenceStyle(10), 10},
|
||||
{"SequenceStyle/success/100", SequenceStyle, genSequenceStyle(100), 100},
|
||||
{
|
||||
"ArrayStyle/corrupt document/length",
|
||||
ArrayStyle,
|
||||
[]byte{0x01, 0x02, 0x03},
|
||||
0,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/corrupt element/first",
|
||||
ArrayStyle,
|
||||
BuildDocument(nil, []byte{0x01, 0x00, 0x03, 0x04, 0x05}),
|
||||
0,
|
||||
},
|
||||
{
|
||||
"ArrayStyle/corrupt element/second",
|
||||
ArrayStyle,
|
||||
BuildDocument(nil, []byte{0x0A, 0x00, 0x01, 0x00, 0x03, 0x04, 0x05}),
|
||||
0,
|
||||
},
|
||||
{"ArrayStyle/success/nil", ArrayStyle, nil, 0},
|
||||
{"ArrayStyle/success/0", ArrayStyle, []byte{0x05, 0x00, 0x00, 0x00, 0x00}, 0},
|
||||
{"ArrayStyle/success/1", ArrayStyle, genArrayStyle(1), 1},
|
||||
{"ArrayStyle/success/2", ArrayStyle, genArrayStyle(2), 2},
|
||||
{"ArrayStyle/success/10", ArrayStyle, genArrayStyle(10), 10},
|
||||
{"ArrayStyle/success/100", ArrayStyle, genArrayStyle(100), 100},
|
||||
{"Invalid DocumentSequenceStyle", 0, nil, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ds := &DocumentSequence{
|
||||
Style: tc.style,
|
||||
Data: tc.data,
|
||||
}
|
||||
count := ds.DocumentCount()
|
||||
if count != tc.count {
|
||||
t.Errorf("Document counts don't match. got %v; want %v", count, tc.count)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ds *DocumentSequence
|
||||
isEmpty bool
|
||||
}{
|
||||
{"ArrayStyle/is empty/nil", nil, true},
|
||||
{"ArrayStyle/is empty/0", &DocumentSequence{Style: ArrayStyle, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}}, true},
|
||||
{"ArrayStyle/is not empty/non-0", &DocumentSequence{Style: ArrayStyle, Data: genArrayStyle(10)}, false},
|
||||
{"SequenceStyle/is empty/nil", nil, true},
|
||||
{"SequenceStyle/is empty/0", &DocumentSequence{Style: SequenceStyle, Data: []byte{}}, true},
|
||||
{"SequenceStyle/is not empty/non-0", &DocumentSequence{Style: SequenceStyle, Data: genSequenceStyle(10)}, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
isEmpty := tc.ds.Empty()
|
||||
if isEmpty != tc.isEmpty {
|
||||
t.Errorf("Unexpected Empty result. got %v; want %v", isEmpty, tc.isEmpty)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("ResetIterator", func(t *testing.T) {
|
||||
ds := &DocumentSequence{Pos: 1234567890}
|
||||
want := 0
|
||||
ds.ResetIterator()
|
||||
if ds.Pos != want {
|
||||
t.Errorf("Unexpected position after ResetIterator. got %d; want %d", ds.Pos, want)
|
||||
}
|
||||
})
|
||||
t.Run("no panic on nil", func(t *testing.T) {
|
||||
capturePanic := func() {
|
||||
if err := recover(); err != nil {
|
||||
t.Errorf("Unexpected panic. got %v; want <nil>", err)
|
||||
}
|
||||
}
|
||||
t.Run("DocumentCount", func(t *testing.T) {
|
||||
defer capturePanic()
|
||||
var ds *DocumentSequence
|
||||
_ = ds.DocumentCount()
|
||||
})
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
defer capturePanic()
|
||||
var ds *DocumentSequence
|
||||
_ = ds.Empty()
|
||||
})
|
||||
t.Run("ResetIterator", func(t *testing.T) {
|
||||
defer capturePanic()
|
||||
var ds *DocumentSequence
|
||||
ds.ResetIterator()
|
||||
})
|
||||
t.Run("Documents", func(t *testing.T) {
|
||||
defer capturePanic()
|
||||
var ds *DocumentSequence
|
||||
_, _ = ds.Documents()
|
||||
})
|
||||
t.Run("Next", func(t *testing.T) {
|
||||
defer capturePanic()
|
||||
var ds *DocumentSequence
|
||||
_, _ = ds.Next()
|
||||
})
|
||||
})
|
||||
}
|
||||
412
mongo/x/bsonx/bsoncore/document_test.go
Normal file
412
mongo/x/bsonx/bsoncore/document_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
func ExampleDocument_Validate() {
|
||||
doc := make(Document, 500)
|
||||
doc[250], doc[251], doc[252], doc[253], doc[254] = 0x05, 0x00, 0x00, 0x00, 0x00
|
||||
err := doc[250:].Validate()
|
||||
fmt.Println(err)
|
||||
|
||||
// Output: <nil>
|
||||
}
|
||||
|
||||
func BenchmarkDocumentValidate(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
doc := make(Document, 500)
|
||||
doc[250], doc[251], doc[252], doc[253], doc[254] = 0x05, 0x00, 0x00, 0x00, 0x00
|
||||
_ = doc[250:].Validate()
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocument(t *testing.T) {
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
t.Run("TooShort", func(t *testing.T) {
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
got := Document{'\x00', '\x00'}.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("InvalidLength", func(t *testing.T) {
|
||||
want := NewDocumentLengthError(200, 5)
|
||||
r := make(Document, 5)
|
||||
binary.LittleEndian.PutUint32(r[0:4], 200)
|
||||
got := r.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Invalid Element", func(t *testing.T) {
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
r := make(Document, 9)
|
||||
binary.LittleEndian.PutUint32(r[0:4], 9)
|
||||
r[4], r[5], r[6], r[7], r[8] = 0x02, 'f', 'o', 'o', 0x00
|
||||
got := r.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("Missing Null Terminator", func(t *testing.T) {
|
||||
want := ErrMissingNull
|
||||
r := make(Document, 8)
|
||||
binary.LittleEndian.PutUint32(r[0:4], 8)
|
||||
r[4], r[5], r[6], r[7] = 0x0A, 'f', 'o', 'o'
|
||||
got := r.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not get expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
r Document
|
||||
want error
|
||||
}{
|
||||
{"null", Document{'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00'}, nil},
|
||||
{"subdocument",
|
||||
Document{
|
||||
'\x15', '\x00', '\x00', '\x00',
|
||||
'\x03',
|
||||
'f', 'o', 'o', '\x00',
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', 'a', '\x00',
|
||||
'\x0A', 'b', '\x00', '\x00', '\x00',
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{"array",
|
||||
Document{
|
||||
'\x15', '\x00', '\x00', '\x00',
|
||||
'\x04',
|
||||
'f', 'o', 'o', '\x00',
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', '1', '\x00',
|
||||
'\x0A', '2', '\x00', '\x00', '\x00',
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.r.Validate()
|
||||
if !compareErrors(got, tc.want) {
|
||||
t.Errorf("Returned error does not match. got %v; want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Lookup", func(t *testing.T) {
|
||||
t.Run("empty-key", func(t *testing.T) {
|
||||
rdr := Document{'\x05', '\x00', '\x00', '\x00', '\x00'}
|
||||
_, err := rdr.LookupErr()
|
||||
if err != ErrEmptyKey {
|
||||
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", err, ErrEmptyKey)
|
||||
}
|
||||
})
|
||||
t.Run("corrupted-subdocument", func(t *testing.T) {
|
||||
rdr := Document{
|
||||
'\x0D', '\x00', '\x00', '\x00',
|
||||
'\x03', 'x', '\x00',
|
||||
'\x06', '\x00', '\x00', '\x00',
|
||||
'\x01',
|
||||
'\x00',
|
||||
'\x00',
|
||||
}
|
||||
_, got := rdr.LookupErr("x", "y")
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
if !cmp.Equal(got, want) {
|
||||
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("corrupted-array", func(t *testing.T) {
|
||||
rdr := Document{
|
||||
'\x0D', '\x00', '\x00', '\x00',
|
||||
'\x04', 'x', '\x00',
|
||||
'\x06', '\x00', '\x00', '\x00',
|
||||
'\x01',
|
||||
'\x00',
|
||||
'\x00',
|
||||
}
|
||||
_, got := rdr.LookupErr("x", "y")
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
if !cmp.Equal(got, want) {
|
||||
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("invalid-traversal", func(t *testing.T) {
|
||||
rdr := Document{'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00'}
|
||||
_, got := rdr.LookupErr("x", "y")
|
||||
want := InvalidDepthTraversalError{Key: "x", Type: bsontype.Null}
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
r Document
|
||||
key []string
|
||||
want Value
|
||||
err error
|
||||
}{
|
||||
{"first",
|
||||
Document{
|
||||
'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00',
|
||||
},
|
||||
[]string{"x"},
|
||||
Value{Type: bsontype.Null, Data: []byte{}},
|
||||
nil,
|
||||
},
|
||||
{"first-second",
|
||||
Document{
|
||||
'\x15', '\x00', '\x00', '\x00',
|
||||
'\x03',
|
||||
'f', 'o', 'o', '\x00',
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', 'a', '\x00',
|
||||
'\x0A', 'b', '\x00', '\x00', '\x00',
|
||||
},
|
||||
[]string{"foo", "b"},
|
||||
Value{Type: bsontype.Null, Data: []byte{}},
|
||||
nil,
|
||||
},
|
||||
{"first-second-array",
|
||||
Document{
|
||||
'\x15', '\x00', '\x00', '\x00',
|
||||
'\x04',
|
||||
'f', 'o', 'o', '\x00',
|
||||
'\x0B', '\x00', '\x00', '\x00', '\x0A', '1', '\x00',
|
||||
'\x0A', '2', '\x00', '\x00', '\x00',
|
||||
},
|
||||
[]string{"foo", "2"},
|
||||
Value{Type: bsontype.Null, Data: []byte{}},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("Lookup", func(t *testing.T) {
|
||||
got := tc.r.Lookup(tc.key...)
|
||||
if !cmp.Equal(got, tc.want) {
|
||||
t.Errorf("Returned value does not match expected element. got %v; want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
t.Run("LookupErr", func(t *testing.T) {
|
||||
got, err := tc.r.LookupErr(tc.key...)
|
||||
if err != tc.err {
|
||||
t.Errorf("Returned error does not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if !cmp.Equal(got, tc.want) {
|
||||
t.Errorf("Returned value does not match expected element. got %v; want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Index", func(t *testing.T) {
|
||||
t.Run("Out of bounds", func(t *testing.T) {
|
||||
rdr := Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0}
|
||||
_, err := rdr.IndexErr(3)
|
||||
if err != ErrOutOfBounds {
|
||||
t.Errorf("Out of bounds should be returned when accessing element beyond end of document. got %v; want %v", err, ErrOutOfBounds)
|
||||
}
|
||||
})
|
||||
t.Run("Validation Error", func(t *testing.T) {
|
||||
rdr := Document{0x07, 0x00, 0x00, 0x00, 0x00}
|
||||
_, got := rdr.IndexErr(1)
|
||||
want := NewInsufficientBytesError(nil, nil)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
rdr Document
|
||||
index uint
|
||||
want Element
|
||||
}{
|
||||
{"first",
|
||||
Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0},
|
||||
0, Element{0x0a, 0x78, 0x00},
|
||||
},
|
||||
{"second",
|
||||
Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0},
|
||||
1, Element{0x0a, 0x79, 0x00},
|
||||
},
|
||||
{"third",
|
||||
Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0},
|
||||
2, Element{0x0a, 0x7a, 0x00},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("IndexErr", func(t *testing.T) {
|
||||
got, err := tc.rdr.IndexErr(tc.index)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error from IndexErr: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("Documents differ: (-got +want)\n%s", diff)
|
||||
}
|
||||
})
|
||||
t.Run("Index", func(t *testing.T) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}()
|
||||
got := tc.rdr.Index(tc.index)
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("Documents differ: (-got +want)\n%s", diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("NewDocumentFromReader", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
ioReader io.Reader
|
||||
doc Document
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"nil reader",
|
||||
nil,
|
||||
nil,
|
||||
ErrNilReader,
|
||||
},
|
||||
{
|
||||
"premature end of reader",
|
||||
bytes.NewBuffer([]byte{}),
|
||||
nil,
|
||||
io.EOF,
|
||||
},
|
||||
{
|
||||
"empty document",
|
||||
bytes.NewBuffer([]byte{5, 0, 0, 0, 0}),
|
||||
[]byte{5, 0, 0, 0, 0},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"non-empty document",
|
||||
bytes.NewBuffer([]byte{
|
||||
// length
|
||||
0x17, 0x0, 0x0, 0x0,
|
||||
|
||||
// type - string
|
||||
0x2,
|
||||
// key - "foo"
|
||||
0x66, 0x6f, 0x6f, 0x0,
|
||||
// value - string length
|
||||
0x4, 0x0, 0x0, 0x0,
|
||||
// value - string "bar"
|
||||
0x62, 0x61, 0x72, 0x0,
|
||||
|
||||
// type - null
|
||||
0xa,
|
||||
// key - "baz"
|
||||
0x62, 0x61, 0x7a, 0x0,
|
||||
|
||||
// null terminator
|
||||
0x0,
|
||||
}),
|
||||
[]byte{
|
||||
// length
|
||||
0x17, 0x0, 0x0, 0x0,
|
||||
|
||||
// type - string
|
||||
0x2,
|
||||
// key - "foo"
|
||||
0x66, 0x6f, 0x6f, 0x0,
|
||||
// value - string length
|
||||
0x4, 0x0, 0x0, 0x0,
|
||||
// value - string "bar"
|
||||
0x62, 0x61, 0x72, 0x0,
|
||||
|
||||
// type - null
|
||||
0xa,
|
||||
// key - "baz"
|
||||
0x62, 0x61, 0x7a, 0x0,
|
||||
|
||||
// null terminator
|
||||
0x0,
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
doc, err := NewDocumentFromReader(tc.ioReader)
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if !bytes.Equal(tc.doc, doc) {
|
||||
t.Errorf("documents differ. got %v; want %v", tc.doc, doc)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Elements", func(t *testing.T) {
|
||||
invalidElem := BuildDocument(nil, AppendHeader(nil, bsontype.Double, "foo"))
|
||||
invalidTwoElem := BuildDocument(nil,
|
||||
AppendHeader(
|
||||
AppendDoubleElement(nil, "pi", 3.14159),
|
||||
bsontype.Double, "foo",
|
||||
),
|
||||
)
|
||||
oneElem := BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159))
|
||||
twoElems := BuildDocument(nil,
|
||||
AppendStringElement(
|
||||
AppendDoubleElement(nil, "pi", 3.14159),
|
||||
"hello", "world!",
|
||||
),
|
||||
)
|
||||
testCases := []struct {
|
||||
name string
|
||||
doc Document
|
||||
elems []Element
|
||||
err error
|
||||
}{
|
||||
{"Insufficient Bytes Length", Document{0x03, 0x00, 0x00}, nil, NewInsufficientBytesError(nil, nil)},
|
||||
{"Insufficient Bytes First Element", invalidElem, nil, NewInsufficientBytesError(nil, nil)},
|
||||
{"Insufficient Bytes Second Element", invalidTwoElem, []Element{AppendDoubleElement(nil, "pi", 3.14159)}, NewInsufficientBytesError(nil, nil)},
|
||||
{"Success One Element", oneElem, []Element{AppendDoubleElement(nil, "pi", 3.14159)}, nil},
|
||||
{"Success Two Elements", twoElems, []Element{AppendDoubleElement(nil, "pi", 3.14159), AppendStringElement(nil, "hello", "world!")}, nil},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
elems, err := tc.doc.Elements()
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if len(elems) != len(tc.elems) {
|
||||
t.Fatalf("number of elements returned does not match. got %d; want %d", len(elems), len(tc.elems))
|
||||
}
|
||||
|
||||
for idx := range elems {
|
||||
got, want := elems[idx], tc.elems[idx]
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Errorf("Elements at index %d differ. got %v; want %v", idx, got.DebugString(), want.DebugString())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
152
mongo/x/bsonx/bsoncore/element.go
Normal file
152
mongo/x/bsonx/bsoncore/element.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
// MalformedElementError represents a class of errors that RawElement methods return.
|
||||
type MalformedElementError string
|
||||
|
||||
func (mee MalformedElementError) Error() string { return string(mee) }
|
||||
|
||||
// ErrElementMissingKey is returned when a RawElement is missing a key.
|
||||
const ErrElementMissingKey MalformedElementError = "element is missing key"
|
||||
|
||||
// ErrElementMissingType is returned when a RawElement is missing a type.
|
||||
const ErrElementMissingType MalformedElementError = "element is missing type"
|
||||
|
||||
// Element is a raw bytes representation of a BSON element.
|
||||
type Element []byte
|
||||
|
||||
// Key returns the key for this element. If the element is not valid, this method returns an empty
|
||||
// string. If knowing if the element is valid is important, use KeyErr.
|
||||
func (e Element) Key() string {
|
||||
key, _ := e.KeyErr()
|
||||
return key
|
||||
}
|
||||
|
||||
// KeyBytes returns the key for this element as a []byte. If the element is not valid, this method
|
||||
// returns an empty string. If knowing if the element is valid is important, use KeyErr. This method
|
||||
// will not include the null byte at the end of the key in the slice of bytes.
|
||||
func (e Element) KeyBytes() []byte {
|
||||
key, _ := e.KeyBytesErr()
|
||||
return key
|
||||
}
|
||||
|
||||
// KeyErr returns the key for this element, returning an error if the element is not valid.
|
||||
func (e Element) KeyErr() (string, error) {
|
||||
key, err := e.KeyBytesErr()
|
||||
return string(key), err
|
||||
}
|
||||
|
||||
// KeyBytesErr returns the key for this element as a []byte, returning an error if the element is
|
||||
// not valid.
|
||||
func (e Element) KeyBytesErr() ([]byte, error) {
|
||||
if len(e) <= 0 {
|
||||
return nil, ErrElementMissingType
|
||||
}
|
||||
idx := bytes.IndexByte(e[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return nil, ErrElementMissingKey
|
||||
}
|
||||
return e[1 : idx+1], nil
|
||||
}
|
||||
|
||||
// Validate ensures the element is a valid BSON element.
|
||||
func (e Element) Validate() error {
|
||||
if len(e) < 1 {
|
||||
return ErrElementMissingType
|
||||
}
|
||||
idx := bytes.IndexByte(e[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return ErrElementMissingKey
|
||||
}
|
||||
return Value{Type: bsontype.Type(e[0]), Data: e[idx+2:]}.Validate()
|
||||
}
|
||||
|
||||
// CompareKey will compare this element's key to key. This method makes it easy to compare keys
|
||||
// without needing to allocate a string. The key may be null terminated. If a valid key cannot be
|
||||
// read this method will return false.
|
||||
func (e Element) CompareKey(key []byte) bool {
|
||||
if len(e) < 2 {
|
||||
return false
|
||||
}
|
||||
idx := bytes.IndexByte(e[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return false
|
||||
}
|
||||
if index := bytes.IndexByte(key, 0x00); index > -1 {
|
||||
key = key[:index]
|
||||
}
|
||||
return bytes.Equal(e[1:idx+1], key)
|
||||
}
|
||||
|
||||
// Value returns the value of this element. If the element is not valid, this method returns an
|
||||
// empty Value. If knowing if the element is valid is important, use ValueErr.
|
||||
func (e Element) Value() Value {
|
||||
val, _ := e.ValueErr()
|
||||
return val
|
||||
}
|
||||
|
||||
// ValueErr returns the value for this element, returning an error if the element is not valid.
|
||||
func (e Element) ValueErr() (Value, error) {
|
||||
if len(e) <= 0 {
|
||||
return Value{}, ErrElementMissingType
|
||||
}
|
||||
idx := bytes.IndexByte(e[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return Value{}, ErrElementMissingKey
|
||||
}
|
||||
|
||||
val, rem, exists := ReadValue(e[idx+2:], bsontype.Type(e[0]))
|
||||
if !exists {
|
||||
return Value{}, NewInsufficientBytesError(e, rem)
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// String implements the fmt.String interface. The output will be in extended JSON format.
|
||||
func (e Element) String() string {
|
||||
if len(e) <= 0 {
|
||||
return ""
|
||||
}
|
||||
t := bsontype.Type(e[0])
|
||||
idx := bytes.IndexByte(e[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return ""
|
||||
}
|
||||
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
|
||||
val, _, valid := ReadValue(valBytes, t)
|
||||
if !valid {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`"%s": %v`, key, val)
|
||||
}
|
||||
|
||||
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
|
||||
// valid components of the element even if the entire element is not valid.
|
||||
func (e Element) DebugString() string {
|
||||
if len(e) <= 0 {
|
||||
return "<malformed>"
|
||||
}
|
||||
t := bsontype.Type(e[0])
|
||||
idx := bytes.IndexByte(e[1:], 0x00)
|
||||
if idx == -1 {
|
||||
return fmt.Sprintf(`bson.Element{[%s]<malformed>}`, t)
|
||||
}
|
||||
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
|
||||
val, _, valid := ReadValue(valBytes, t)
|
||||
if !valid {
|
||||
return fmt.Sprintf(`bson.Element{[%s]"%s": <malformed>}`, t, key)
|
||||
}
|
||||
return fmt.Sprintf(`bson.Element{[%s]"%s": %v}`, t, key, val)
|
||||
}
|
||||
127
mongo/x/bsonx/bsoncore/element_test.go
Normal file
127
mongo/x/bsonx/bsoncore/element_test.go
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
func TestElement(t *testing.T) {
|
||||
t.Run("KeyErr", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
elem Element
|
||||
str string
|
||||
err error
|
||||
}{
|
||||
{"No Type", Element{}, "", ErrElementMissingType},
|
||||
{"No Key", Element{0x01, 'f', 'o', 'o'}, "", ErrElementMissingKey},
|
||||
{"Success", AppendHeader(nil, bsontype.Double, "foo"), "foo", nil},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("Key", func(t *testing.T) {
|
||||
str := tc.elem.Key()
|
||||
if str != tc.str {
|
||||
t.Errorf("returned strings do not match. got %s; want %s", str, tc.str)
|
||||
}
|
||||
})
|
||||
t.Run("KeyErr", func(t *testing.T) {
|
||||
str, err := tc.elem.KeyErr()
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if str != tc.str {
|
||||
t.Errorf("returned strings do not match. got %s; want %s", str, tc.str)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
elem Element
|
||||
err error
|
||||
}{
|
||||
{"No Type", Element{}, ErrElementMissingType},
|
||||
{"No Key", Element{0x01, 'f', 'o', 'o'}, ErrElementMissingKey},
|
||||
{"Insufficient Bytes", AppendHeader(nil, bsontype.Double, "foo"), NewInsufficientBytesError(nil, nil)},
|
||||
{"Success", AppendDoubleElement(nil, "foo", 3.14159), nil},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.elem.Validate()
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("CompareKey", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
elem Element
|
||||
key []byte
|
||||
equal bool
|
||||
}{
|
||||
{"Element Too Short", Element{0x02}, nil, false},
|
||||
{"Element Invalid Key", Element{0x02, 'f', 'o', 'o'}, nil, false},
|
||||
{"Key With Null Byte", AppendHeader(nil, bsontype.Double, "foo"), []byte{'f', 'o', 'o', 0x00}, true},
|
||||
{"Key Without Null Byte", AppendHeader(nil, bsontype.Double, "pi"), []byte{'p', 'i'}, true},
|
||||
{"Key With Null Byte With Extra", AppendHeader(nil, bsontype.Double, "foo"), []byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r'}, true},
|
||||
{"Prefix Key No Match", AppendHeader(nil, bsontype.Double, "foo"), []byte{'f', 'o', 'o', 'b', 'a', 'r'}, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
equal := tc.elem.CompareKey(tc.key)
|
||||
if equal != tc.equal {
|
||||
t.Errorf("Did not get expected equality result. got %t; want %t", equal, tc.equal)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Value & ValueErr", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
elem Element
|
||||
val Value
|
||||
err error
|
||||
}{
|
||||
{"No Type", Element{}, Value{}, ErrElementMissingType},
|
||||
{"No Key", Element{0x01, 'f', 'o', 'o'}, Value{}, ErrElementMissingKey},
|
||||
{"Insufficient Bytes", AppendHeader(nil, bsontype.Double, "foo"), Value{}, NewInsufficientBytesError(nil, nil)},
|
||||
{"Success", AppendDoubleElement(nil, "foo", 3.14159), Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}, nil},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Run("Value", func(t *testing.T) {
|
||||
val := tc.elem.Value()
|
||||
if !cmp.Equal(val, tc.val) {
|
||||
t.Errorf("Values do not match. got %v; want %v", val, tc.val)
|
||||
}
|
||||
})
|
||||
t.Run("ValueErr", func(t *testing.T) {
|
||||
val, err := tc.elem.ValueErr()
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if !cmp.Equal(val, tc.val) {
|
||||
t.Errorf("Values do not match. got %v; want %v", val, tc.val)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
223
mongo/x/bsonx/bsoncore/tables.go
Normal file
223
mongo/x/bsonx/bsoncore/tables.go
Normal file
@@ -0,0 +1,223 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/golang/go by The Go Authors
|
||||
// See THIRD-PARTY-NOTICES for original license terms.
|
||||
|
||||
package bsoncore
|
||||
|
||||
import "unicode/utf8"
|
||||
|
||||
// safeSet holds the value true if the ASCII character with the given array
|
||||
// position can be represented inside a JSON string without any further
|
||||
// escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), and the backslash character ("\").
|
||||
var safeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': true,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': true,
|
||||
'=': true,
|
||||
'>': true,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
||||
|
||||
// htmlSafeSet holds the value true if the ASCII character with the given
|
||||
// array position can be safely represented inside a JSON string, embedded
|
||||
// inside of HTML <script> tags, without any additional escaping.
|
||||
//
|
||||
// All values are true except for the ASCII control characters (0-31), the
|
||||
// double quote ("), the backslash character ("\"), HTML opening and closing
|
||||
// tags ("<" and ">"), and the ampersand ("&").
|
||||
var htmlSafeSet = [utf8.RuneSelf]bool{
|
||||
' ': true,
|
||||
'!': true,
|
||||
'"': false,
|
||||
'#': true,
|
||||
'$': true,
|
||||
'%': true,
|
||||
'&': false,
|
||||
'\'': true,
|
||||
'(': true,
|
||||
')': true,
|
||||
'*': true,
|
||||
'+': true,
|
||||
',': true,
|
||||
'-': true,
|
||||
'.': true,
|
||||
'/': true,
|
||||
'0': true,
|
||||
'1': true,
|
||||
'2': true,
|
||||
'3': true,
|
||||
'4': true,
|
||||
'5': true,
|
||||
'6': true,
|
||||
'7': true,
|
||||
'8': true,
|
||||
'9': true,
|
||||
':': true,
|
||||
';': true,
|
||||
'<': false,
|
||||
'=': true,
|
||||
'>': false,
|
||||
'?': true,
|
||||
'@': true,
|
||||
'A': true,
|
||||
'B': true,
|
||||
'C': true,
|
||||
'D': true,
|
||||
'E': true,
|
||||
'F': true,
|
||||
'G': true,
|
||||
'H': true,
|
||||
'I': true,
|
||||
'J': true,
|
||||
'K': true,
|
||||
'L': true,
|
||||
'M': true,
|
||||
'N': true,
|
||||
'O': true,
|
||||
'P': true,
|
||||
'Q': true,
|
||||
'R': true,
|
||||
'S': true,
|
||||
'T': true,
|
||||
'U': true,
|
||||
'V': true,
|
||||
'W': true,
|
||||
'X': true,
|
||||
'Y': true,
|
||||
'Z': true,
|
||||
'[': true,
|
||||
'\\': false,
|
||||
']': true,
|
||||
'^': true,
|
||||
'_': true,
|
||||
'`': true,
|
||||
'a': true,
|
||||
'b': true,
|
||||
'c': true,
|
||||
'd': true,
|
||||
'e': true,
|
||||
'f': true,
|
||||
'g': true,
|
||||
'h': true,
|
||||
'i': true,
|
||||
'j': true,
|
||||
'k': true,
|
||||
'l': true,
|
||||
'm': true,
|
||||
'n': true,
|
||||
'o': true,
|
||||
'p': true,
|
||||
'q': true,
|
||||
'r': true,
|
||||
's': true,
|
||||
't': true,
|
||||
'u': true,
|
||||
'v': true,
|
||||
'w': true,
|
||||
'x': true,
|
||||
'y': true,
|
||||
'z': true,
|
||||
'{': true,
|
||||
'|': true,
|
||||
'}': true,
|
||||
'~': true,
|
||||
'\u007f': true,
|
||||
}
|
||||
972
mongo/x/bsonx/bsoncore/value.go
Normal file
972
mongo/x/bsonx/bsoncore/value.go
Normal file
@@ -0,0 +1,972 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// ElementTypeError specifies that a method to obtain a BSON value an incorrect type was called on a bson.Value.
|
||||
type ElementTypeError struct {
|
||||
Method string
|
||||
Type bsontype.Type
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (ete ElementTypeError) Error() string {
|
||||
return "Call of " + ete.Method + " on " + ete.Type.String() + " type"
|
||||
}
|
||||
|
||||
// Value represents a BSON value with a type and raw bytes.
|
||||
type Value struct {
|
||||
Type bsontype.Type
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// Validate ensures the value is a valid BSON value.
|
||||
func (v Value) Validate() error {
|
||||
_, _, valid := readValue(v.Data, v.Type)
|
||||
if !valid {
|
||||
return NewInsufficientBytesError(v.Data, v.Data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsNumber returns true if the type of v is a numeric BSON type.
|
||||
func (v Value) IsNumber() bool {
|
||||
switch v.Type {
|
||||
case bsontype.Double, bsontype.Int32, bsontype.Int64, bsontype.Decimal128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// AsInt32 returns a BSON number as an int32. If the BSON type is not a numeric one, this method
|
||||
// will panic.
|
||||
func (v Value) AsInt32() int32 {
|
||||
if !v.IsNumber() {
|
||||
panic(ElementTypeError{"bsoncore.Value.AsInt32", v.Type})
|
||||
}
|
||||
var i32 int32
|
||||
switch v.Type {
|
||||
case bsontype.Double:
|
||||
f64, _, ok := ReadDouble(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
i32 = int32(f64)
|
||||
case bsontype.Int32:
|
||||
var ok bool
|
||||
i32, _, ok = ReadInt32(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
case bsontype.Int64:
|
||||
i64, _, ok := ReadInt64(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
i32 = int32(i64)
|
||||
case bsontype.Decimal128:
|
||||
panic(ElementTypeError{"bsoncore.Value.AsInt32", v.Type})
|
||||
}
|
||||
return i32
|
||||
}
|
||||
|
||||
// AsInt32OK functions the same as AsInt32 but returns a boolean instead of panicking. False
|
||||
// indicates an error.
|
||||
func (v Value) AsInt32OK() (int32, bool) {
|
||||
if !v.IsNumber() {
|
||||
return 0, false
|
||||
}
|
||||
var i32 int32
|
||||
switch v.Type {
|
||||
case bsontype.Double:
|
||||
f64, _, ok := ReadDouble(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
i32 = int32(f64)
|
||||
case bsontype.Int32:
|
||||
var ok bool
|
||||
i32, _, ok = ReadInt32(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
case bsontype.Int64:
|
||||
i64, _, ok := ReadInt64(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
i32 = int32(i64)
|
||||
case bsontype.Decimal128:
|
||||
return 0, false
|
||||
}
|
||||
return i32, true
|
||||
}
|
||||
|
||||
// AsInt64 returns a BSON number as an int64. If the BSON type is not a numeric one, this method
|
||||
// will panic.
|
||||
func (v Value) AsInt64() int64 {
|
||||
if !v.IsNumber() {
|
||||
panic(ElementTypeError{"bsoncore.Value.AsInt64", v.Type})
|
||||
}
|
||||
var i64 int64
|
||||
switch v.Type {
|
||||
case bsontype.Double:
|
||||
f64, _, ok := ReadDouble(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
i64 = int64(f64)
|
||||
case bsontype.Int32:
|
||||
var ok bool
|
||||
i32, _, ok := ReadInt32(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
i64 = int64(i32)
|
||||
case bsontype.Int64:
|
||||
var ok bool
|
||||
i64, _, ok = ReadInt64(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
case bsontype.Decimal128:
|
||||
panic(ElementTypeError{"bsoncore.Value.AsInt64", v.Type})
|
||||
}
|
||||
return i64
|
||||
}
|
||||
|
||||
// AsInt64OK functions the same as AsInt64 but returns a boolean instead of panicking. False
|
||||
// indicates an error.
|
||||
func (v Value) AsInt64OK() (int64, bool) {
|
||||
if !v.IsNumber() {
|
||||
return 0, false
|
||||
}
|
||||
var i64 int64
|
||||
switch v.Type {
|
||||
case bsontype.Double:
|
||||
f64, _, ok := ReadDouble(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
i64 = int64(f64)
|
||||
case bsontype.Int32:
|
||||
var ok bool
|
||||
i32, _, ok := ReadInt32(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
i64 = int64(i32)
|
||||
case bsontype.Int64:
|
||||
var ok bool
|
||||
i64, _, ok = ReadInt64(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
case bsontype.Decimal128:
|
||||
return 0, false
|
||||
}
|
||||
return i64, true
|
||||
}
|
||||
|
||||
// AsFloat64 returns a BSON number as an float64. If the BSON type is not a numeric one, this method
|
||||
// will panic.
|
||||
//
|
||||
// TODO(skriptble): Add support for Decimal128.
|
||||
func (v Value) AsFloat64() float64 { return 0 }
|
||||
|
||||
// AsFloat64OK functions the same as AsFloat64 but returns a boolean instead of panicking. False
|
||||
// indicates an error.
|
||||
//
|
||||
// TODO(skriptble): Add support for Decimal128.
|
||||
func (v Value) AsFloat64OK() (float64, bool) { return 0, false }
|
||||
|
||||
// Add will add this value to another. This is currently only implemented for strings and numbers.
|
||||
// If either value is a string, the other type is coerced into a string and added to the other.
|
||||
//
|
||||
// This method will alter v and will attempt to reuse the []byte of v. If the []byte is too small,
|
||||
// it will be expanded.
|
||||
func (v *Value) Add(v2 Value) error { return nil }
|
||||
|
||||
// Equal compaes v to v2 and returns true if they are equal.
|
||||
func (v Value) Equal(v2 Value) bool {
|
||||
if v.Type != v2.Type {
|
||||
return false
|
||||
}
|
||||
|
||||
return bytes.Equal(v.Data, v2.Data)
|
||||
}
|
||||
|
||||
// String implements the fmt.String interface. This method will return values in extended JSON
|
||||
// format. If the value is not valid, this returns an empty string
|
||||
func (v Value) String() string {
|
||||
switch v.Type {
|
||||
case bsontype.Double:
|
||||
f64, ok := v.DoubleOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$numberDouble":"%s"}`, formatDouble(f64))
|
||||
case bsontype.String:
|
||||
str, ok := v.StringValueOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return escapeString(str)
|
||||
case bsontype.EmbeddedDocument:
|
||||
doc, ok := v.DocumentOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return doc.String()
|
||||
case bsontype.Array:
|
||||
arr, ok := v.ArrayOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return arr.String()
|
||||
case bsontype.Binary:
|
||||
subtype, data, ok := v.BinaryOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$binary":{"base64":"%s","subType":"%02x"}}`, base64.StdEncoding.EncodeToString(data), subtype)
|
||||
case bsontype.Undefined:
|
||||
return `{"$undefined":true}`
|
||||
case bsontype.ObjectID:
|
||||
oid, ok := v.ObjectIDOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$oid":"%s"}`, oid.Hex())
|
||||
case bsontype.Boolean:
|
||||
b, ok := v.BooleanOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strconv.FormatBool(b)
|
||||
case bsontype.DateTime:
|
||||
dt, ok := v.DateTimeOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$date":{"$numberLong":"%d"}}`, dt)
|
||||
case bsontype.Null:
|
||||
return "null"
|
||||
case bsontype.Regex:
|
||||
pattern, options, ok := v.RegexOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`{"$regularExpression":{"pattern":%s,"options":"%s"}}`,
|
||||
escapeString(pattern), sortStringAlphebeticAscending(options),
|
||||
)
|
||||
case bsontype.DBPointer:
|
||||
ns, pointer, ok := v.DBPointerOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$dbPointer":{"$ref":%s,"$id":{"$oid":"%s"}}}`, escapeString(ns), pointer.Hex())
|
||||
case bsontype.JavaScript:
|
||||
js, ok := v.JavaScriptOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$code":%s}`, escapeString(js))
|
||||
case bsontype.Symbol:
|
||||
symbol, ok := v.SymbolOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$symbol":%s}`, escapeString(symbol))
|
||||
case bsontype.CodeWithScope:
|
||||
code, scope, ok := v.CodeWithScopeOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$code":%s,"$scope":%s}`, code, scope)
|
||||
case bsontype.Int32:
|
||||
i32, ok := v.Int32OK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$numberInt":"%d"}`, i32)
|
||||
case bsontype.Timestamp:
|
||||
t, i, ok := v.TimestampOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$timestamp":{"t":%v,"i":%v}}`, t, i)
|
||||
case bsontype.Int64:
|
||||
i64, ok := v.Int64OK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$numberLong":"%d"}`, i64)
|
||||
case bsontype.Decimal128:
|
||||
d128, ok := v.Decimal128OK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$numberDecimal":"%s"}`, d128.String())
|
||||
case bsontype.MinKey:
|
||||
return `{"$minKey":1}`
|
||||
case bsontype.MaxKey:
|
||||
return `{"$maxKey":1}`
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// DebugString outputs a human readable version of Document. It will attempt to stringify the
|
||||
// valid components of the document even if the entire document is not valid.
|
||||
func (v Value) DebugString() string {
|
||||
switch v.Type {
|
||||
case bsontype.String:
|
||||
str, ok := v.StringValueOK()
|
||||
if !ok {
|
||||
return "<malformed>"
|
||||
}
|
||||
return escapeString(str)
|
||||
case bsontype.EmbeddedDocument:
|
||||
doc, ok := v.DocumentOK()
|
||||
if !ok {
|
||||
return "<malformed>"
|
||||
}
|
||||
return doc.DebugString()
|
||||
case bsontype.Array:
|
||||
arr, ok := v.ArrayOK()
|
||||
if !ok {
|
||||
return "<malformed>"
|
||||
}
|
||||
return arr.DebugString()
|
||||
case bsontype.CodeWithScope:
|
||||
code, scope, ok := v.CodeWithScopeOK()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf(`{"$code":%s,"$scope":%s}`, code, scope.DebugString())
|
||||
default:
|
||||
str := v.String()
|
||||
if str == "" {
|
||||
return "<malformed>"
|
||||
}
|
||||
return str
|
||||
}
|
||||
}
|
||||
|
||||
// Double returns the float64 value for this element.
|
||||
// It panics if e's BSON type is not bsontype.Double.
|
||||
func (v Value) Double() float64 {
|
||||
if v.Type != bsontype.Double {
|
||||
panic(ElementTypeError{"bsoncore.Value.Double", v.Type})
|
||||
}
|
||||
f64, _, ok := ReadDouble(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return f64
|
||||
}
|
||||
|
||||
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
|
||||
func (v Value) DoubleOK() (float64, bool) {
|
||||
if v.Type != bsontype.Double {
|
||||
return 0, false
|
||||
}
|
||||
f64, _, ok := ReadDouble(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return f64, true
|
||||
}
|
||||
|
||||
// StringValue returns the string balue for this element.
|
||||
// It panics if e's BSON type is not bsontype.String.
|
||||
//
|
||||
// NOTE: This method is called StringValue to avoid a collision with the String method which
|
||||
// implements the fmt.Stringer interface.
|
||||
func (v Value) StringValue() string {
|
||||
if v.Type != bsontype.String {
|
||||
panic(ElementTypeError{"bsoncore.Value.StringValue", v.Type})
|
||||
}
|
||||
str, _, ok := ReadString(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// StringValueOK is the same as StringValue, but returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) StringValueOK() (string, bool) {
|
||||
if v.Type != bsontype.String {
|
||||
return "", false
|
||||
}
|
||||
str, _, ok := ReadString(v.Data)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return str, true
|
||||
}
|
||||
|
||||
// Document returns the BSON document the Value represents as a Document. It panics if the
|
||||
// value is a BSON type other than document.
|
||||
func (v Value) Document() Document {
|
||||
if v.Type != bsontype.EmbeddedDocument {
|
||||
panic(ElementTypeError{"bsoncore.Value.Document", v.Type})
|
||||
}
|
||||
doc, _, ok := ReadDocument(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return doc
|
||||
}
|
||||
|
||||
// DocumentOK is the same as Document, except it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Value) DocumentOK() (Document, bool) {
|
||||
if v.Type != bsontype.EmbeddedDocument {
|
||||
return nil, false
|
||||
}
|
||||
doc, _, ok := ReadDocument(v.Data)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return doc, true
|
||||
}
|
||||
|
||||
// Array returns the BSON array the Value represents as an Array. It panics if the
|
||||
// value is a BSON type other than array.
|
||||
func (v Value) Array() Array {
|
||||
if v.Type != bsontype.Array {
|
||||
panic(ElementTypeError{"bsoncore.Value.Array", v.Type})
|
||||
}
|
||||
arr, _, ok := ReadArray(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return arr
|
||||
}
|
||||
|
||||
// ArrayOK is the same as Array, except it returns a boolean instead
|
||||
// of panicking.
|
||||
func (v Value) ArrayOK() (Array, bool) {
|
||||
if v.Type != bsontype.Array {
|
||||
return nil, false
|
||||
}
|
||||
arr, _, ok := ReadArray(v.Data)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return arr, true
|
||||
}
|
||||
|
||||
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
|
||||
// other than binary.
|
||||
func (v Value) Binary() (subtype byte, data []byte) {
|
||||
if v.Type != bsontype.Binary {
|
||||
panic(ElementTypeError{"bsoncore.Value.Binary", v.Type})
|
||||
}
|
||||
subtype, data, _, ok := ReadBinary(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return subtype, data
|
||||
}
|
||||
|
||||
// BinaryOK is the same as Binary, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) BinaryOK() (subtype byte, data []byte, ok bool) {
|
||||
if v.Type != bsontype.Binary {
|
||||
return 0x00, nil, false
|
||||
}
|
||||
subtype, data, _, ok = ReadBinary(v.Data)
|
||||
if !ok {
|
||||
return 0x00, nil, false
|
||||
}
|
||||
return subtype, data, true
|
||||
}
|
||||
|
||||
// ObjectID returns the BSON objectid value the Value represents. It panics if the value is a BSON
|
||||
// type other than objectid.
|
||||
func (v Value) ObjectID() primitive.ObjectID {
|
||||
if v.Type != bsontype.ObjectID {
|
||||
panic(ElementTypeError{"bsoncore.Value.ObjectID", v.Type})
|
||||
}
|
||||
oid, _, ok := ReadObjectID(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return oid
|
||||
}
|
||||
|
||||
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) ObjectIDOK() (primitive.ObjectID, bool) {
|
||||
if v.Type != bsontype.ObjectID {
|
||||
return primitive.ObjectID{}, false
|
||||
}
|
||||
oid, _, ok := ReadObjectID(v.Data)
|
||||
if !ok {
|
||||
return primitive.ObjectID{}, false
|
||||
}
|
||||
return oid, true
|
||||
}
|
||||
|
||||
// Boolean returns the boolean value the Value represents. It panics if the
|
||||
// value is a BSON type other than boolean.
|
||||
func (v Value) Boolean() bool {
|
||||
if v.Type != bsontype.Boolean {
|
||||
panic(ElementTypeError{"bsoncore.Value.Boolean", v.Type})
|
||||
}
|
||||
b, _, ok := ReadBoolean(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// BooleanOK is the same as Boolean, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) BooleanOK() (bool, bool) {
|
||||
if v.Type != bsontype.Boolean {
|
||||
return false, false
|
||||
}
|
||||
b, _, ok := ReadBoolean(v.Data)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return b, true
|
||||
}
|
||||
|
||||
// DateTime returns the BSON datetime value the Value represents as a
|
||||
// unix timestamp. It panics if the value is a BSON type other than datetime.
|
||||
func (v Value) DateTime() int64 {
|
||||
if v.Type != bsontype.DateTime {
|
||||
panic(ElementTypeError{"bsoncore.Value.DateTime", v.Type})
|
||||
}
|
||||
dt, _, ok := ReadDateTime(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return dt
|
||||
}
|
||||
|
||||
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) DateTimeOK() (int64, bool) {
|
||||
if v.Type != bsontype.DateTime {
|
||||
return 0, false
|
||||
}
|
||||
dt, _, ok := ReadDateTime(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return dt, true
|
||||
}
|
||||
|
||||
// Time returns the BSON datetime value the Value represents. It panics if the value is a BSON
|
||||
// type other than datetime.
|
||||
func (v Value) Time() time.Time {
|
||||
if v.Type != bsontype.DateTime {
|
||||
panic(ElementTypeError{"bsoncore.Value.Time", v.Type})
|
||||
}
|
||||
dt, _, ok := ReadDateTime(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return time.Unix(dt/1000, dt%1000*1000000)
|
||||
}
|
||||
|
||||
// TimeOK is the same as Time, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) TimeOK() (time.Time, bool) {
|
||||
if v.Type != bsontype.DateTime {
|
||||
return time.Time{}, false
|
||||
}
|
||||
dt, _, ok := ReadDateTime(v.Data)
|
||||
if !ok {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(dt/1000, dt%1000*1000000), true
|
||||
}
|
||||
|
||||
// Regex returns the BSON regex value the Value represents. It panics if the value is a BSON
|
||||
// type other than regex.
|
||||
func (v Value) Regex() (pattern, options string) {
|
||||
if v.Type != bsontype.Regex {
|
||||
panic(ElementTypeError{"bsoncore.Value.Regex", v.Type})
|
||||
}
|
||||
pattern, options, _, ok := ReadRegex(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return pattern, options
|
||||
}
|
||||
|
||||
// RegexOK is the same as Regex, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) RegexOK() (pattern, options string, ok bool) {
|
||||
if v.Type != bsontype.Regex {
|
||||
return "", "", false
|
||||
}
|
||||
pattern, options, _, ok = ReadRegex(v.Data)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return pattern, options, true
|
||||
}
|
||||
|
||||
// DBPointer returns the BSON dbpointer value the Value represents. It panics if the value is a BSON
|
||||
// type other than DBPointer.
|
||||
func (v Value) DBPointer() (string, primitive.ObjectID) {
|
||||
if v.Type != bsontype.DBPointer {
|
||||
panic(ElementTypeError{"bsoncore.Value.DBPointer", v.Type})
|
||||
}
|
||||
ns, pointer, _, ok := ReadDBPointer(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return ns, pointer
|
||||
}
|
||||
|
||||
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Value) DBPointerOK() (string, primitive.ObjectID, bool) {
|
||||
if v.Type != bsontype.DBPointer {
|
||||
return "", primitive.ObjectID{}, false
|
||||
}
|
||||
ns, pointer, _, ok := ReadDBPointer(v.Data)
|
||||
if !ok {
|
||||
return "", primitive.ObjectID{}, false
|
||||
}
|
||||
return ns, pointer, true
|
||||
}
|
||||
|
||||
// JavaScript returns the BSON JavaScript code value the Value represents. It panics if the value is
|
||||
// a BSON type other than JavaScript code.
|
||||
func (v Value) JavaScript() string {
|
||||
if v.Type != bsontype.JavaScript {
|
||||
panic(ElementTypeError{"bsoncore.Value.JavaScript", v.Type})
|
||||
}
|
||||
js, _, ok := ReadJavaScript(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return js
|
||||
}
|
||||
|
||||
// JavaScriptOK is the same as Javascript, excepti that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Value) JavaScriptOK() (string, bool) {
|
||||
if v.Type != bsontype.JavaScript {
|
||||
return "", false
|
||||
}
|
||||
js, _, ok := ReadJavaScript(v.Data)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return js, true
|
||||
}
|
||||
|
||||
// Symbol returns the BSON symbol value the Value represents. It panics if the value is a BSON
|
||||
// type other than symbol.
|
||||
func (v Value) Symbol() string {
|
||||
if v.Type != bsontype.Symbol {
|
||||
panic(ElementTypeError{"bsoncore.Value.Symbol", v.Type})
|
||||
}
|
||||
symbol, _, ok := ReadSymbol(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return symbol
|
||||
}
|
||||
|
||||
// SymbolOK is the same as Symbol, excepti that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Value) SymbolOK() (string, bool) {
|
||||
if v.Type != bsontype.Symbol {
|
||||
return "", false
|
||||
}
|
||||
symbol, _, ok := ReadSymbol(v.Data)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return symbol, true
|
||||
}
|
||||
|
||||
// CodeWithScope returns the BSON JavaScript code with scope the Value represents.
|
||||
// It panics if the value is a BSON type other than JavaScript code with scope.
|
||||
func (v Value) CodeWithScope() (string, Document) {
|
||||
if v.Type != bsontype.CodeWithScope {
|
||||
panic(ElementTypeError{"bsoncore.Value.CodeWithScope", v.Type})
|
||||
}
|
||||
code, scope, _, ok := ReadCodeWithScope(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return code, scope
|
||||
}
|
||||
|
||||
// CodeWithScopeOK is the same as CodeWithScope, except that it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) CodeWithScopeOK() (string, Document, bool) {
|
||||
if v.Type != bsontype.CodeWithScope {
|
||||
return "", nil, false
|
||||
}
|
||||
code, scope, _, ok := ReadCodeWithScope(v.Data)
|
||||
if !ok {
|
||||
return "", nil, false
|
||||
}
|
||||
return code, scope, true
|
||||
}
|
||||
|
||||
// Int32 returns the int32 the Value represents. It panics if the value is a BSON type other than
|
||||
// int32.
|
||||
func (v Value) Int32() int32 {
|
||||
if v.Type != bsontype.Int32 {
|
||||
panic(ElementTypeError{"bsoncore.Value.Int32", v.Type})
|
||||
}
|
||||
i32, _, ok := ReadInt32(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return i32
|
||||
}
|
||||
|
||||
// Int32OK is the same as Int32, except that it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) Int32OK() (int32, bool) {
|
||||
if v.Type != bsontype.Int32 {
|
||||
return 0, false
|
||||
}
|
||||
i32, _, ok := ReadInt32(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return i32, true
|
||||
}
|
||||
|
||||
// Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a
|
||||
// BSON type other than timestamp.
|
||||
func (v Value) Timestamp() (t, i uint32) {
|
||||
if v.Type != bsontype.Timestamp {
|
||||
panic(ElementTypeError{"bsoncore.Value.Timestamp", v.Type})
|
||||
}
|
||||
t, i, _, ok := ReadTimestamp(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return t, i
|
||||
}
|
||||
|
||||
// TimestampOK is the same as Timestamp, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Value) TimestampOK() (t, i uint32, ok bool) {
|
||||
if v.Type != bsontype.Timestamp {
|
||||
return 0, 0, false
|
||||
}
|
||||
t, i, _, ok = ReadTimestamp(v.Data)
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
return t, i, true
|
||||
}
|
||||
|
||||
// Int64 returns the int64 the Value represents. It panics if the value is a BSON type other than
|
||||
// int64.
|
||||
func (v Value) Int64() int64 {
|
||||
if v.Type != bsontype.Int64 {
|
||||
panic(ElementTypeError{"bsoncore.Value.Int64", v.Type})
|
||||
}
|
||||
i64, _, ok := ReadInt64(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return i64
|
||||
}
|
||||
|
||||
// Int64OK is the same as Int64, except that it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Value) Int64OK() (int64, bool) {
|
||||
if v.Type != bsontype.Int64 {
|
||||
return 0, false
|
||||
}
|
||||
i64, _, ok := ReadInt64(v.Data)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
return i64, true
|
||||
}
|
||||
|
||||
// Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than
|
||||
// decimal.
|
||||
func (v Value) Decimal128() primitive.Decimal128 {
|
||||
if v.Type != bsontype.Decimal128 {
|
||||
panic(ElementTypeError{"bsoncore.Value.Decimal128", v.Type})
|
||||
}
|
||||
d128, _, ok := ReadDecimal128(v.Data)
|
||||
if !ok {
|
||||
panic(NewInsufficientBytesError(v.Data, v.Data))
|
||||
}
|
||||
return d128
|
||||
}
|
||||
|
||||
// Decimal128OK is the same as Decimal128, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Value) Decimal128OK() (primitive.Decimal128, bool) {
|
||||
if v.Type != bsontype.Decimal128 {
|
||||
return primitive.Decimal128{}, false
|
||||
}
|
||||
d128, _, ok := ReadDecimal128(v.Data)
|
||||
if !ok {
|
||||
return primitive.Decimal128{}, false
|
||||
}
|
||||
return d128, true
|
||||
}
|
||||
|
||||
var hexChars = "0123456789abcdef"
|
||||
|
||||
func escapeString(s string) string {
|
||||
escapeHTML := true
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('"')
|
||||
start := 0
|
||||
for i := 0; i < len(s); {
|
||||
if b := s[i]; b < utf8.RuneSelf {
|
||||
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if start < i {
|
||||
buf.WriteString(s[start:i])
|
||||
}
|
||||
switch b {
|
||||
case '\\', '"':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte(b)
|
||||
case '\n':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('n')
|
||||
case '\r':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('r')
|
||||
case '\t':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('t')
|
||||
case '\b':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('b')
|
||||
case '\f':
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteByte('f')
|
||||
default:
|
||||
// This encodes bytes < 0x20 except for \t, \n and \r.
|
||||
// If escapeHTML is set, it also escapes <, >, and &
|
||||
// because they can lead to security holes when
|
||||
// user-controlled strings are rendered into JSON
|
||||
// and served to some browsers.
|
||||
buf.WriteString(`\u00`)
|
||||
buf.WriteByte(hexChars[b>>4])
|
||||
buf.WriteByte(hexChars[b&0xF])
|
||||
}
|
||||
i++
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
c, size := utf8.DecodeRuneInString(s[i:])
|
||||
if c == utf8.RuneError && size == 1 {
|
||||
if start < i {
|
||||
buf.WriteString(s[start:i])
|
||||
}
|
||||
buf.WriteString(`\ufffd`)
|
||||
i += size
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
// U+2028 is LINE SEPARATOR.
|
||||
// U+2029 is PARAGRAPH SEPARATOR.
|
||||
// They are both technically valid characters in JSON strings,
|
||||
// but don't work in JSONP, which has to be evaluated as JavaScript,
|
||||
// and can lead to security holes there. It is valid JSON to
|
||||
// escape them, so we do so unconditionally.
|
||||
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
|
||||
if c == '\u2028' || c == '\u2029' {
|
||||
if start < i {
|
||||
buf.WriteString(s[start:i])
|
||||
}
|
||||
buf.WriteString(`\u202`)
|
||||
buf.WriteByte(hexChars[c&0xF])
|
||||
i += size
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
i += size
|
||||
}
|
||||
if start < len(s) {
|
||||
buf.WriteString(s[start:])
|
||||
}
|
||||
buf.WriteByte('"')
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func formatDouble(f float64) string {
|
||||
var s string
|
||||
if math.IsInf(f, 1) {
|
||||
s = "Infinity"
|
||||
} else if math.IsInf(f, -1) {
|
||||
s = "-Infinity"
|
||||
} else if math.IsNaN(f) {
|
||||
s = "NaN"
|
||||
} else {
|
||||
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
|
||||
// perfectly represent it.
|
||||
s = strconv.FormatFloat(f, 'G', -1, 64)
|
||||
if !strings.ContainsRune(s, '.') {
|
||||
s += ".0"
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
type sortableString []rune
|
||||
|
||||
func (ss sortableString) Len() int {
|
||||
return len(ss)
|
||||
}
|
||||
|
||||
func (ss sortableString) Less(i, j int) bool {
|
||||
return ss[i] < ss[j]
|
||||
}
|
||||
|
||||
func (ss sortableString) Swap(i, j int) {
|
||||
oldI := ss[i]
|
||||
ss[i] = ss[j]
|
||||
ss[j] = oldI
|
||||
}
|
||||
|
||||
func sortStringAlphebeticAscending(s string) string {
|
||||
ss := sortableString([]rune(s))
|
||||
sort.Sort(ss)
|
||||
return string([]rune(ss))
|
||||
}
|
||||
678
mongo/x/bsonx/bsoncore/value_test.go
Normal file
678
mongo/x/bsonx/bsoncore/value_test.go
Normal file
@@ -0,0 +1,678 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsoncore
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("Validate", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := Value{Type: bsontype.Double, Data: []byte{0x01, 0x02, 0x03, 0x04}}
|
||||
want := NewInsufficientBytesError(v.Data, v.Data)
|
||||
got := v.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
v := Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}
|
||||
var want error
|
||||
got := v.Validate()
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("IsNumber", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
val Value
|
||||
isnum bool
|
||||
}{
|
||||
{"double", Value{Type: bsontype.Double}, true},
|
||||
{"int32", Value{Type: bsontype.Int32}, true},
|
||||
{"int64", Value{Type: bsontype.Int64}, true},
|
||||
{"decimal128", Value{Type: bsontype.Decimal128}, true},
|
||||
{"string", Value{Type: bsontype.String}, false},
|
||||
{"regex", Value{Type: bsontype.Regex}, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
isnum := tc.val.IsNumber()
|
||||
if isnum != tc.isnum {
|
||||
t.Errorf("IsNumber did not return the expected boolean. got %t; want %t", isnum, tc.isnum)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
now := time.Now().Truncate(time.Millisecond)
|
||||
oid := primitive.NewObjectID()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{}
|
||||
val Value
|
||||
panicErr error
|
||||
ret []interface{}
|
||||
}{
|
||||
{
|
||||
"Double/Not Double", Value.Double, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Double", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Double/Insufficient Bytes", Value.Double, Value{Type: bsontype.Double, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Double/Success", Value.Double, Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)},
|
||||
nil,
|
||||
[]interface{}{float64(3.14159)},
|
||||
},
|
||||
{
|
||||
"DoubleOK/Not Double", Value.DoubleOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{float64(0), false},
|
||||
},
|
||||
{
|
||||
"DoubleOK/Insufficient Bytes", Value.DoubleOK, Value{Type: bsontype.Double, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
nil,
|
||||
[]interface{}{float64(0), false},
|
||||
},
|
||||
{
|
||||
"DoubleOK/Success", Value.DoubleOK, Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)},
|
||||
nil,
|
||||
[]interface{}{float64(3.14159), true},
|
||||
},
|
||||
{
|
||||
"StringValue/Not String", Value.StringValue, Value{Type: bsontype.Double},
|
||||
ElementTypeError{"bsoncore.Value.StringValue", bsontype.Double},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"StringValue/Insufficient Bytes", Value.StringValue, Value{Type: bsontype.String, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"StringValue/Zero Length", Value.StringValue, Value{Type: bsontype.String, Data: []byte{0x00, 0x00, 0x00, 0x00}},
|
||||
NewInsufficientBytesError([]byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00, 0x00, 0x00}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"StringValue/Success", Value.StringValue, Value{Type: bsontype.String, Data: AppendString(nil, "hello, world!")},
|
||||
nil,
|
||||
[]interface{}{"hello, world!"},
|
||||
},
|
||||
{
|
||||
"StringValueOK/Not String", Value.StringValueOK, Value{Type: bsontype.Double},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"StringValueOK/Insufficient Bytes", Value.StringValueOK, Value{Type: bsontype.String, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"StringValueOK/Zero Length", Value.StringValueOK, Value{Type: bsontype.String, Data: []byte{0x00, 0x00, 0x00, 0x00}},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"StringValueOK/Success", Value.StringValueOK, Value{Type: bsontype.String, Data: AppendString(nil, "hello, world!")},
|
||||
nil,
|
||||
[]interface{}{"hello, world!", true},
|
||||
},
|
||||
{
|
||||
"Document/Not Document", Value.Document, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Document", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Document/Insufficient Bytes", Value.Document, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Document/Success", Value.Document, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
nil,
|
||||
[]interface{}{Document{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
},
|
||||
{
|
||||
"DocumentOK/Not Document", Value.DocumentOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{Document(nil), false},
|
||||
},
|
||||
{
|
||||
"DocumentOK/Insufficient Bytes", Value.DocumentOK, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
nil,
|
||||
[]interface{}{Document(nil), false},
|
||||
},
|
||||
{
|
||||
"DocumentOK/Success", Value.DocumentOK, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
nil,
|
||||
[]interface{}{Document{0x05, 0x00, 0x00, 0x00, 0x00}, true},
|
||||
},
|
||||
{
|
||||
"Array/Not Array", Value.Array, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Array", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Array/Insufficient Bytes", Value.Array, Value{Type: bsontype.Array, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Array/Success", Value.Array, Value{Type: bsontype.Array, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
nil,
|
||||
[]interface{}{Array{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
},
|
||||
{
|
||||
"ArrayOK/Not Array", Value.ArrayOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{Array(nil), false},
|
||||
},
|
||||
{
|
||||
"ArrayOK/Insufficient Bytes", Value.ArrayOK, Value{Type: bsontype.Array, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
nil,
|
||||
[]interface{}{Array(nil), false},
|
||||
},
|
||||
{
|
||||
"ArrayOK/Success", Value.ArrayOK, Value{Type: bsontype.Array, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
nil,
|
||||
[]interface{}{Array{0x05, 0x00, 0x00, 0x00, 0x00}, true},
|
||||
},
|
||||
{
|
||||
"Binary/Not Binary", Value.Binary, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Binary", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Binary/Insufficient Bytes", Value.Binary, Value{Type: bsontype.Binary, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Binary/Success", Value.Binary, Value{Type: bsontype.Binary, Data: AppendBinary(nil, 0xFF, []byte{0x01, 0x02, 0x03})},
|
||||
nil,
|
||||
[]interface{}{byte(0xFF), []byte{0x01, 0x02, 0x03}},
|
||||
},
|
||||
{
|
||||
"BinaryOK/Not Binary", Value.BinaryOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{byte(0x00), []byte(nil), false},
|
||||
},
|
||||
{
|
||||
"BinaryOK/Insufficient Bytes", Value.BinaryOK, Value{Type: bsontype.Binary, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
nil,
|
||||
[]interface{}{byte(0x00), []byte(nil), false},
|
||||
},
|
||||
{
|
||||
"BinaryOK/Success", Value.BinaryOK, Value{Type: bsontype.Binary, Data: AppendBinary(nil, 0xFF, []byte{0x01, 0x02, 0x03})},
|
||||
nil,
|
||||
[]interface{}{byte(0xFF), []byte{0x01, 0x02, 0x03}, true},
|
||||
},
|
||||
{
|
||||
"ObjectID/Not ObjectID", Value.ObjectID, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.ObjectID", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ObjectID/Insufficient Bytes", Value.ObjectID, Value{Type: bsontype.ObjectID, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"ObjectID/Success", Value.ObjectID, Value{Type: bsontype.ObjectID, Data: AppendObjectID(nil, primitive.ObjectID{0x01, 0x02})},
|
||||
nil,
|
||||
[]interface{}{primitive.ObjectID{0x01, 0x02}},
|
||||
},
|
||||
{
|
||||
"ObjectIDOK/Not ObjectID", Value.ObjectIDOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{primitive.ObjectID{}, false},
|
||||
},
|
||||
{
|
||||
"ObjectIDOK/Insufficient Bytes", Value.ObjectIDOK, Value{Type: bsontype.ObjectID, Data: []byte{0x01, 0x02, 0x03, 0x04}},
|
||||
nil,
|
||||
[]interface{}{primitive.ObjectID{}, false},
|
||||
},
|
||||
{
|
||||
"ObjectIDOK/Success", Value.ObjectIDOK, Value{Type: bsontype.ObjectID, Data: AppendObjectID(nil, primitive.ObjectID{0x01, 0x02})},
|
||||
nil,
|
||||
[]interface{}{primitive.ObjectID{0x01, 0x02}, true},
|
||||
},
|
||||
{
|
||||
"Boolean/Not Boolean", Value.Boolean, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Boolean", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Boolean/Insufficient Bytes", Value.Boolean, Value{Type: bsontype.Boolean, Data: []byte{}},
|
||||
NewInsufficientBytesError([]byte{}, []byte{}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Boolean/Success", Value.Boolean, Value{Type: bsontype.Boolean, Data: AppendBoolean(nil, true)},
|
||||
nil,
|
||||
[]interface{}{true},
|
||||
},
|
||||
{
|
||||
"BooleanOK/Not Boolean", Value.BooleanOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{false, false},
|
||||
},
|
||||
{
|
||||
"BooleanOK/Insufficient Bytes", Value.BooleanOK, Value{Type: bsontype.Boolean, Data: []byte{}},
|
||||
nil,
|
||||
[]interface{}{false, false},
|
||||
},
|
||||
{
|
||||
"BooleanOK/Success", Value.BooleanOK, Value{Type: bsontype.Boolean, Data: AppendBoolean(nil, true)},
|
||||
nil,
|
||||
[]interface{}{true, true},
|
||||
},
|
||||
{
|
||||
"DateTime/Not DateTime", Value.DateTime, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.DateTime", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"DateTime/Insufficient Bytes", Value.DateTime, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{}, []byte{}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"DateTime/Success", Value.DateTime, Value{Type: bsontype.DateTime, Data: AppendDateTime(nil, 12345)},
|
||||
nil,
|
||||
[]interface{}{int64(12345)},
|
||||
},
|
||||
{
|
||||
"DateTimeOK/Not DateTime", Value.DateTimeOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{int64(0), false},
|
||||
},
|
||||
{
|
||||
"DateTimeOK/Insufficient Bytes", Value.DateTimeOK, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{int64(0), false},
|
||||
},
|
||||
{
|
||||
"DateTimeOK/Success", Value.DateTimeOK, Value{Type: bsontype.DateTime, Data: AppendDateTime(nil, 12345)},
|
||||
nil,
|
||||
[]interface{}{int64(12345), true},
|
||||
},
|
||||
{
|
||||
"Time/Not DateTime", Value.Time, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Time", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Time/Insufficient Bytes", Value.Time, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Time/Success", Value.Time, Value{Type: bsontype.DateTime, Data: AppendTime(nil, now)},
|
||||
nil,
|
||||
[]interface{}{now},
|
||||
},
|
||||
{
|
||||
"TimeOK/Not DateTime", Value.TimeOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{time.Time{}, false},
|
||||
},
|
||||
{
|
||||
"TimeOK/Insufficient Bytes", Value.TimeOK, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{time.Time{}, false},
|
||||
},
|
||||
{
|
||||
"TimeOK/Success", Value.TimeOK, Value{Type: bsontype.DateTime, Data: AppendTime(nil, now)},
|
||||
nil,
|
||||
[]interface{}{now, true},
|
||||
},
|
||||
{
|
||||
"Regex/Not Regex", Value.Regex, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Regex", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Regex/Insufficient Bytes", Value.Regex, Value{Type: bsontype.Regex, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Regex/Success", Value.Regex, Value{Type: bsontype.Regex, Data: AppendRegex(nil, "/abcdefg/", "hijkl")},
|
||||
nil,
|
||||
[]interface{}{"/abcdefg/", "hijkl"},
|
||||
},
|
||||
{
|
||||
"RegexOK/Not Regex", Value.RegexOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{"", "", false},
|
||||
},
|
||||
{
|
||||
"RegexOK/Insufficient Bytes", Value.RegexOK, Value{Type: bsontype.Regex, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{"", "", false},
|
||||
},
|
||||
{
|
||||
"RegexOK/Success", Value.RegexOK, Value{Type: bsontype.Regex, Data: AppendRegex(nil, "/abcdefg/", "hijkl")},
|
||||
nil,
|
||||
[]interface{}{"/abcdefg/", "hijkl", true},
|
||||
},
|
||||
{
|
||||
"DBPointer/Not DBPointer", Value.DBPointer, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.DBPointer", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"DBPointer/Insufficient Bytes", Value.DBPointer, Value{Type: bsontype.DBPointer, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"DBPointer/Success", Value.DBPointer, Value{Type: bsontype.DBPointer, Data: AppendDBPointer(nil, "foobar", oid)},
|
||||
nil,
|
||||
[]interface{}{"foobar", oid},
|
||||
},
|
||||
{
|
||||
"DBPointerOK/Not DBPointer", Value.DBPointerOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{"", primitive.ObjectID{}, false},
|
||||
},
|
||||
{
|
||||
"DBPointerOK/Insufficient Bytes", Value.DBPointerOK, Value{Type: bsontype.DBPointer, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{"", primitive.ObjectID{}, false},
|
||||
},
|
||||
{
|
||||
"DBPointerOK/Success", Value.DBPointerOK, Value{Type: bsontype.DBPointer, Data: AppendDBPointer(nil, "foobar", oid)},
|
||||
nil,
|
||||
[]interface{}{"foobar", oid, true},
|
||||
},
|
||||
{
|
||||
"JavaScript/Not JavaScript", Value.JavaScript, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.JavaScript", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"JavaScript/Insufficient Bytes", Value.JavaScript, Value{Type: bsontype.JavaScript, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"JavaScript/Success", Value.JavaScript, Value{Type: bsontype.JavaScript, Data: AppendJavaScript(nil, "var hello = 'world';")},
|
||||
nil,
|
||||
[]interface{}{"var hello = 'world';"},
|
||||
},
|
||||
{
|
||||
"JavaScriptOK/Not JavaScript", Value.JavaScriptOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"JavaScriptOK/Insufficient Bytes", Value.JavaScriptOK, Value{Type: bsontype.JavaScript, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"JavaScriptOK/Success", Value.JavaScriptOK, Value{Type: bsontype.JavaScript, Data: AppendJavaScript(nil, "var hello = 'world';")},
|
||||
nil,
|
||||
[]interface{}{"var hello = 'world';", true},
|
||||
},
|
||||
{
|
||||
"Symbol/Not Symbol", Value.Symbol, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Symbol", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Symbol/Insufficient Bytes", Value.Symbol, Value{Type: bsontype.Symbol, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Symbol/Success", Value.Symbol, Value{Type: bsontype.Symbol, Data: AppendSymbol(nil, "symbol123456")},
|
||||
nil,
|
||||
[]interface{}{"symbol123456"},
|
||||
},
|
||||
{
|
||||
"SymbolOK/Not Symbol", Value.SymbolOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"SymbolOK/Insufficient Bytes", Value.SymbolOK, Value{Type: bsontype.Symbol, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{"", false},
|
||||
},
|
||||
{
|
||||
"SymbolOK/Success", Value.SymbolOK, Value{Type: bsontype.Symbol, Data: AppendSymbol(nil, "symbol123456")},
|
||||
nil,
|
||||
[]interface{}{"symbol123456", true},
|
||||
},
|
||||
{
|
||||
"CodeWithScope/Not CodeWithScope", Value.CodeWithScope, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.CodeWithScope", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"CodeWithScope/Insufficient Bytes", Value.CodeWithScope, Value{Type: bsontype.CodeWithScope, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"CodeWithScope/Success", Value.CodeWithScope, Value{Type: bsontype.CodeWithScope, Data: AppendCodeWithScope(nil, "var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00})},
|
||||
nil,
|
||||
[]interface{}{"var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00}},
|
||||
},
|
||||
{
|
||||
"CodeWithScopeOK/Not CodeWithScope", Value.CodeWithScopeOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{"", Document(nil), false},
|
||||
},
|
||||
{
|
||||
"CodeWithScopeOK/Insufficient Bytes", Value.CodeWithScopeOK, Value{Type: bsontype.CodeWithScope, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{"", Document(nil), false},
|
||||
},
|
||||
{
|
||||
"CodeWithScopeOK/Success", Value.CodeWithScopeOK, Value{Type: bsontype.CodeWithScope, Data: AppendCodeWithScope(nil, "var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00})},
|
||||
nil,
|
||||
[]interface{}{"var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00}, true},
|
||||
},
|
||||
{
|
||||
"Int32/Not Int32", Value.Int32, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Int32", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Int32/Insufficient Bytes", Value.Int32, Value{Type: bsontype.Int32, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Int32/Success", Value.Int32, Value{Type: bsontype.Int32, Data: AppendInt32(nil, 1234)},
|
||||
nil,
|
||||
[]interface{}{int32(1234)},
|
||||
},
|
||||
{
|
||||
"Int32OK/Not Int32", Value.Int32OK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{int32(0), false},
|
||||
},
|
||||
{
|
||||
"Int32OK/Insufficient Bytes", Value.Int32OK, Value{Type: bsontype.Int32, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{int32(0), false},
|
||||
},
|
||||
{
|
||||
"Int32OK/Success", Value.Int32OK, Value{Type: bsontype.Int32, Data: AppendInt32(nil, 1234)},
|
||||
nil,
|
||||
[]interface{}{int32(1234), true},
|
||||
},
|
||||
{
|
||||
"Timestamp/Not Timestamp", Value.Timestamp, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Timestamp", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Timestamp/Insufficient Bytes", Value.Timestamp, Value{Type: bsontype.Timestamp, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Timestamp/Success", Value.Timestamp, Value{Type: bsontype.Timestamp, Data: AppendTimestamp(nil, 12345, 67890)},
|
||||
nil,
|
||||
[]interface{}{uint32(12345), uint32(67890)},
|
||||
},
|
||||
{
|
||||
"TimestampOK/Not Timestamp", Value.TimestampOK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{uint32(0), uint32(0), false},
|
||||
},
|
||||
{
|
||||
"TimestampOK/Insufficient Bytes", Value.TimestampOK, Value{Type: bsontype.Timestamp, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{uint32(0), uint32(0), false},
|
||||
},
|
||||
{
|
||||
"TimestampOK/Success", Value.TimestampOK, Value{Type: bsontype.Timestamp, Data: AppendTimestamp(nil, 12345, 67890)},
|
||||
nil,
|
||||
[]interface{}{uint32(12345), uint32(67890), true},
|
||||
},
|
||||
{
|
||||
"Int64/Not Int64", Value.Int64, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Int64", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Int64/Insufficient Bytes", Value.Int64, Value{Type: bsontype.Int64, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Int64/Success", Value.Int64, Value{Type: bsontype.Int64, Data: AppendInt64(nil, 1234567890)},
|
||||
nil,
|
||||
[]interface{}{int64(1234567890)},
|
||||
},
|
||||
{
|
||||
"Int64OK/Not Int64", Value.Int64OK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{int64(0), false},
|
||||
},
|
||||
{
|
||||
"Int64OK/Insufficient Bytes", Value.Int64OK, Value{Type: bsontype.Int64, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{int64(0), false},
|
||||
},
|
||||
{
|
||||
"Int64OK/Success", Value.Int64OK, Value{Type: bsontype.Int64, Data: AppendInt64(nil, 1234567890)},
|
||||
nil,
|
||||
[]interface{}{int64(1234567890), true},
|
||||
},
|
||||
{
|
||||
"Decimal128/Not Decimal128", Value.Decimal128, Value{Type: bsontype.String},
|
||||
ElementTypeError{"bsoncore.Value.Decimal128", bsontype.String},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Decimal128/Insufficient Bytes", Value.Decimal128, Value{Type: bsontype.Decimal128, Data: []byte{0x01, 0x02, 0x03}},
|
||||
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"Decimal128/Success", Value.Decimal128, Value{Type: bsontype.Decimal128, Data: AppendDecimal128(nil, primitive.NewDecimal128(12345, 67890))},
|
||||
nil,
|
||||
[]interface{}{primitive.NewDecimal128(12345, 67890)},
|
||||
},
|
||||
{
|
||||
"Decimal128OK/Not Decimal128", Value.Decimal128OK, Value{Type: bsontype.String},
|
||||
nil,
|
||||
[]interface{}{primitive.Decimal128{}, false},
|
||||
},
|
||||
{
|
||||
"Decimal128OK/Insufficient Bytes", Value.Decimal128OK, Value{Type: bsontype.Decimal128, Data: []byte{0x01, 0x02, 0x03}},
|
||||
nil,
|
||||
[]interface{}{primitive.Decimal128{}, false},
|
||||
},
|
||||
{
|
||||
"Decimal128OK/Success", Value.Decimal128OK, Value{Type: bsontype.Decimal128, Data: AppendDecimal128(nil, primitive.NewDecimal128(12345, 67890))},
|
||||
nil,
|
||||
[]interface{}{primitive.NewDecimal128(12345, 67890), true},
|
||||
},
|
||||
{
|
||||
"Timestamp.String/Success", Value.String, Value{Type: bsontype.Timestamp, Data: AppendTimestamp(nil, 12345, 67890)},
|
||||
nil,
|
||||
[]interface{}{"{\"$timestamp\":{\"t\":12345,\"i\":67890}}"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defer func() {
|
||||
err := recover()
|
||||
if !cmp.Equal(err, tc.panicErr, cmp.Comparer(compareErrors)) {
|
||||
t.Errorf("Did not receive expected panic error. got %v; want %v", err, tc.panicErr)
|
||||
}
|
||||
}()
|
||||
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func || fn.Type().NumIn() != 1 || fn.Type().In(0) != reflect.TypeOf(Value{}) {
|
||||
t.Fatalf("test case field fn must be a function with 1 parameter that is a Value, but it is %v", fn.Type())
|
||||
}
|
||||
got := fn.Call([]reflect.Value{reflect.ValueOf(tc.val)})
|
||||
want := make([]reflect.Value, 0, len(tc.ret))
|
||||
for _, ret := range tc.ret {
|
||||
want = append(want, reflect.ValueOf(ret))
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("incorrect number of values returned. got %d; want %d", len(got), len(want))
|
||||
}
|
||||
|
||||
for idx := range got {
|
||||
gotv, wantv := got[idx].Interface(), want[idx].Interface()
|
||||
if !cmp.Equal(gotv, wantv, cmp.Comparer(compareDecimal128)) {
|
||||
t.Errorf("return values at index %d are not equal. got %v; want %v", idx, gotv, wantv)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
166
mongo/x/bsonx/constructor.go
Normal file
166
mongo/x/bsonx/constructor.go
Normal file
@@ -0,0 +1,166 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
// IDoc is the interface implemented by Doc and MDoc. It allows either of these types to be provided
|
||||
// to the Document function to create a Value.
|
||||
type IDoc interface {
|
||||
idoc()
|
||||
}
|
||||
|
||||
// Double constructs a BSON double Value.
|
||||
func Double(f64 float64) Val {
|
||||
v := Val{t: bsontype.Double}
|
||||
binary.LittleEndian.PutUint64(v.bootstrap[0:8], math.Float64bits(f64))
|
||||
return v
|
||||
}
|
||||
|
||||
// String constructs a BSON string Value.
|
||||
func String(str string) Val { return Val{t: bsontype.String}.writestring(str) }
|
||||
|
||||
// Document constructs a Value from the given IDoc. If nil is provided, a BSON Null value will be
|
||||
// returned.
|
||||
func Document(doc IDoc) Val {
|
||||
var v Val
|
||||
switch tt := doc.(type) {
|
||||
case Doc:
|
||||
if tt == nil {
|
||||
v.t = bsontype.Null
|
||||
break
|
||||
}
|
||||
v.t = bsontype.EmbeddedDocument
|
||||
v.primitive = tt
|
||||
case MDoc:
|
||||
if tt == nil {
|
||||
v.t = bsontype.Null
|
||||
break
|
||||
}
|
||||
v.t = bsontype.EmbeddedDocument
|
||||
v.primitive = tt
|
||||
default:
|
||||
v.t = bsontype.Null
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Array constructs a Value from arr. If arr is nil, a BSON Null value is returned.
|
||||
func Array(arr Arr) Val {
|
||||
if arr == nil {
|
||||
return Val{t: bsontype.Null}
|
||||
}
|
||||
return Val{t: bsontype.Array, primitive: arr}
|
||||
}
|
||||
|
||||
// Binary constructs a BSON binary Value.
|
||||
func Binary(subtype byte, data []byte) Val {
|
||||
return Val{t: bsontype.Binary, primitive: primitive.Binary{Subtype: subtype, Data: data}}
|
||||
}
|
||||
|
||||
// Undefined constructs a BSON binary Value.
|
||||
func Undefined() Val { return Val{t: bsontype.Undefined} }
|
||||
|
||||
// ObjectID constructs a BSON objectid Value.
|
||||
func ObjectID(oid primitive.ObjectID) Val {
|
||||
v := Val{t: bsontype.ObjectID}
|
||||
copy(v.bootstrap[0:12], oid[:])
|
||||
return v
|
||||
}
|
||||
|
||||
// Boolean constructs a BSON boolean Value.
|
||||
func Boolean(b bool) Val {
|
||||
v := Val{t: bsontype.Boolean}
|
||||
if b {
|
||||
v.bootstrap[0] = 0x01
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// DateTime constructs a BSON datetime Value.
|
||||
func DateTime(dt int64) Val { return Val{t: bsontype.DateTime}.writei64(dt) }
|
||||
|
||||
// Time constructs a BSON datetime Value.
|
||||
func Time(t time.Time) Val {
|
||||
return Val{t: bsontype.DateTime}.writei64(t.Unix()*1e3 + int64(t.Nanosecond()/1e6))
|
||||
}
|
||||
|
||||
// Null constructs a BSON binary Value.
|
||||
func Null() Val { return Val{t: bsontype.Null} }
|
||||
|
||||
// Regex constructs a BSON regex Value.
|
||||
func Regex(pattern, options string) Val {
|
||||
regex := primitive.Regex{Pattern: pattern, Options: options}
|
||||
return Val{t: bsontype.Regex, primitive: regex}
|
||||
}
|
||||
|
||||
// DBPointer constructs a BSON dbpointer Value.
|
||||
func DBPointer(ns string, ptr primitive.ObjectID) Val {
|
||||
dbptr := primitive.DBPointer{DB: ns, Pointer: ptr}
|
||||
return Val{t: bsontype.DBPointer, primitive: dbptr}
|
||||
}
|
||||
|
||||
// JavaScript constructs a BSON javascript Value.
|
||||
func JavaScript(js string) Val {
|
||||
return Val{t: bsontype.JavaScript}.writestring(js)
|
||||
}
|
||||
|
||||
// Symbol constructs a BSON symbol Value.
|
||||
func Symbol(symbol string) Val {
|
||||
return Val{t: bsontype.Symbol}.writestring(symbol)
|
||||
}
|
||||
|
||||
// CodeWithScope constructs a BSON code with scope Value.
|
||||
func CodeWithScope(code string, scope IDoc) Val {
|
||||
cws := primitive.CodeWithScope{Code: primitive.JavaScript(code), Scope: scope}
|
||||
return Val{t: bsontype.CodeWithScope, primitive: cws}
|
||||
}
|
||||
|
||||
// Int32 constructs a BSON int32 Value.
|
||||
func Int32(i32 int32) Val {
|
||||
v := Val{t: bsontype.Int32}
|
||||
v.bootstrap[0] = byte(i32)
|
||||
v.bootstrap[1] = byte(i32 >> 8)
|
||||
v.bootstrap[2] = byte(i32 >> 16)
|
||||
v.bootstrap[3] = byte(i32 >> 24)
|
||||
return v
|
||||
}
|
||||
|
||||
// Timestamp constructs a BSON timestamp Value.
|
||||
func Timestamp(t, i uint32) Val {
|
||||
v := Val{t: bsontype.Timestamp}
|
||||
v.bootstrap[0] = byte(i)
|
||||
v.bootstrap[1] = byte(i >> 8)
|
||||
v.bootstrap[2] = byte(i >> 16)
|
||||
v.bootstrap[3] = byte(i >> 24)
|
||||
v.bootstrap[4] = byte(t)
|
||||
v.bootstrap[5] = byte(t >> 8)
|
||||
v.bootstrap[6] = byte(t >> 16)
|
||||
v.bootstrap[7] = byte(t >> 24)
|
||||
return v
|
||||
}
|
||||
|
||||
// Int64 constructs a BSON int64 Value.
|
||||
func Int64(i64 int64) Val { return Val{t: bsontype.Int64}.writei64(i64) }
|
||||
|
||||
// Decimal128 constructs a BSON decimal128 Value.
|
||||
func Decimal128(d128 primitive.Decimal128) Val {
|
||||
return Val{t: bsontype.Decimal128, primitive: d128}
|
||||
}
|
||||
|
||||
// MinKey constructs a BSON minkey Value.
|
||||
func MinKey() Val { return Val{t: bsontype.MinKey} }
|
||||
|
||||
// MaxKey constructs a BSON maxkey Value.
|
||||
func MaxKey() Val { return Val{t: bsontype.MaxKey} }
|
||||
305
mongo/x/bsonx/document.go
Normal file
305
mongo/x/bsonx/document.go
Normal file
@@ -0,0 +1,305 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// ErrNilDocument indicates that an operation was attempted on a nil *bson.Document.
|
||||
var ErrNilDocument = errors.New("document is nil")
|
||||
|
||||
// KeyNotFound is an error type returned from the Lookup methods on Document. This type contains
|
||||
// information about which key was not found and if it was actually not found or if a component of
|
||||
// the key except the last was not a document nor array.
|
||||
type KeyNotFound struct {
|
||||
Key []string // The keys that were searched for.
|
||||
Depth uint // Which key either was not found or was an incorrect type.
|
||||
Type bsontype.Type // The type of the key that was found but was an incorrect type.
|
||||
}
|
||||
|
||||
func (knf KeyNotFound) Error() string {
|
||||
depth := knf.Depth
|
||||
if depth >= uint(len(knf.Key)) {
|
||||
depth = uint(len(knf.Key)) - 1
|
||||
}
|
||||
|
||||
if len(knf.Key) == 0 {
|
||||
return "no keys were provided for lookup"
|
||||
}
|
||||
|
||||
if knf.Type != bsontype.Type(0) {
|
||||
return fmt.Sprintf(`key "%s" was found but was not valid to traverse BSON type %s`, knf.Key[depth], knf.Type)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`key "%s" was not found`, knf.Key[depth])
|
||||
}
|
||||
|
||||
// Doc is a type safe, concise BSON document representation.
|
||||
type Doc []Elem
|
||||
|
||||
// ReadDoc will create a Document using the provided slice of bytes. If the
|
||||
// slice of bytes is not a valid BSON document, this method will return an error.
|
||||
func ReadDoc(b []byte) (Doc, error) {
|
||||
doc := make(Doc, 0)
|
||||
err := doc.UnmarshalBSON(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// Copy makes a shallow copy of this document.
|
||||
func (d Doc) Copy() Doc {
|
||||
d2 := make(Doc, len(d))
|
||||
copy(d2, d)
|
||||
return d2
|
||||
}
|
||||
|
||||
// Append adds an element to the end of the document, creating it from the key and value provided.
|
||||
func (d Doc) Append(key string, val Val) Doc {
|
||||
return append(d, Elem{Key: key, Value: val})
|
||||
}
|
||||
|
||||
// Prepend adds an element to the beginning of the document, creating it from the key and value provided.
|
||||
func (d Doc) Prepend(key string, val Val) Doc {
|
||||
// TODO: should we just modify d itself instead of doing an alloc here?
|
||||
return append(Doc{{Key: key, Value: val}}, d...)
|
||||
}
|
||||
|
||||
// Set replaces an element of a document. If an element with a matching key is
|
||||
// found, the element will be replaced with the one provided. If the document
|
||||
// does not have an element with that key, the element is appended to the
|
||||
// document instead.
|
||||
func (d Doc) Set(key string, val Val) Doc {
|
||||
idx := d.IndexOf(key)
|
||||
if idx == -1 {
|
||||
return append(d, Elem{Key: key, Value: val})
|
||||
}
|
||||
d[idx] = Elem{Key: key, Value: val}
|
||||
return d
|
||||
}
|
||||
|
||||
// IndexOf returns the index of the first element with a key of key, or -1 if no element with a key
|
||||
// was found.
|
||||
func (d Doc) IndexOf(key string) int {
|
||||
for i, e := range d {
|
||||
if e.Key == key {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Delete removes the element with key if it exists and returns the updated Doc.
|
||||
func (d Doc) Delete(key string) Doc {
|
||||
idx := d.IndexOf(key)
|
||||
if idx == -1 {
|
||||
return d
|
||||
}
|
||||
return append(d[:idx], d[idx+1:]...)
|
||||
}
|
||||
|
||||
// Lookup searches the document and potentially subdocuments or arrays for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
//
|
||||
// This method will return an empty Value if they key does not exist. To know if they key actually
|
||||
// exists, use LookupErr.
|
||||
func (d Doc) Lookup(key ...string) Val {
|
||||
val, _ := d.LookupErr(key...)
|
||||
return val
|
||||
}
|
||||
|
||||
// LookupErr searches the document and potentially subdocuments or arrays for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
func (d Doc) LookupErr(key ...string) (Val, error) {
|
||||
elem, err := d.LookupElementErr(key...)
|
||||
return elem.Value, err
|
||||
}
|
||||
|
||||
// LookupElement searches the document and potentially subdocuments or arrays for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
//
|
||||
// This method will return an empty Element if they key does not exist. To know if they key actually
|
||||
// exists, use LookupElementErr.
|
||||
func (d Doc) LookupElement(key ...string) Elem {
|
||||
elem, _ := d.LookupElementErr(key...)
|
||||
return elem
|
||||
}
|
||||
|
||||
// LookupElementErr searches the document and potentially subdocuments for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
func (d Doc) LookupElementErr(key ...string) (Elem, error) {
|
||||
// KeyNotFound operates by being created where the error happens and then the depth is
|
||||
// incremented by 1 as each function unwinds. Whenever this function returns, it also assigns
|
||||
// the Key slice to the key slice it has. This ensures that the proper depth is identified and
|
||||
// the proper keys.
|
||||
if len(key) == 0 {
|
||||
return Elem{}, KeyNotFound{Key: key}
|
||||
}
|
||||
|
||||
var elem Elem
|
||||
var err error
|
||||
idx := d.IndexOf(key[0])
|
||||
if idx == -1 {
|
||||
return Elem{}, KeyNotFound{Key: key}
|
||||
}
|
||||
|
||||
elem = d[idx]
|
||||
if len(key) == 1 {
|
||||
return elem, nil
|
||||
}
|
||||
|
||||
switch elem.Value.Type() {
|
||||
case bsontype.EmbeddedDocument:
|
||||
switch tt := elem.Value.primitive.(type) {
|
||||
case Doc:
|
||||
elem, err = tt.LookupElementErr(key[1:]...)
|
||||
case MDoc:
|
||||
elem, err = tt.LookupElementErr(key[1:]...)
|
||||
}
|
||||
default:
|
||||
return Elem{}, KeyNotFound{Type: elem.Value.Type()}
|
||||
}
|
||||
switch tt := err.(type) {
|
||||
case KeyNotFound:
|
||||
tt.Depth++
|
||||
tt.Key = key
|
||||
return Elem{}, tt
|
||||
case nil:
|
||||
return elem, nil
|
||||
default:
|
||||
return Elem{}, err // We can't actually hit this.
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
|
||||
//
|
||||
// This method will never return an error.
|
||||
func (d Doc) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
if d == nil {
|
||||
// TODO: Should we do this?
|
||||
return bsontype.Null, nil, nil
|
||||
}
|
||||
data, _ := d.MarshalBSON()
|
||||
return bsontype.EmbeddedDocument, data, nil
|
||||
}
|
||||
|
||||
// MarshalBSON implements the Marshaler interface.
|
||||
//
|
||||
// This method will never return an error.
|
||||
func (d Doc) MarshalBSON() ([]byte, error) { return d.AppendMarshalBSON(nil) }
|
||||
|
||||
// AppendMarshalBSON marshals Doc to BSON bytes, appending to dst.
|
||||
//
|
||||
// This method will never return an error.
|
||||
func (d Doc) AppendMarshalBSON(dst []byte) ([]byte, error) {
|
||||
idx, dst := bsoncore.ReserveLength(dst)
|
||||
for _, elem := range d {
|
||||
t, data, _ := elem.Value.MarshalBSONValue() // Value.MarshalBSONValue never returns an error.
|
||||
dst = append(dst, byte(t))
|
||||
dst = append(dst, elem.Key...)
|
||||
dst = append(dst, 0x00)
|
||||
dst = append(dst, data...)
|
||||
}
|
||||
dst = append(dst, 0x00)
|
||||
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// UnmarshalBSON implements the Unmarshaler interface.
|
||||
func (d *Doc) UnmarshalBSON(b []byte) error {
|
||||
if d == nil {
|
||||
return ErrNilDocument
|
||||
}
|
||||
|
||||
if err := bsoncore.Document(b).Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elems, err := bsoncore.Document(b).Elements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var val Val
|
||||
for _, elem := range elems {
|
||||
rawv := elem.Value()
|
||||
err = val.UnmarshalBSONValue(rawv.Type, rawv.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = d.Append(elem.Key(), val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalBSONValue implements the bson.ValueUnmarshaler interface.
|
||||
func (d *Doc) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
|
||||
if t != bsontype.EmbeddedDocument {
|
||||
return fmt.Errorf("cannot unmarshal %s into a bsonx.Doc", t)
|
||||
}
|
||||
return d.UnmarshalBSON(data)
|
||||
}
|
||||
|
||||
// Equal compares this document to another, returning true if they are equal.
|
||||
func (d Doc) Equal(id IDoc) bool {
|
||||
switch tt := id.(type) {
|
||||
case Doc:
|
||||
d2 := tt
|
||||
if len(d) != len(d2) {
|
||||
return false
|
||||
}
|
||||
for idx := range d {
|
||||
if !d[idx].Equal(d2[idx]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
case MDoc:
|
||||
unique := make(map[string]struct{})
|
||||
for _, elem := range d {
|
||||
unique[elem.Key] = struct{}{}
|
||||
val, ok := tt[elem.Key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !val.Equal(elem.Value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(unique) != len(tt) {
|
||||
return false
|
||||
}
|
||||
case nil:
|
||||
return d == nil
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface.
|
||||
func (d Doc) String() string {
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte("bson.Document{"))
|
||||
for idx, elem := range d {
|
||||
if idx > 0 {
|
||||
buf.Write([]byte(", "))
|
||||
}
|
||||
fmt.Fprintf(&buf, "%v", elem)
|
||||
}
|
||||
buf.WriteByte('}')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (Doc) idoc() {}
|
||||
295
mongo/x/bsonx/document_test.go
Normal file
295
mongo/x/bsonx/document_test.go
Normal file
@@ -0,0 +1,295 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func ExampleDocument() {
|
||||
internalVersion := "1234567"
|
||||
|
||||
f := func(appName string) Doc {
|
||||
doc := Doc{
|
||||
{"driver", Document(Doc{{"name", String("mongo-go-driver")}, {"version", String(internalVersion)}})},
|
||||
{"os", Document(Doc{{"type", String("darwin")}, {"architecture", String("amd64")}})},
|
||||
{"platform", String("go1.11.1")},
|
||||
}
|
||||
if appName != "" {
|
||||
doc = append(doc, Elem{"application", Document(MDoc{"name": String(appName)})})
|
||||
}
|
||||
|
||||
return doc
|
||||
}
|
||||
buf, err := f("hello-world").MarshalBSON()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
fmt.Println(buf)
|
||||
|
||||
// Output: [178 0 0 0 3 100 114 105 118 101 114 0 52 0 0 0 2 110 97 109 101 0 16 0 0 0 109 111 110 103 111 45 103 111 45 100 114 105 118 101 114 0 2 118 101 114 115 105 111 110 0 8 0 0 0 49 50 51 52 53 54 55 0 0 3 111 115 0 46 0 0 0 2 116 121 112 101 0 7 0 0 0 100 97 114 119 105 110 0 2 97 114 99 104 105 116 101 99 116 117 114 101 0 6 0 0 0 97 109 100 54 52 0 0 2 112 108 97 116 102 111 114 109 0 9 0 0 0 103 111 49 46 49 49 46 49 0 3 97 112 112 108 105 99 97 116 105 111 110 0 27 0 0 0 2 110 97 109 101 0 12 0 0 0 104 101 108 108 111 45 119 111 114 108 100 0 0 0]
|
||||
}
|
||||
|
||||
func BenchmarkDocument(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
internalVersion := "1234567"
|
||||
for i := 0; i < b.N; i++ {
|
||||
doc := Doc{
|
||||
{"driver", Document(Doc{{"name", String("mongo-go-driver")}, {"version", String(internalVersion)}})},
|
||||
{"os", Document(Doc{{"type", String("darwin")}, {"architecture", String("amd64")}})},
|
||||
{"platform", String("go1.11.1")},
|
||||
}
|
||||
_, _ = doc.MarshalBSON()
|
||||
}
|
||||
}
|
||||
|
||||
func TestDocument(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("ReadDocument", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("UnmarshalingError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
name string
|
||||
invalid []byte
|
||||
}{
|
||||
{"base", []byte{0x01, 0x02}},
|
||||
{"fuzzed1", []byte("0\x990\xc4")}, // fuzzed
|
||||
{"fuzzed2", []byte("\x10\x00\x00\x00\x10\x000000\x0600\x00\x05\x00\xff\xff\xff\u007f")}, // fuzzed
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
want := bsoncore.NewInsufficientBytesError(nil, nil)
|
||||
_, got := ReadDoc(tc.invalid)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Expected errors to match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
valid := bsoncore.BuildDocument(nil, bsoncore.AppendNullElement(nil, "foobar"))
|
||||
var want error
|
||||
wantDoc := Doc{{"foobar", Null()}}
|
||||
gotDoc, got := ReadDoc(valid)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Expected errors to match. got %v; want %v", got, want)
|
||||
}
|
||||
if !cmp.Equal(gotDoc, wantDoc) {
|
||||
t.Errorf("Expected returned documents to match. got %v; want %v", gotDoc, wantDoc)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("Copy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
name string
|
||||
start Doc
|
||||
copy Doc
|
||||
}{
|
||||
{"nil", nil, Doc{}},
|
||||
{"not-nil", Doc{{"foobar", Null()}}, Doc{{"foobar", Null()}}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
copy := tc.start.Copy()
|
||||
if !cmp.Equal(copy, tc.copy) {
|
||||
t.Errorf("Expected copies to be equal. got %v; want %v", copy, tc.copy)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{} // method to call
|
||||
params []interface{} // parameters
|
||||
rets []interface{} // returns
|
||||
}{
|
||||
{
|
||||
"Append", Doc{}.Append,
|
||||
[]interface{}{"foo", Null()},
|
||||
[]interface{}{Doc{{"foo", Null()}}},
|
||||
},
|
||||
{
|
||||
"Prepend", Doc{}.Prepend,
|
||||
[]interface{}{"foo", Null()},
|
||||
[]interface{}{Doc{{"foo", Null()}}},
|
||||
},
|
||||
{
|
||||
"Set/append", Doc{{"foo", Null()}}.Set,
|
||||
[]interface{}{"bar", Null()},
|
||||
[]interface{}{Doc{{"foo", Null()}, {"bar", Null()}}},
|
||||
},
|
||||
{
|
||||
"Set/replace", Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}.Set,
|
||||
[]interface{}{"bar", Int64(1234567890)},
|
||||
[]interface{}{Doc{{"foo", Null()}, {"bar", Int64(1234567890)}, {"baz", Double(3.14159)}}},
|
||||
},
|
||||
{
|
||||
"Delete/doesn't exist", Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}.Delete,
|
||||
[]interface{}{"qux"},
|
||||
[]interface{}{Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}},
|
||||
},
|
||||
{
|
||||
"Delete/exists", Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}.Delete,
|
||||
[]interface{}{"bar"},
|
||||
[]interface{}{Doc{{"foo", Null()}, {"baz", Double(3.14159)}}},
|
||||
},
|
||||
{
|
||||
"Lookup/err", Doc{}.Lookup,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Val{}},
|
||||
},
|
||||
{
|
||||
"Lookup/success", Doc{{"pi", Double(3.14159)}}.Lookup,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Double(3.14159)},
|
||||
},
|
||||
{
|
||||
"LookupErr/err", Doc{}.LookupErr,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Val{}, KeyNotFound{Key: []string{}}},
|
||||
},
|
||||
{
|
||||
"LookupErr/success", Doc{{"pi", Double(3.14159)}}.LookupErr,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Double(3.14159), error(nil)},
|
||||
},
|
||||
{
|
||||
"LookupElem/err", Doc{}.LookupElement,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Elem{}},
|
||||
},
|
||||
{
|
||||
"LookupElem/success", Doc{{"pi", Double(3.14159)}}.LookupElement,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/zero length key", Doc{}.LookupElementErr,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{}}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/key not found", Doc{}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo"}}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/key not found/depth 2", Doc{{"foo", Document(Doc{})}}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "bar"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "bar"}, Depth: 1}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/invalid depth 2 type", Doc{{"foo", Document(Doc{{"pi", Double(3.14159)}})}}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "pi", "baz"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "pi", "baz"}, Depth: 1, Type: bsontype.Double}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/success", Doc{{"pi", Double(3.14159)}}.LookupElementErr,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/success/depth 2", Doc{{"foo", Document(Doc{{"pi", Double(3.14159)}})}}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
|
||||
},
|
||||
{
|
||||
"MarshalBSONValue/nil", Doc(nil).MarshalBSONValue,
|
||||
nil,
|
||||
[]interface{}{bsontype.Null, []byte(nil), error(nil)},
|
||||
},
|
||||
{
|
||||
"MarshalBSONValue/success", Doc{{"pi", Double(3.14159)}}.MarshalBSONValue, nil,
|
||||
[]interface{}{
|
||||
bsontype.EmbeddedDocument,
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)),
|
||||
error(nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
"MarshalBSON", Doc{{"pi", Double(3.14159)}}.MarshalBSON, nil,
|
||||
[]interface{}{bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
|
||||
},
|
||||
{
|
||||
"MarshalBSON/empty", Doc{}.MarshalBSON, nil,
|
||||
[]interface{}{bsoncore.BuildDocument(nil, nil), error(nil)},
|
||||
},
|
||||
{
|
||||
"AppendMarshalBSON", Doc{{"pi", Double(3.14159)}}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
|
||||
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
|
||||
},
|
||||
{
|
||||
"AppendMarshalBSON/empty", Doc{}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
|
||||
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, nil), error(nil)},
|
||||
},
|
||||
{"Equal/IDoc nil", Doc(nil).Equal, []interface{}{IDoc(nil)}, []interface{}{true}},
|
||||
{"Equal/MDoc nil", Doc(nil).Equal, []interface{}{MDoc(nil)}, []interface{}{true}},
|
||||
{"Equal/Doc/different length", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{Doc{}}, []interface{}{false}},
|
||||
{"Equal/Doc/elems not equal", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{Doc{{"pi", Int32(1)}}}, []interface{}{false}},
|
||||
{"Equal/Doc/success", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{Doc{{"pi", Double(3.14159)}}}, []interface{}{true}},
|
||||
{"Equal/MDoc/elems not equal", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"pi": Int32(1)}}, []interface{}{false}},
|
||||
{"Equal/MDoc/elems not found", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"foo": Int32(1)}}, []interface{}{false}},
|
||||
{
|
||||
"Equal/MDoc/duplicate",
|
||||
Doc{{"a", Int32(1)}, {"a", Int32(1)}}.Equal, []interface{}{MDoc{"a": Int32(1), "b": Int32(2)}},
|
||||
[]interface{}{false},
|
||||
},
|
||||
{"Equal/MDoc/success", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"pi": Double(3.14159)}}, []interface{}{true}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("property fn must be a function, but it is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != len(tc.params) && !fn.Type().IsVariadic() {
|
||||
t.Fatalf("number of parameters does not match. fn takes %d, but was provided %d", fn.Type().NumIn(), len(tc.params))
|
||||
}
|
||||
params := make([]reflect.Value, 0, len(tc.params))
|
||||
for idx, param := range tc.params {
|
||||
if param == nil {
|
||||
params = append(params, reflect.New(fn.Type().In(idx)).Elem())
|
||||
continue
|
||||
}
|
||||
params = append(params, reflect.ValueOf(param))
|
||||
}
|
||||
var rets []reflect.Value
|
||||
if fn.Type().IsVariadic() {
|
||||
rets = fn.CallSlice(params)
|
||||
} else {
|
||||
rets = fn.Call(params)
|
||||
}
|
||||
if len(rets) != len(tc.rets) {
|
||||
t.Fatalf("mismatched number of returns. received %d; expected %d", len(rets), len(tc.rets))
|
||||
}
|
||||
for idx := range rets {
|
||||
got, want := rets[idx].Interface(), tc.rets[idx]
|
||||
if !cmp.Equal(got, want) {
|
||||
t.Errorf("Return %d does not match. got %v; want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
51
mongo/x/bsonx/element.go
Normal file
51
mongo/x/bsonx/element.go
Normal file
@@ -0,0 +1,51 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
// ElementTypeError specifies that a method to obtain a BSON value an incorrect type was called on a bson.Value.
|
||||
//
|
||||
// TODO: rename this ValueTypeError.
|
||||
type ElementTypeError struct {
|
||||
Method string
|
||||
Type bsontype.Type
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (ete ElementTypeError) Error() string {
|
||||
return "Call of " + ete.Method + " on " + ete.Type.String() + " type"
|
||||
}
|
||||
|
||||
// Elem represents a BSON element.
|
||||
//
|
||||
// NOTE: Element cannot be the value of a map nor a property of a struct without special handling.
|
||||
// The default encoders and decoders will not process Element correctly. To do so would require
|
||||
// information loss since an Element contains a key, but the keys used when encoding a struct are
|
||||
// the struct field names. Instead of using an Element, use a Value as a value in a map or a
|
||||
// property of a struct.
|
||||
type Elem struct {
|
||||
Key string
|
||||
Value Val
|
||||
}
|
||||
|
||||
// Equal compares e and e2 and returns true if they are equal.
|
||||
func (e Elem) Equal(e2 Elem) bool {
|
||||
if e.Key != e2.Key {
|
||||
return false
|
||||
}
|
||||
return e.Value.Equal(e2.Value)
|
||||
}
|
||||
|
||||
func (e Elem) String() string {
|
||||
// TODO(GODRIVER-612): When bsoncore has appenders for extended JSON use that here.
|
||||
return fmt.Sprintf(`bson.Element{"%s": %v}`, e.Key, e.Value)
|
||||
}
|
||||
7
mongo/x/bsonx/element_test.go
Normal file
7
mongo/x/bsonx/element_test.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
231
mongo/x/bsonx/mdocument.go
Normal file
231
mongo/x/bsonx/mdocument.go
Normal file
@@ -0,0 +1,231 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// MDoc is an unordered, type safe, concise BSON document representation. This type should not be
|
||||
// used if you require ordering of values or duplicate keys.
|
||||
type MDoc map[string]Val
|
||||
|
||||
// ReadMDoc will create a Doc using the provided slice of bytes. If the
|
||||
// slice of bytes is not a valid BSON document, this method will return an error.
|
||||
func ReadMDoc(b []byte) (MDoc, error) {
|
||||
doc := make(MDoc)
|
||||
err := doc.UnmarshalBSON(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// Copy makes a shallow copy of this document.
|
||||
func (d MDoc) Copy() MDoc {
|
||||
d2 := make(MDoc, len(d))
|
||||
for k, v := range d {
|
||||
d2[k] = v
|
||||
}
|
||||
return d2
|
||||
}
|
||||
|
||||
// Lookup searches the document and potentially subdocuments or arrays for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
//
|
||||
// This method will return an empty Value if they key does not exist. To know if they key actually
|
||||
// exists, use LookupErr.
|
||||
func (d MDoc) Lookup(key ...string) Val {
|
||||
val, _ := d.LookupErr(key...)
|
||||
return val
|
||||
}
|
||||
|
||||
// LookupErr searches the document and potentially subdocuments or arrays for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
func (d MDoc) LookupErr(key ...string) (Val, error) {
|
||||
elem, err := d.LookupElementErr(key...)
|
||||
return elem.Value, err
|
||||
}
|
||||
|
||||
// LookupElement searches the document and potentially subdocuments or arrays for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
//
|
||||
// This method will return an empty Element if they key does not exist. To know if they key actually
|
||||
// exists, use LookupElementErr.
|
||||
func (d MDoc) LookupElement(key ...string) Elem {
|
||||
elem, _ := d.LookupElementErr(key...)
|
||||
return elem
|
||||
}
|
||||
|
||||
// LookupElementErr searches the document and potentially subdocuments for the
|
||||
// provided key. Each key provided to this method represents a layer of depth.
|
||||
func (d MDoc) LookupElementErr(key ...string) (Elem, error) {
|
||||
// KeyNotFound operates by being created where the error happens and then the depth is
|
||||
// incremented by 1 as each function unwinds. Whenever this function returns, it also assigns
|
||||
// the Key slice to the key slice it has. This ensures that the proper depth is identified and
|
||||
// the proper keys.
|
||||
if len(key) == 0 {
|
||||
return Elem{}, KeyNotFound{Key: key}
|
||||
}
|
||||
|
||||
var elem Elem
|
||||
var err error
|
||||
val, ok := d[key[0]]
|
||||
if !ok {
|
||||
return Elem{}, KeyNotFound{Key: key}
|
||||
}
|
||||
|
||||
if len(key) == 1 {
|
||||
return Elem{Key: key[0], Value: val}, nil
|
||||
}
|
||||
|
||||
switch val.Type() {
|
||||
case bsontype.EmbeddedDocument:
|
||||
switch tt := val.primitive.(type) {
|
||||
case Doc:
|
||||
elem, err = tt.LookupElementErr(key[1:]...)
|
||||
case MDoc:
|
||||
elem, err = tt.LookupElementErr(key[1:]...)
|
||||
}
|
||||
default:
|
||||
return Elem{}, KeyNotFound{Type: val.Type()}
|
||||
}
|
||||
switch tt := err.(type) {
|
||||
case KeyNotFound:
|
||||
tt.Depth++
|
||||
tt.Key = key
|
||||
return Elem{}, tt
|
||||
case nil:
|
||||
return elem, nil
|
||||
default:
|
||||
return Elem{}, err // We can't actually hit this.
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
|
||||
//
|
||||
// This method will never return an error.
|
||||
func (d MDoc) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
if d == nil {
|
||||
// TODO: Should we do this?
|
||||
return bsontype.Null, nil, nil
|
||||
}
|
||||
data, _ := d.MarshalBSON()
|
||||
return bsontype.EmbeddedDocument, data, nil
|
||||
}
|
||||
|
||||
// MarshalBSON implements the Marshaler interface.
|
||||
//
|
||||
// This method will never return an error.
|
||||
func (d MDoc) MarshalBSON() ([]byte, error) { return d.AppendMarshalBSON(nil) }
|
||||
|
||||
// AppendMarshalBSON marshals Doc to BSON bytes, appending to dst.
|
||||
//
|
||||
// This method will never return an error.
|
||||
func (d MDoc) AppendMarshalBSON(dst []byte) ([]byte, error) {
|
||||
idx, dst := bsoncore.ReserveLength(dst)
|
||||
for k, v := range d {
|
||||
t, data, _ := v.MarshalBSONValue() // Value.MarshalBSONValue never returns an error.
|
||||
dst = append(dst, byte(t))
|
||||
dst = append(dst, k...)
|
||||
dst = append(dst, 0x00)
|
||||
dst = append(dst, data...)
|
||||
}
|
||||
dst = append(dst, 0x00)
|
||||
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// UnmarshalBSON implements the Unmarshaler interface.
|
||||
func (d *MDoc) UnmarshalBSON(b []byte) error {
|
||||
if d == nil {
|
||||
return ErrNilDocument
|
||||
}
|
||||
|
||||
if err := bsoncore.Document(b).Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elems, err := bsoncore.Document(b).Elements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var val Val
|
||||
for _, elem := range elems {
|
||||
rawv := elem.Value()
|
||||
err = val.UnmarshalBSONValue(rawv.Type, rawv.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
(*d)[elem.Key()] = val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Equal compares this document to another, returning true if they are equal.
|
||||
func (d MDoc) Equal(id IDoc) bool {
|
||||
switch tt := id.(type) {
|
||||
case MDoc:
|
||||
d2 := tt
|
||||
if len(d) != len(d2) {
|
||||
return false
|
||||
}
|
||||
for key, value := range d {
|
||||
value2, ok := d2[key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !value.Equal(value2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
case Doc:
|
||||
unique := make(map[string]struct{})
|
||||
for _, elem := range tt {
|
||||
unique[elem.Key] = struct{}{}
|
||||
val, ok := d[elem.Key]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if !val.Equal(elem.Value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(unique) != len(d) {
|
||||
return false
|
||||
}
|
||||
case nil:
|
||||
return d == nil
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface.
|
||||
func (d MDoc) String() string {
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte("bson.Document{"))
|
||||
first := true
|
||||
for key, value := range d {
|
||||
if !first {
|
||||
buf.Write([]byte(", "))
|
||||
}
|
||||
fmt.Fprintf(&buf, "%v", Elem{Key: key, Value: value})
|
||||
first = false
|
||||
}
|
||||
buf.WriteByte('}')
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (MDoc) idoc() {}
|
||||
223
mongo/x/bsonx/mdocument_test.go
Normal file
223
mongo/x/bsonx/mdocument_test.go
Normal file
@@ -0,0 +1,223 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestMDoc(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("ReadMDoc", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
t.Run("UnmarshalingError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
invalid := []byte{0x01, 0x02}
|
||||
want := bsoncore.NewInsufficientBytesError(nil, nil)
|
||||
_, got := ReadMDoc(invalid)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Expected errors to match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("success", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
valid := bsoncore.BuildDocument(nil, bsoncore.AppendNullElement(nil, "foobar"))
|
||||
var want error
|
||||
wantDoc := MDoc{"foobar": Null()}
|
||||
gotDoc, got := ReadMDoc(valid)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Expected errors to match. got %v; want %v", got, want)
|
||||
}
|
||||
if !cmp.Equal(gotDoc, wantDoc) {
|
||||
t.Errorf("Expected returned documents to match. got %v; want %v", gotDoc, wantDoc)
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("Copy", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
name string
|
||||
start MDoc
|
||||
copy MDoc
|
||||
}{
|
||||
{"nil", nil, MDoc{}},
|
||||
{"not-nil", MDoc{"foobar": Null()}, MDoc{"foobar": Null()}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
copy := tc.start.Copy()
|
||||
if !cmp.Equal(copy, tc.copy) {
|
||||
t.Errorf("Expected copies to be equal. got %v; want %v", copy, tc.copy)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{} // method to call
|
||||
params []interface{} // parameters
|
||||
rets []interface{} // returns
|
||||
}{
|
||||
{
|
||||
"Lookup/err", MDoc{}.Lookup,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Val{}},
|
||||
},
|
||||
{
|
||||
"Lookup/success", MDoc{"pi": Double(3.14159)}.Lookup,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Double(3.14159)},
|
||||
},
|
||||
{
|
||||
"LookupErr/err", MDoc{}.LookupErr,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Val{}, KeyNotFound{Key: []string{}}},
|
||||
},
|
||||
{
|
||||
"LookupErr/success", MDoc{"pi": Double(3.14159)}.LookupErr,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Double(3.14159), error(nil)},
|
||||
},
|
||||
{
|
||||
"LookupElem/err", MDoc{}.LookupElement,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Elem{}},
|
||||
},
|
||||
{
|
||||
"LookupElem/success", MDoc{"pi": Double(3.14159)}.LookupElement,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/zero length key", MDoc{}.LookupElementErr,
|
||||
[]interface{}{[]string{}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{}}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/key not found", MDoc{}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo"}}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/key not found/depth 2", MDoc{"foo": Document(Doc{})}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "bar"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "bar"}, Depth: 1}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/invalid depth 2 type", MDoc{"foo": Document(MDoc{"pi": Double(3.14159)})}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "pi", "baz"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "pi", "baz"}, Depth: 1, Type: bsontype.Double}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/invalid depth 2 type (Doc)", MDoc{"foo": Document(Doc{{"pi", Double(3.14159)}})}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "pi", "baz"}},
|
||||
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "pi", "baz"}, Depth: 1, Type: bsontype.Double}},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/success", MDoc{"pi": Double(3.14159)}.LookupElementErr,
|
||||
[]interface{}{[]string{"pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/success/depth 2 (Doc)", MDoc{"foo": Document(Doc{{"pi", Double(3.14159)}})}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
|
||||
},
|
||||
{
|
||||
"LookupElementErr/success/depth 2", MDoc{"foo": Document(MDoc{"pi": Double(3.14159)})}.LookupElementErr,
|
||||
[]interface{}{[]string{"foo", "pi"}},
|
||||
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
|
||||
},
|
||||
{
|
||||
"MarshalBSONValue/nil", MDoc(nil).MarshalBSONValue,
|
||||
nil,
|
||||
[]interface{}{bsontype.Null, []byte(nil), error(nil)},
|
||||
},
|
||||
{
|
||||
"MarshalBSONValue/success", MDoc{"pi": Double(3.14159)}.MarshalBSONValue, nil,
|
||||
[]interface{}{
|
||||
bsontype.EmbeddedDocument,
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)),
|
||||
error(nil),
|
||||
},
|
||||
},
|
||||
{
|
||||
"MarshalBSON", MDoc{"pi": Double(3.14159)}.MarshalBSON, nil,
|
||||
[]interface{}{bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
|
||||
},
|
||||
{
|
||||
"MarshalBSON/empty", MDoc{}.MarshalBSON, nil,
|
||||
[]interface{}{bsoncore.BuildDocument(nil, nil), error(nil)},
|
||||
},
|
||||
{
|
||||
"AppendMarshalBSON", MDoc{"pi": Double(3.14159)}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
|
||||
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
|
||||
},
|
||||
{
|
||||
"AppendMarshalBSON/empty", MDoc{}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
|
||||
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, nil), error(nil)},
|
||||
},
|
||||
{"Equal/IDoc nil", MDoc(nil).Equal, []interface{}{IDoc(nil)}, []interface{}{true}},
|
||||
{"Equal/MDoc nil", MDoc(nil).Equal, []interface{}{Doc(nil)}, []interface{}{true}},
|
||||
{"Equal/Doc/different length", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{Doc{}}, []interface{}{false}},
|
||||
{"Equal/Doc/elems not equal", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{Doc{{"pi", Int32(1)}}}, []interface{}{false}},
|
||||
{"Equal/Doc/success", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{Doc{{"pi", Double(3.14159)}}}, []interface{}{true}},
|
||||
{"Equal/MDoc/elems not equal", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{MDoc{"pi": Int32(1)}}, []interface{}{false}},
|
||||
{"Equal/MDoc/elems not found", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{MDoc{"foo": Int32(1)}}, []interface{}{false}},
|
||||
{
|
||||
"Equal/MDoc/duplicate",
|
||||
Doc{{"a", Int32(1)}, {"a", Int32(1)}}.Equal, []interface{}{MDoc{"a": Int32(1), "b": Int32(2)}},
|
||||
[]interface{}{false},
|
||||
},
|
||||
{"Equal/MDoc/success", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"pi": Double(3.14159)}}, []interface{}{true}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("property fn must be a function, but it is a %v", fn.Kind())
|
||||
}
|
||||
if fn.Type().NumIn() != len(tc.params) && !fn.Type().IsVariadic() {
|
||||
t.Fatalf("number of parameters does not match. fn takes %d, but was provided %d", fn.Type().NumIn(), len(tc.params))
|
||||
}
|
||||
params := make([]reflect.Value, 0, len(tc.params))
|
||||
for idx, param := range tc.params {
|
||||
if param == nil {
|
||||
params = append(params, reflect.New(fn.Type().In(idx)).Elem())
|
||||
continue
|
||||
}
|
||||
params = append(params, reflect.ValueOf(param))
|
||||
}
|
||||
var rets []reflect.Value
|
||||
if fn.Type().IsVariadic() {
|
||||
rets = fn.CallSlice(params)
|
||||
} else {
|
||||
rets = fn.Call(params)
|
||||
}
|
||||
if len(rets) != len(tc.rets) {
|
||||
t.Fatalf("mismatched number of returns. received %d; expected %d", len(rets), len(tc.rets))
|
||||
}
|
||||
for idx := range rets {
|
||||
got, want := rets[idx].Interface(), tc.rets[idx]
|
||||
if !cmp.Equal(got, want) {
|
||||
t.Errorf("Return %d does not match. got %v; want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
637
mongo/x/bsonx/primitive_codecs.go
Normal file
637
mongo/x/bsonx/primitive_codecs.go
Normal file
@@ -0,0 +1,637 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
)
|
||||
|
||||
var primitiveCodecs PrimitiveCodecs
|
||||
|
||||
var tDocument = reflect.TypeOf((Doc)(nil))
|
||||
var tArray = reflect.TypeOf((Arr)(nil))
|
||||
var tValue = reflect.TypeOf(Val{})
|
||||
var tElementSlice = reflect.TypeOf(([]Elem)(nil))
|
||||
|
||||
// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types
|
||||
// defined in this package.
|
||||
type PrimitiveCodecs struct{}
|
||||
|
||||
// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
|
||||
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
|
||||
func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) {
|
||||
if rb == nil {
|
||||
panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil"))
|
||||
}
|
||||
|
||||
rb.
|
||||
RegisterTypeEncoder(tDocument, bsoncodec.ValueEncoderFunc(pc.DocumentEncodeValue)).
|
||||
RegisterTypeEncoder(tArray, bsoncodec.ValueEncoderFunc(pc.ArrayEncodeValue)).
|
||||
RegisterTypeEncoder(tValue, bsoncodec.ValueEncoderFunc(pc.ValueEncodeValue)).
|
||||
RegisterTypeEncoder(tElementSlice, bsoncodec.ValueEncoderFunc(pc.ElementSliceEncodeValue)).
|
||||
RegisterTypeDecoder(tDocument, bsoncodec.ValueDecoderFunc(pc.DocumentDecodeValue)).
|
||||
RegisterTypeDecoder(tArray, bsoncodec.ValueDecoderFunc(pc.ArrayDecodeValue)).
|
||||
RegisterTypeDecoder(tValue, bsoncodec.ValueDecoderFunc(pc.ValueDecodeValue)).
|
||||
RegisterTypeDecoder(tElementSlice, bsoncodec.ValueDecoderFunc(pc.ElementSliceDecodeValue))
|
||||
}
|
||||
|
||||
// DocumentEncodeValue is the ValueEncoderFunc for *Document.
|
||||
func (pc PrimitiveCodecs) DocumentEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tDocument {
|
||||
return bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
|
||||
doc := val.Interface().(Doc)
|
||||
|
||||
dw, err := vw.WriteDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pc.encodeDocument(ec, dw, doc)
|
||||
}
|
||||
|
||||
// DocumentDecodeValue is the ValueDecoderFunc for *Document.
|
||||
func (pc PrimitiveCodecs) DocumentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tDocument {
|
||||
return bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
|
||||
}
|
||||
|
||||
return pc.documentDecodeValue(dctx, vr, val.Addr().Interface().(*Doc))
|
||||
}
|
||||
|
||||
func (pc PrimitiveCodecs) documentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, doc *Doc) error {
|
||||
|
||||
dr, err := vr.ReadDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return pc.decodeDocument(dctx, dr, doc)
|
||||
}
|
||||
|
||||
// ArrayEncodeValue is the ValueEncoderFunc for *Array.
|
||||
func (pc PrimitiveCodecs) ArrayEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tArray {
|
||||
return bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
|
||||
arr := val.Interface().(Arr)
|
||||
|
||||
aw, err := vw.WriteArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, val := range arr {
|
||||
dvw, err := aw.WriteArrayElement()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = pc.encodeValue(ec, dvw, val)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return aw.WriteArrayEnd()
|
||||
}
|
||||
|
||||
// ArrayDecodeValue is the ValueDecoderFunc for *Array.
|
||||
func (pc PrimitiveCodecs) ArrayDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tArray {
|
||||
return bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
|
||||
}
|
||||
|
||||
ar, err := vr.ReadArray()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.MakeSlice(tArray, 0, 0))
|
||||
}
|
||||
val.SetLen(0)
|
||||
|
||||
for {
|
||||
vr, err := ar.ReadValue()
|
||||
if err == bsonrw.ErrEOA {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elem Val
|
||||
err = pc.valueDecodeValue(dc, vr, &elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ElementSliceEncodeValue is the ValueEncoderFunc for []*Element.
|
||||
func (pc PrimitiveCodecs) ElementSliceEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tElementSlice {
|
||||
return bsoncodec.ValueEncoderError{Name: "ElementSliceEncodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
return vw.WriteNull()
|
||||
}
|
||||
|
||||
return pc.DocumentEncodeValue(ec, vw, val.Convert(tDocument))
|
||||
}
|
||||
|
||||
// ElementSliceDecodeValue is the ValueDecoderFunc for []*Element.
|
||||
func (pc PrimitiveCodecs) ElementSliceDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tElementSlice {
|
||||
return bsoncodec.ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
|
||||
}
|
||||
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
|
||||
}
|
||||
|
||||
val.SetLen(0)
|
||||
|
||||
dr, err := vr.ReadDocument()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
elems := make([]reflect.Value, 0)
|
||||
for {
|
||||
key, vr, err := dr.ReadElement()
|
||||
if err == bsonrw.ErrEOD {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elem Elem
|
||||
err = pc.elementDecodeValue(dc, vr, key, &elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
elems = append(elems, reflect.ValueOf(elem))
|
||||
}
|
||||
|
||||
val.Set(reflect.Append(val, elems...))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValueEncodeValue is the ValueEncoderFunc for *Value.
|
||||
func (pc PrimitiveCodecs) ValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
|
||||
if !val.IsValid() || val.Type() != tValue {
|
||||
return bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: val}
|
||||
}
|
||||
|
||||
v := val.Interface().(Val)
|
||||
|
||||
return pc.encodeValue(ec, vw, v)
|
||||
}
|
||||
|
||||
// ValueDecodeValue is the ValueDecoderFunc for *Value.
|
||||
func (pc PrimitiveCodecs) ValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if !val.CanSet() || val.Type() != tValue {
|
||||
return bsoncodec.ValueDecoderError{Name: "ValueDecodeValue", Types: []reflect.Type{tValue}, Received: val}
|
||||
}
|
||||
|
||||
return pc.valueDecodeValue(dc, vr, val.Addr().Interface().(*Val))
|
||||
}
|
||||
|
||||
// encodeDocument is a separate function that we use because CodeWithScope
|
||||
// returns us a DocumentWriter and we need to do the same logic that we would do
|
||||
// for a document but cannot use a Codec.
|
||||
func (pc PrimitiveCodecs) encodeDocument(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, doc Doc) error {
|
||||
for _, elem := range doc {
|
||||
dvw, err := dw.WriteDocumentElement(elem.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = pc.encodeValue(ec, dvw, elem.Value)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return dw.WriteDocumentEnd()
|
||||
}
|
||||
|
||||
// DecodeDocument haves decoding into a Doc from a bsonrw.DocumentReader.
|
||||
func (pc PrimitiveCodecs) DecodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
|
||||
return pc.decodeDocument(dctx, dr, pdoc)
|
||||
}
|
||||
|
||||
func (pc PrimitiveCodecs) decodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
|
||||
if *pdoc == nil {
|
||||
*pdoc = make(Doc, 0)
|
||||
}
|
||||
*pdoc = (*pdoc)[:0]
|
||||
for {
|
||||
key, vr, err := dr.ReadElement()
|
||||
if err == bsonrw.ErrEOD {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var elem Elem
|
||||
err = pc.elementDecodeValue(dctx, vr, key, &elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*pdoc = append(*pdoc, elem)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc PrimitiveCodecs) elementDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, key string, elem *Elem) error {
|
||||
var val Val
|
||||
switch vr.Type() {
|
||||
case bsontype.Double:
|
||||
f64, err := vr.ReadDouble()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Double(f64)
|
||||
case bsontype.String:
|
||||
str, err := vr.ReadString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = String(str)
|
||||
case bsontype.EmbeddedDocument:
|
||||
var embeddedDoc Doc
|
||||
err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Document(embeddedDoc)
|
||||
case bsontype.Array:
|
||||
arr := reflect.New(tArray).Elem()
|
||||
err := pc.ArrayDecodeValue(dc, vr, arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Array(arr.Interface().(Arr))
|
||||
case bsontype.Binary:
|
||||
data, subtype, err := vr.ReadBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Binary(subtype, data)
|
||||
case bsontype.Undefined:
|
||||
err := vr.ReadUndefined()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Undefined()
|
||||
case bsontype.ObjectID:
|
||||
oid, err := vr.ReadObjectID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = ObjectID(oid)
|
||||
case bsontype.Boolean:
|
||||
b, err := vr.ReadBoolean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Boolean(b)
|
||||
case bsontype.DateTime:
|
||||
dt, err := vr.ReadDateTime()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = DateTime(dt)
|
||||
case bsontype.Null:
|
||||
err := vr.ReadNull()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Null()
|
||||
case bsontype.Regex:
|
||||
pattern, options, err := vr.ReadRegex()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Regex(pattern, options)
|
||||
case bsontype.DBPointer:
|
||||
ns, pointer, err := vr.ReadDBPointer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = DBPointer(ns, pointer)
|
||||
case bsontype.JavaScript:
|
||||
js, err := vr.ReadJavascript()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = JavaScript(js)
|
||||
case bsontype.Symbol:
|
||||
symbol, err := vr.ReadSymbol()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Symbol(symbol)
|
||||
case bsontype.CodeWithScope:
|
||||
code, scope, err := vr.ReadCodeWithScope()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var doc Doc
|
||||
err = pc.decodeDocument(dc, scope, &doc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = CodeWithScope(code, doc)
|
||||
case bsontype.Int32:
|
||||
i32, err := vr.ReadInt32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Int32(i32)
|
||||
case bsontype.Timestamp:
|
||||
t, i, err := vr.ReadTimestamp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Timestamp(t, i)
|
||||
case bsontype.Int64:
|
||||
i64, err := vr.ReadInt64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Int64(i64)
|
||||
case bsontype.Decimal128:
|
||||
d128, err := vr.ReadDecimal128()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = Decimal128(d128)
|
||||
case bsontype.MinKey:
|
||||
err := vr.ReadMinKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = MinKey()
|
||||
case bsontype.MaxKey:
|
||||
err := vr.ReadMaxKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
val = MaxKey()
|
||||
default:
|
||||
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
|
||||
}
|
||||
|
||||
*elem = Elem{Key: key, Value: val}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeValue does not validation, and the callers must perform validation on val before calling
|
||||
// this method.
|
||||
func (pc PrimitiveCodecs) encodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val Val) error {
|
||||
var err error
|
||||
switch val.Type() {
|
||||
case bsontype.Double:
|
||||
err = vw.WriteDouble(val.Double())
|
||||
case bsontype.String:
|
||||
err = vw.WriteString(val.StringValue())
|
||||
case bsontype.EmbeddedDocument:
|
||||
var encoder bsoncodec.ValueEncoder
|
||||
encoder, err = ec.LookupEncoder(tDocument)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Document()))
|
||||
case bsontype.Array:
|
||||
var encoder bsoncodec.ValueEncoder
|
||||
encoder, err = ec.LookupEncoder(tArray)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Array()))
|
||||
case bsontype.Binary:
|
||||
// TODO: FIX THIS (╯°□°)╯︵ ┻━┻
|
||||
subtype, data := val.Binary()
|
||||
err = vw.WriteBinaryWithSubtype(data, subtype)
|
||||
case bsontype.Undefined:
|
||||
err = vw.WriteUndefined()
|
||||
case bsontype.ObjectID:
|
||||
err = vw.WriteObjectID(val.ObjectID())
|
||||
case bsontype.Boolean:
|
||||
err = vw.WriteBoolean(val.Boolean())
|
||||
case bsontype.DateTime:
|
||||
err = vw.WriteDateTime(val.DateTime())
|
||||
case bsontype.Null:
|
||||
err = vw.WriteNull()
|
||||
case bsontype.Regex:
|
||||
err = vw.WriteRegex(val.Regex())
|
||||
case bsontype.DBPointer:
|
||||
err = vw.WriteDBPointer(val.DBPointer())
|
||||
case bsontype.JavaScript:
|
||||
err = vw.WriteJavascript(val.JavaScript())
|
||||
case bsontype.Symbol:
|
||||
err = vw.WriteSymbol(val.Symbol())
|
||||
case bsontype.CodeWithScope:
|
||||
code, scope := val.CodeWithScope()
|
||||
|
||||
var cwsw bsonrw.DocumentWriter
|
||||
cwsw, err = vw.WriteCodeWithScope(code)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = pc.encodeDocument(ec, cwsw, scope)
|
||||
case bsontype.Int32:
|
||||
err = vw.WriteInt32(val.Int32())
|
||||
case bsontype.Timestamp:
|
||||
err = vw.WriteTimestamp(val.Timestamp())
|
||||
case bsontype.Int64:
|
||||
err = vw.WriteInt64(val.Int64())
|
||||
case bsontype.Decimal128:
|
||||
err = vw.WriteDecimal128(val.Decimal128())
|
||||
case bsontype.MinKey:
|
||||
err = vw.WriteMinKey()
|
||||
case bsontype.MaxKey:
|
||||
err = vw.WriteMaxKey()
|
||||
default:
|
||||
err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (pc PrimitiveCodecs) valueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val *Val) error {
|
||||
switch vr.Type() {
|
||||
case bsontype.Double:
|
||||
f64, err := vr.ReadDouble()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Double(f64)
|
||||
case bsontype.String:
|
||||
str, err := vr.ReadString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = String(str)
|
||||
case bsontype.EmbeddedDocument:
|
||||
var embeddedDoc Doc
|
||||
err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Document(embeddedDoc)
|
||||
case bsontype.Array:
|
||||
arr := reflect.New(tArray).Elem()
|
||||
err := pc.ArrayDecodeValue(dc, vr, arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Array(arr.Interface().(Arr))
|
||||
case bsontype.Binary:
|
||||
data, subtype, err := vr.ReadBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Binary(subtype, data)
|
||||
case bsontype.Undefined:
|
||||
err := vr.ReadUndefined()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Undefined()
|
||||
case bsontype.ObjectID:
|
||||
oid, err := vr.ReadObjectID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = ObjectID(oid)
|
||||
case bsontype.Boolean:
|
||||
b, err := vr.ReadBoolean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Boolean(b)
|
||||
case bsontype.DateTime:
|
||||
dt, err := vr.ReadDateTime()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = DateTime(dt)
|
||||
case bsontype.Null:
|
||||
err := vr.ReadNull()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Null()
|
||||
case bsontype.Regex:
|
||||
pattern, options, err := vr.ReadRegex()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Regex(pattern, options)
|
||||
case bsontype.DBPointer:
|
||||
ns, pointer, err := vr.ReadDBPointer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = DBPointer(ns, pointer)
|
||||
case bsontype.JavaScript:
|
||||
js, err := vr.ReadJavascript()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = JavaScript(js)
|
||||
case bsontype.Symbol:
|
||||
symbol, err := vr.ReadSymbol()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Symbol(symbol)
|
||||
case bsontype.CodeWithScope:
|
||||
code, scope, err := vr.ReadCodeWithScope()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var scopeDoc Doc
|
||||
err = pc.decodeDocument(dc, scope, &scopeDoc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = CodeWithScope(code, scopeDoc)
|
||||
case bsontype.Int32:
|
||||
i32, err := vr.ReadInt32()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Int32(i32)
|
||||
case bsontype.Timestamp:
|
||||
t, i, err := vr.ReadTimestamp()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Timestamp(t, i)
|
||||
case bsontype.Int64:
|
||||
i64, err := vr.ReadInt64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Int64(i64)
|
||||
case bsontype.Decimal128:
|
||||
d128, err := vr.ReadDecimal128()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = Decimal128(d128)
|
||||
case bsontype.MinKey:
|
||||
err := vr.ReadMinKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = MinKey()
|
||||
case bsontype.MaxKey:
|
||||
err := vr.ReadMaxKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*val = MaxKey()
|
||||
default:
|
||||
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
910
mongo/x/bsonx/primitive_codecs_test.go
Normal file
910
mongo/x/bsonx/primitive_codecs_test.go
Normal file
@@ -0,0 +1,910 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestDefaultValueEncoders(t *testing.T) {
|
||||
var pcx PrimitiveCodecs
|
||||
|
||||
var wrong = func(string, string) string { return "wrong" }
|
||||
|
||||
type subtest struct {
|
||||
name string
|
||||
val interface{}
|
||||
ectx *bsoncodec.EncodeContext
|
||||
llvrw *bsonrwtest.ValueReaderWriter
|
||||
invoke bsonrwtest.Invoked
|
||||
err error
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ve bsoncodec.ValueEncoder
|
||||
subtests []subtest
|
||||
}{
|
||||
{
|
||||
"ValueEncodeValue",
|
||||
bsoncodec.ValueEncoderFunc(pcx.ValueEncodeValue),
|
||||
[]subtest{
|
||||
{
|
||||
"wrong type",
|
||||
wrong,
|
||||
nil,
|
||||
nil,
|
||||
bsonrwtest.Nothing,
|
||||
bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: reflect.ValueOf(wrong)},
|
||||
},
|
||||
{"empty value", Val{}, nil, nil, bsonrwtest.WriteNull, nil},
|
||||
{
|
||||
"success",
|
||||
Null(),
|
||||
&bsoncodec.EncodeContext{Registry: DefaultRegistry},
|
||||
&bsonrwtest.ValueReaderWriter{},
|
||||
bsonrwtest.WriteNull,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"ElementSliceEncodeValue",
|
||||
bsoncodec.ValueEncoderFunc(pcx.ElementSliceEncodeValue),
|
||||
[]subtest{
|
||||
{
|
||||
"wrong type",
|
||||
wrong,
|
||||
nil,
|
||||
nil,
|
||||
bsonrwtest.Nothing,
|
||||
bsoncodec.ValueEncoderError{
|
||||
Name: "ElementSliceEncodeValue",
|
||||
Types: []reflect.Type{tElementSlice},
|
||||
Received: reflect.ValueOf(wrong),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"ArrayEncodeValue",
|
||||
bsoncodec.ValueEncoderFunc(pcx.ArrayEncodeValue),
|
||||
[]subtest{
|
||||
{
|
||||
"wrong type",
|
||||
wrong,
|
||||
nil,
|
||||
nil,
|
||||
bsonrwtest.Nothing,
|
||||
bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: reflect.ValueOf(wrong)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, subtest := range tc.subtests {
|
||||
t.Run(subtest.name, func(t *testing.T) {
|
||||
var ec bsoncodec.EncodeContext
|
||||
if subtest.ectx != nil {
|
||||
ec = *subtest.ectx
|
||||
}
|
||||
llvrw := new(bsonrwtest.ValueReaderWriter)
|
||||
if subtest.llvrw != nil {
|
||||
llvrw = subtest.llvrw
|
||||
}
|
||||
llvrw.T = t
|
||||
err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val))
|
||||
if !compareErrors(err, subtest.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, subtest.err)
|
||||
}
|
||||
invoked := llvrw.Invoked
|
||||
if !cmp.Equal(invoked, subtest.invoke) {
|
||||
t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("DocumentEncodeValue", func(t *testing.T) {
|
||||
t.Run("ValueEncoderError", func(t *testing.T) {
|
||||
val := reflect.ValueOf(bool(true))
|
||||
want := bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
|
||||
got := (PrimitiveCodecs{}).DocumentEncodeValue(bsoncodec.EncodeContext{}, nil, val)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteDocument Error", func(t *testing.T) {
|
||||
want := errors.New("WriteDocument Error")
|
||||
llvrw := &bsonrwtest.ValueReaderWriter{
|
||||
T: t,
|
||||
Err: want,
|
||||
ErrAfter: bsonrwtest.WriteDocument,
|
||||
}
|
||||
got := (PrimitiveCodecs{}).DocumentEncodeValue(bsoncodec.EncodeContext{}, llvrw, reflect.MakeSlice(tDocument, 0, 0))
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("encodeDocument errors", func(t *testing.T) {
|
||||
ec := bsoncodec.EncodeContext{}
|
||||
err := errors.New("encodeDocument error")
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
testCases := []struct {
|
||||
name string
|
||||
ec bsoncodec.EncodeContext
|
||||
llvrw *bsonrwtest.ValueReaderWriter
|
||||
doc Doc
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"WriteDocumentElement",
|
||||
ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("wde error"), ErrAfter: bsonrwtest.WriteDocumentElement},
|
||||
Doc{{"foo", Null()}},
|
||||
errors.New("wde error"),
|
||||
},
|
||||
{
|
||||
"WriteDouble", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDouble},
|
||||
Doc{{"foo", Double(3.14159)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteString", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteString},
|
||||
Doc{{"foo", String("bar")}}, err,
|
||||
},
|
||||
{
|
||||
"WriteDocument (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t},
|
||||
Doc{{"foo", Document(Doc{{"bar", Null()}})}},
|
||||
bsoncodec.ErrNoEncoder{Type: tDocument},
|
||||
},
|
||||
{
|
||||
"WriteArray (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t},
|
||||
Doc{{"foo", Array(Arr{Null()})}},
|
||||
bsoncodec.ErrNoEncoder{Type: tArray},
|
||||
},
|
||||
{
|
||||
"WriteBinary", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBinaryWithSubtype},
|
||||
Doc{{"foo", Binary(0xFF, []byte{0x01, 0x02, 0x03})}}, err,
|
||||
},
|
||||
{
|
||||
"WriteUndefined", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteUndefined},
|
||||
Doc{{"foo", Undefined()}}, err,
|
||||
},
|
||||
{
|
||||
"WriteObjectID", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteObjectID},
|
||||
Doc{{"foo", ObjectID(oid)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteBoolean", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBoolean},
|
||||
Doc{{"foo", Boolean(true)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteDateTime", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDateTime},
|
||||
Doc{{"foo", DateTime(1234567890)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteNull", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteNull},
|
||||
Doc{{"foo", Null()}}, err,
|
||||
},
|
||||
{
|
||||
"WriteRegex", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteRegex},
|
||||
Doc{{"foo", Regex("bar", "baz")}}, err,
|
||||
},
|
||||
{
|
||||
"WriteDBPointer", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDBPointer},
|
||||
Doc{{"foo", DBPointer("bar", oid)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteJavascript", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteJavascript},
|
||||
Doc{{"foo", JavaScript("var hello = 'world';")}}, err,
|
||||
},
|
||||
{
|
||||
"WriteSymbol", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteSymbol},
|
||||
Doc{{"foo", Symbol("symbolbaz")}}, err,
|
||||
},
|
||||
{
|
||||
"WriteCodeWithScope (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteCodeWithScope},
|
||||
Doc{{"foo", CodeWithScope("var hello = 'world';", Doc{}.Append("bar", Null()))}},
|
||||
err,
|
||||
},
|
||||
{
|
||||
"WriteInt32", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt32},
|
||||
Doc{{"foo", Int32(12345)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteInt64", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt64},
|
||||
Doc{{"foo", Int64(1234567890)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteTimestamp", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteTimestamp},
|
||||
Doc{{"foo", Timestamp(10, 20)}}, err,
|
||||
},
|
||||
{
|
||||
"WriteDecimal128", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDecimal128},
|
||||
Doc{{"foo", Decimal128(primitive.NewDecimal128(10, 20))}}, err,
|
||||
},
|
||||
{
|
||||
"WriteMinKey", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMinKey},
|
||||
Doc{{"foo", MinKey()}}, err,
|
||||
},
|
||||
{
|
||||
"WriteMaxKey", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMaxKey},
|
||||
Doc{{"foo", MaxKey()}}, err,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := (PrimitiveCodecs{}).DocumentEncodeValue(tc.ec, tc.llvrw, reflect.ValueOf(tc.doc))
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
d128 := primitive.NewDecimal128(10, 20)
|
||||
want := Doc{
|
||||
{"a", Double(3.14159)}, {"b", String("foo")},
|
||||
{"c", Document(Doc{{"aa", Null()}})}, {"d", Array(Arr{Null()})},
|
||||
{"e", Binary(0xFF, []byte{0x01, 0x02, 0x03})}, {"f", Undefined()},
|
||||
{"g", ObjectID(oid)}, {"h", Boolean(true)},
|
||||
{"i", DateTime(1234567890)}, {"j", Null()},
|
||||
{"k", Regex("foo", "abr")},
|
||||
{"l", DBPointer("foobar", oid)}, {"m", JavaScript("var hello = 'world';")},
|
||||
{"n", Symbol("bazqux")},
|
||||
{"o", CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}})},
|
||||
{"p", Int32(12345)},
|
||||
{"q", Timestamp(10, 20)}, {"r", Int64(1234567890)}, {"s", Decimal128(d128)}, {"t", MinKey()}, {"u", MaxKey()},
|
||||
}
|
||||
slc := make(bsonrw.SliceWriter, 0, 128)
|
||||
vw, err := bsonrw.NewBSONValueWriter(&slc)
|
||||
noerr(t, err)
|
||||
|
||||
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
|
||||
err = (PrimitiveCodecs{}).DocumentEncodeValue(ec, vw, reflect.ValueOf(want))
|
||||
noerr(t, err)
|
||||
got, err := ReadDoc(slc)
|
||||
noerr(t, err)
|
||||
if !got.Equal(want) {
|
||||
t.Error("Documents do not match")
|
||||
t.Errorf("\ngot :%v\nwant:%v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ArrayEncodeValue", func(t *testing.T) {
|
||||
t.Run("CodecEncodeError", func(t *testing.T) {
|
||||
val := reflect.ValueOf(bool(true))
|
||||
want := bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
|
||||
got := (PrimitiveCodecs{}).ArrayEncodeValue(bsoncodec.EncodeContext{}, nil, val)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("WriteArray Error", func(t *testing.T) {
|
||||
want := errors.New("WriteArray Error")
|
||||
llvrw := &bsonrwtest.ValueReaderWriter{
|
||||
T: t,
|
||||
Err: want,
|
||||
ErrAfter: bsonrwtest.WriteArray,
|
||||
}
|
||||
got := (PrimitiveCodecs{}).ArrayEncodeValue(bsoncodec.EncodeContext{}, llvrw, reflect.MakeSlice(tArray, 0, 0))
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("encode array errors", func(t *testing.T) {
|
||||
ec := bsoncodec.EncodeContext{}
|
||||
err := errors.New("encode array error")
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
testCases := []struct {
|
||||
name string
|
||||
ec bsoncodec.EncodeContext
|
||||
llvrw *bsonrwtest.ValueReaderWriter
|
||||
arr Arr
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"WriteDocumentElement",
|
||||
ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("wde error"), ErrAfter: bsonrwtest.WriteArrayElement},
|
||||
Arr{Null()},
|
||||
errors.New("wde error"),
|
||||
},
|
||||
{
|
||||
"WriteDouble", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDouble},
|
||||
Arr{Double(3.14159)}, err,
|
||||
},
|
||||
{
|
||||
"WriteString", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteString},
|
||||
Arr{String("bar")}, err,
|
||||
},
|
||||
{
|
||||
"WriteDocument (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t},
|
||||
Arr{Document(Doc{{"bar", Null()}})},
|
||||
bsoncodec.ErrNoEncoder{Type: tDocument},
|
||||
},
|
||||
{
|
||||
"WriteArray (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t},
|
||||
Arr{Array(Arr{Null()})},
|
||||
bsoncodec.ErrNoEncoder{Type: tArray},
|
||||
},
|
||||
{
|
||||
"WriteBinary", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBinaryWithSubtype},
|
||||
Arr{Binary(0xFF, []byte{0x01, 0x02, 0x03})}, err,
|
||||
},
|
||||
{
|
||||
"WriteUndefined", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteUndefined},
|
||||
Arr{Undefined()}, err,
|
||||
},
|
||||
{
|
||||
"WriteObjectID", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteObjectID},
|
||||
Arr{ObjectID(oid)}, err,
|
||||
},
|
||||
{
|
||||
"WriteBoolean", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBoolean},
|
||||
Arr{Boolean(true)}, err,
|
||||
},
|
||||
{
|
||||
"WriteDateTime", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDateTime},
|
||||
Arr{DateTime(1234567890)}, err,
|
||||
},
|
||||
{
|
||||
"WriteNull", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteNull},
|
||||
Arr{Null()}, err,
|
||||
},
|
||||
{
|
||||
"WriteRegex", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteRegex},
|
||||
Arr{Regex("bar", "baz")}, err,
|
||||
},
|
||||
{
|
||||
"WriteDBPointer", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDBPointer},
|
||||
Arr{DBPointer("bar", oid)}, err,
|
||||
},
|
||||
{
|
||||
"WriteJavascript", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteJavascript},
|
||||
Arr{JavaScript("var hello = 'world';")}, err,
|
||||
},
|
||||
{
|
||||
"WriteSymbol", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteSymbol},
|
||||
Arr{Symbol("symbolbaz")}, err,
|
||||
},
|
||||
{
|
||||
"WriteCodeWithScope (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteCodeWithScope},
|
||||
Arr{CodeWithScope("var hello = 'world';", Doc{{"bar", Null()}})},
|
||||
err,
|
||||
},
|
||||
{
|
||||
"WriteInt32", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt32},
|
||||
Arr{Int32(12345)}, err,
|
||||
},
|
||||
{
|
||||
"WriteInt64", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt64},
|
||||
Arr{Int64(1234567890)}, err,
|
||||
},
|
||||
{
|
||||
"WriteTimestamp", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteTimestamp},
|
||||
Arr{Timestamp(10, 20)}, err,
|
||||
},
|
||||
{
|
||||
"WriteDecimal128", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDecimal128},
|
||||
Arr{Decimal128(primitive.NewDecimal128(10, 20))}, err,
|
||||
},
|
||||
{
|
||||
"WriteMinKey", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMinKey},
|
||||
Arr{MinKey()}, err,
|
||||
},
|
||||
{
|
||||
"WriteMaxKey", ec,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMaxKey},
|
||||
Arr{MaxKey()}, err,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := (PrimitiveCodecs{}).ArrayEncodeValue(tc.ec, tc.llvrw, reflect.ValueOf(tc.arr))
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
d128 := primitive.NewDecimal128(10, 20)
|
||||
want := Arr{
|
||||
Double(3.14159), String("foo"), Document(Doc{{"aa", Null()}}),
|
||||
Array(Arr{Null()}),
|
||||
Binary(0xFF, []byte{0x01, 0x02, 0x03}), Undefined(),
|
||||
ObjectID(oid), Boolean(true), DateTime(1234567890), Null(), Regex("foo", "abr"),
|
||||
DBPointer("foobar", oid), JavaScript("var hello = 'world';"), Symbol("bazqux"),
|
||||
CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}}), Int32(12345),
|
||||
Timestamp(10, 20), Int64(1234567890), Decimal128(d128), MinKey(), MaxKey(),
|
||||
}
|
||||
|
||||
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
|
||||
|
||||
slc := make(bsonrw.SliceWriter, 0, 128)
|
||||
vw, err := bsonrw.NewBSONValueWriter(&slc)
|
||||
noerr(t, err)
|
||||
|
||||
dr, err := vw.WriteDocument()
|
||||
noerr(t, err)
|
||||
vr, err := dr.WriteDocumentElement("foo")
|
||||
noerr(t, err)
|
||||
|
||||
err = (PrimitiveCodecs{}).ArrayEncodeValue(ec, vr, reflect.ValueOf(want))
|
||||
noerr(t, err)
|
||||
|
||||
err = dr.WriteDocumentEnd()
|
||||
noerr(t, err)
|
||||
|
||||
val, err := bsoncore.Document(slc).LookupErr("foo")
|
||||
noerr(t, err)
|
||||
rgot := val.Array()
|
||||
doc, err := ReadDoc(rgot)
|
||||
noerr(t, err)
|
||||
got := make(Arr, 0)
|
||||
for _, elem := range doc {
|
||||
got = append(got, elem.Value)
|
||||
}
|
||||
if !got.Equal(want) {
|
||||
t.Error("Documents do not match")
|
||||
t.Errorf("\ngot :%v\nwant:%v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultValueDecoders(t *testing.T) {
|
||||
var pcx PrimitiveCodecs
|
||||
|
||||
var wrong = func(string, string) string { return "wrong" }
|
||||
|
||||
const cansetreflectiontest = "cansetreflectiontest"
|
||||
|
||||
type subtest struct {
|
||||
name string
|
||||
val interface{}
|
||||
dctx *bsoncodec.DecodeContext
|
||||
llvrw *bsonrwtest.ValueReaderWriter
|
||||
invoke bsonrwtest.Invoked
|
||||
err error
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
vd bsoncodec.ValueDecoder
|
||||
subtests []subtest
|
||||
}{
|
||||
{
|
||||
"ValueDecodeValue",
|
||||
bsoncodec.ValueDecoderFunc(pcx.ValueDecodeValue),
|
||||
[]subtest{
|
||||
{
|
||||
"wrong type",
|
||||
wrong,
|
||||
nil,
|
||||
nil,
|
||||
bsonrwtest.Nothing,
|
||||
bsoncodec.ValueDecoderError{
|
||||
Name: "ValueDecodeValue",
|
||||
Types: []reflect.Type{tValue},
|
||||
Received: reflect.ValueOf(wrong),
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid value",
|
||||
(*Val)(nil),
|
||||
nil,
|
||||
nil,
|
||||
bsonrwtest.Nothing,
|
||||
bsoncodec.ValueDecoderError{
|
||||
Name: "ValueDecodeValue",
|
||||
Types: []reflect.Type{tValue},
|
||||
Received: reflect.ValueOf((*Val)(nil)),
|
||||
},
|
||||
},
|
||||
{
|
||||
"success",
|
||||
Double(3.14159),
|
||||
&bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
|
||||
bsonrwtest.ReadDouble,
|
||||
nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, rc := range tc.subtests {
|
||||
t.Run(rc.name, func(t *testing.T) {
|
||||
var dc bsoncodec.DecodeContext
|
||||
if rc.dctx != nil {
|
||||
dc = *rc.dctx
|
||||
}
|
||||
llvrw := new(bsonrwtest.ValueReaderWriter)
|
||||
if rc.llvrw != nil {
|
||||
llvrw = rc.llvrw
|
||||
}
|
||||
llvrw.T = t
|
||||
// var got interface{}
|
||||
if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test
|
||||
err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{})
|
||||
if !compareErrors(err, rc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, rc.err)
|
||||
}
|
||||
|
||||
val := reflect.New(reflect.TypeOf(rc.val)).Elem()
|
||||
err = tc.vd.DecodeValue(dc, llvrw, val)
|
||||
if !compareErrors(err, rc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, rc.err)
|
||||
}
|
||||
return
|
||||
}
|
||||
var val reflect.Value
|
||||
if rtype := reflect.TypeOf(rc.val); rtype != nil {
|
||||
val = reflect.New(rtype).Elem()
|
||||
}
|
||||
want := rc.val
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
fmt.Println(t.Name())
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
err := tc.vd.DecodeValue(dc, llvrw, val)
|
||||
if !compareErrors(err, rc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, rc.err)
|
||||
}
|
||||
invoked := llvrw.Invoked
|
||||
if !cmp.Equal(invoked, rc.invoke) {
|
||||
t.Errorf("Incorrect method invoked. got %v; want %v", invoked, rc.invoke)
|
||||
}
|
||||
var got interface{}
|
||||
if val.IsValid() && val.CanInterface() {
|
||||
got = val.Interface()
|
||||
}
|
||||
if rc.err == nil && !cmp.Equal(got, want, cmp.Comparer(compareValues)) {
|
||||
t.Errorf("Values do not match. got (%T)%v; want (%T)%v", got, got, want, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("DocumentDecodeValue", func(t *testing.T) {
|
||||
t.Run("CodecDecodeError", func(t *testing.T) {
|
||||
val := reflect.New(reflect.TypeOf(false)).Elem()
|
||||
want := bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
|
||||
got := pcx.DocumentDecodeValue(bsoncodec.DecodeContext{}, &bsonrwtest.ValueReaderWriter{BSONType: bsontype.EmbeddedDocument}, val)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("ReadDocument Error", func(t *testing.T) {
|
||||
want := errors.New("ReadDocument Error")
|
||||
llvrw := &bsonrwtest.ValueReaderWriter{
|
||||
T: t,
|
||||
Err: want,
|
||||
ErrAfter: bsonrwtest.ReadDocument,
|
||||
BSONType: bsontype.EmbeddedDocument,
|
||||
}
|
||||
got := pcx.DocumentDecodeValue(bsoncodec.DecodeContext{}, llvrw, reflect.New(reflect.TypeOf(Doc{})).Elem())
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("decodeDocument errors", func(t *testing.T) {
|
||||
dc := bsoncodec.DecodeContext{}
|
||||
err := errors.New("decodeDocument error")
|
||||
testCases := []struct {
|
||||
name string
|
||||
dc bsoncodec.DecodeContext
|
||||
llvrw *bsonrwtest.ValueReaderWriter
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"ReadElement",
|
||||
dc,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("re error"), ErrAfter: bsonrwtest.ReadElement},
|
||||
errors.New("re error"),
|
||||
},
|
||||
{"ReadDouble", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDouble, BSONType: bsontype.Double}, err},
|
||||
{"ReadString", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadString, BSONType: bsontype.String}, err},
|
||||
{"ReadBinary", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBinary, BSONType: bsontype.Binary}, err},
|
||||
{"ReadUndefined", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadUndefined, BSONType: bsontype.Undefined}, err},
|
||||
{"ReadObjectID", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadObjectID, BSONType: bsontype.ObjectID}, err},
|
||||
{"ReadBoolean", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBoolean, BSONType: bsontype.Boolean}, err},
|
||||
{"ReadDateTime", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDateTime, BSONType: bsontype.DateTime}, err},
|
||||
{"ReadNull", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadNull, BSONType: bsontype.Null}, err},
|
||||
{"ReadRegex", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadRegex, BSONType: bsontype.Regex}, err},
|
||||
{"ReadDBPointer", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDBPointer, BSONType: bsontype.DBPointer}, err},
|
||||
{"ReadJavascript", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadJavascript, BSONType: bsontype.JavaScript}, err},
|
||||
{"ReadSymbol", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadSymbol, BSONType: bsontype.Symbol}, err},
|
||||
{
|
||||
"ReadCodeWithScope (Lookup)", bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadCodeWithScope, BSONType: bsontype.CodeWithScope},
|
||||
err,
|
||||
},
|
||||
{"ReadInt32", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt32, BSONType: bsontype.Int32}, err},
|
||||
{"ReadInt64", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt64, BSONType: bsontype.Int64}, err},
|
||||
{"ReadTimestamp", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadTimestamp, BSONType: bsontype.Timestamp}, err},
|
||||
{"ReadDecimal128", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDecimal128, BSONType: bsontype.Decimal128}, err},
|
||||
{"ReadMinKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMinKey, BSONType: bsontype.MinKey}, err},
|
||||
{"ReadMaxKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMaxKey, BSONType: bsontype.MaxKey}, err},
|
||||
{"Invalid Type", dc, &bsonrwtest.ValueReaderWriter{T: t, BSONType: bsontype.Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", bsontype.Type(0))},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := pcx.DecodeDocument(tc.dc, tc.llvrw, new(Doc))
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
d128 := primitive.NewDecimal128(10, 20)
|
||||
want := Doc{
|
||||
{"a", Double(3.14159)}, {"b", String("foo")},
|
||||
{"c", Document(Doc{{"aa", Null()}})},
|
||||
{"d", Array(Arr{Null()})},
|
||||
{"e", Binary(0xFF, []byte{0x01, 0x02, 0x03})}, {"f", Undefined()},
|
||||
{"g", ObjectID(oid)}, {"h", Boolean(true)},
|
||||
{"i", DateTime(1234567890)}, {"j", Null()}, {"k", Regex("foo", "bar")},
|
||||
{"l", DBPointer("foobar", oid)}, {"m", JavaScript("var hello = 'world';")},
|
||||
{"n", Symbol("bazqux")},
|
||||
{"o", CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}})},
|
||||
{"p", Int32(12345)},
|
||||
{"q", Timestamp(10, 20)}, {"r", Int64(1234567890)},
|
||||
{"s", Decimal128(d128)}, {"t", MinKey()}, {"u", MaxKey()},
|
||||
}
|
||||
got := reflect.New(reflect.TypeOf(Doc{})).Elem()
|
||||
dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
|
||||
b, err := want.MarshalBSON()
|
||||
noerr(t, err)
|
||||
err = pcx.DocumentDecodeValue(dc, bsonrw.NewBSONDocumentReader(b), got)
|
||||
noerr(t, err)
|
||||
if !got.Interface().(Doc).Equal(want) {
|
||||
t.Error("Documents do not match")
|
||||
t.Errorf("\ngot :%v\nwant:%v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ArrayDecodeValue", func(t *testing.T) {
|
||||
t.Run("CodecDecodeError", func(t *testing.T) {
|
||||
val := reflect.New(reflect.TypeOf(false)).Elem()
|
||||
want := bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
|
||||
got := pcx.ArrayDecodeValue(bsoncodec.DecodeContext{}, &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Array}, val)
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("ReadArray Error", func(t *testing.T) {
|
||||
want := errors.New("ReadArray Error")
|
||||
llvrw := &bsonrwtest.ValueReaderWriter{
|
||||
T: t,
|
||||
Err: want,
|
||||
ErrAfter: bsonrwtest.ReadArray,
|
||||
BSONType: bsontype.Array,
|
||||
}
|
||||
got := pcx.ArrayDecodeValue(bsoncodec.DecodeContext{}, llvrw, reflect.New(tArray).Elem())
|
||||
if !compareErrors(got, want) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("decode array errors", func(t *testing.T) {
|
||||
dc := bsoncodec.DecodeContext{}
|
||||
err := errors.New("decode array error")
|
||||
testCases := []struct {
|
||||
name string
|
||||
dc bsoncodec.DecodeContext
|
||||
llvrw *bsonrwtest.ValueReaderWriter
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"ReadValue",
|
||||
dc,
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("re error"), ErrAfter: bsonrwtest.ReadValue},
|
||||
errors.New("re error"),
|
||||
},
|
||||
{"ReadDouble", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDouble, BSONType: bsontype.Double}, err},
|
||||
{"ReadString", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadString, BSONType: bsontype.String}, err},
|
||||
{"ReadBinary", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBinary, BSONType: bsontype.Binary}, err},
|
||||
{"ReadUndefined", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadUndefined, BSONType: bsontype.Undefined}, err},
|
||||
{"ReadObjectID", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadObjectID, BSONType: bsontype.ObjectID}, err},
|
||||
{"ReadBoolean", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBoolean, BSONType: bsontype.Boolean}, err},
|
||||
{"ReadDateTime", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDateTime, BSONType: bsontype.DateTime}, err},
|
||||
{"ReadNull", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadNull, BSONType: bsontype.Null}, err},
|
||||
{"ReadRegex", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadRegex, BSONType: bsontype.Regex}, err},
|
||||
{"ReadDBPointer", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDBPointer, BSONType: bsontype.DBPointer}, err},
|
||||
{"ReadJavascript", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadJavascript, BSONType: bsontype.JavaScript}, err},
|
||||
{"ReadSymbol", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadSymbol, BSONType: bsontype.Symbol}, err},
|
||||
{
|
||||
"ReadCodeWithScope (Lookup)", bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
|
||||
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadCodeWithScope, BSONType: bsontype.CodeWithScope},
|
||||
err,
|
||||
},
|
||||
{"ReadInt32", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt32, BSONType: bsontype.Int32}, err},
|
||||
{"ReadInt64", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt64, BSONType: bsontype.Int64}, err},
|
||||
{"ReadTimestamp", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadTimestamp, BSONType: bsontype.Timestamp}, err},
|
||||
{"ReadDecimal128", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDecimal128, BSONType: bsontype.Decimal128}, err},
|
||||
{"ReadMinKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMinKey, BSONType: bsontype.MinKey}, err},
|
||||
{"ReadMaxKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMaxKey, BSONType: bsontype.MaxKey}, err},
|
||||
{"Invalid Type", dc, &bsonrwtest.ValueReaderWriter{T: t, BSONType: bsontype.Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", bsontype.Type(0))},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := pcx.ArrayDecodeValue(tc.dc, tc.llvrw, reflect.New(tArray).Elem())
|
||||
if !compareErrors(err, tc.err) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
d128 := primitive.NewDecimal128(10, 20)
|
||||
want := Arr{
|
||||
Double(3.14159), String("foo"), Document(Doc{{"aa", Null()}}),
|
||||
Array(Arr{Null()}),
|
||||
Binary(0xFF, []byte{0x01, 0x02, 0x03}), Undefined(),
|
||||
ObjectID(oid), Boolean(true), DateTime(1234567890), Null(), Regex("foo", "bar"),
|
||||
DBPointer("foobar", oid), JavaScript("var hello = 'world';"), Symbol("bazqux"),
|
||||
CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}}), Int32(12345),
|
||||
Timestamp(10, 20), Int64(1234567890), Decimal128(d128), MinKey(), MaxKey(),
|
||||
}
|
||||
dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
|
||||
|
||||
b, err := Doc{{"", Array(want)}}.MarshalBSON()
|
||||
noerr(t, err)
|
||||
dvr := bsonrw.NewBSONDocumentReader(b)
|
||||
dr, err := dvr.ReadDocument()
|
||||
noerr(t, err)
|
||||
_, vr, err := dr.ReadElement()
|
||||
noerr(t, err)
|
||||
|
||||
val := reflect.New(tArray).Elem()
|
||||
err = pcx.ArrayDecodeValue(dc, vr, val)
|
||||
noerr(t, err)
|
||||
got := val.Interface().(Arr)
|
||||
if !got.Equal(want) {
|
||||
t.Error("Documents do not match")
|
||||
t.Errorf("\ngot :%v\nwant:%v", got, want)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("success path", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
b []byte
|
||||
err error
|
||||
}{
|
||||
{
|
||||
"map[string][]Element",
|
||||
map[string][]Elem{"Z": {{"A", Int32(1)}, {"B", Int32(2)}, {"EC", Int32(3)}}},
|
||||
docToBytes(Doc{{"Z", Document(Doc{{"A", Int32(1)}, {"B", Int32(2)}, {"EC", Int32(3)}})}}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"map[string][]Value",
|
||||
map[string][]Val{"Z": {Int32(1), Int32(2), Int32(3)}},
|
||||
docToBytes(Doc{{"Z", Array(Arr{Int32(1), Int32(2), Int32(3)})}}),
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"map[string]*Document",
|
||||
map[string]Doc{"Z": {{"foo", Null()}}},
|
||||
docToBytes(Doc{{"Z", Document(Doc{{"foo", Null()}})}}),
|
||||
nil,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("Decode", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
vr := bsonrw.NewBSONDocumentReader(tc.b)
|
||||
dec, err := bson.NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr)
|
||||
noerr(t, err)
|
||||
gotVal := reflect.New(reflect.TypeOf(tc.value))
|
||||
err = dec.Decode(gotVal.Interface())
|
||||
noerr(t, err)
|
||||
got := gotVal.Elem().Interface()
|
||||
want := tc.value
|
||||
if diff := cmp.Diff(
|
||||
got, want,
|
||||
); diff != "" {
|
||||
t.Errorf("difference:\n%s", diff)
|
||||
t.Errorf("Values are not equal.\ngot: %#v\nwant:%#v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func compareValues(v1, v2 Val) bool { return v1.Equal(v2) }
|
||||
|
||||
func docToBytes(d Doc) []byte {
|
||||
b, err := d.MarshalBSON()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
1025
mongo/x/bsonx/reflectionfree_d_codec.go
Normal file
1025
mongo/x/bsonx/reflectionfree_d_codec.go
Normal file
File diff suppressed because it is too large
Load Diff
123
mongo/x/bsonx/reflectionfree_d_codec_test.go
Normal file
123
mongo/x/bsonx/reflectionfree_d_codec_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
)
|
||||
|
||||
func TestReflectionFreeDCodec(t *testing.T) {
|
||||
assert.RegisterOpts(reflect.TypeOf(primitive.D{}), cmp.AllowUnexported(primitive.Decimal128{}))
|
||||
|
||||
now := time.Now()
|
||||
oid := primitive.NewObjectID()
|
||||
d128 := primitive.NewDecimal128(10, 20)
|
||||
js := primitive.JavaScript("js")
|
||||
symbol := primitive.Symbol("sybmol")
|
||||
binary := primitive.Binary{Subtype: 2, Data: []byte("binary")}
|
||||
datetime := primitive.NewDateTimeFromTime(now)
|
||||
regex := primitive.Regex{Pattern: "pattern", Options: "i"}
|
||||
dbPointer := primitive.DBPointer{DB: "db", Pointer: oid}
|
||||
timestamp := primitive.Timestamp{T: 5, I: 10}
|
||||
cws := primitive.CodeWithScope{Code: js, Scope: bson.D{{"x", 1}}}
|
||||
noReflectionRegistry := bson.NewRegistryBuilder().RegisterCodec(tPrimitiveD, ReflectionFreeDCodec).Build()
|
||||
docWithAllTypes := primitive.D{
|
||||
{"byteSlice", []byte("foobar")},
|
||||
{"sliceByteSlice", [][]byte{[]byte("foobar")}},
|
||||
{"timeTime", now},
|
||||
{"sliceTimeTime", []time.Time{now}},
|
||||
{"objectID", oid},
|
||||
{"sliceObjectID", []primitive.ObjectID{oid}},
|
||||
{"decimal128", d128},
|
||||
{"sliceDecimal128", []primitive.Decimal128{d128}},
|
||||
{"js", js},
|
||||
{"sliceJS", []primitive.JavaScript{js}},
|
||||
{"symbol", symbol},
|
||||
{"sliceSymbol", []primitive.Symbol{symbol}},
|
||||
{"binary", binary},
|
||||
{"sliceBinary", []primitive.Binary{binary}},
|
||||
{"undefined", primitive.Undefined{}},
|
||||
{"sliceUndefined", []primitive.Undefined{{}}},
|
||||
{"datetime", datetime},
|
||||
{"sliceDateTime", []primitive.DateTime{datetime}},
|
||||
{"null", primitive.Null{}},
|
||||
{"sliceNull", []primitive.Null{{}}},
|
||||
{"regex", regex},
|
||||
{"sliceRegex", []primitive.Regex{regex}},
|
||||
{"dbPointer", dbPointer},
|
||||
{"sliceDBPointer", []primitive.DBPointer{dbPointer}},
|
||||
{"timestamp", timestamp},
|
||||
{"sliceTimestamp", []primitive.Timestamp{timestamp}},
|
||||
{"minKey", primitive.MinKey{}},
|
||||
{"sliceMinKey", []primitive.MinKey{{}}},
|
||||
{"maxKey", primitive.MaxKey{}},
|
||||
{"sliceMaxKey", []primitive.MaxKey{{}}},
|
||||
{"cws", cws},
|
||||
{"sliceCWS", []primitive.CodeWithScope{cws}},
|
||||
{"bool", true},
|
||||
{"sliceBool", []bool{true}},
|
||||
{"int", int(10)},
|
||||
{"sliceInt", []int{10}},
|
||||
{"int8", int8(10)},
|
||||
{"sliceInt8", []int8{10}},
|
||||
{"int16", int16(10)},
|
||||
{"sliceInt16", []int16{10}},
|
||||
{"int32", int32(10)},
|
||||
{"sliceInt32", []int32{10}},
|
||||
{"int64", int64(10)},
|
||||
{"sliceInt64", []int64{10}},
|
||||
{"uint", uint(10)},
|
||||
{"sliceUint", []uint{10}},
|
||||
{"uint8", uint8(10)},
|
||||
{"sliceUint8", []uint8{10}},
|
||||
{"uint16", uint16(10)},
|
||||
{"sliceUint16", []uint16{10}},
|
||||
{"uint32", uint32(10)},
|
||||
{"sliceUint32", []uint32{10}},
|
||||
{"uint64", uint64(10)},
|
||||
{"sliceUint64", []uint64{10}},
|
||||
{"float32", float32(10)},
|
||||
{"sliceFloat32", []float32{10}},
|
||||
{"float64", float64(10)},
|
||||
{"sliceFloat64", []float64{10}},
|
||||
{"primitiveA", primitive.A{"foo", "bar"}},
|
||||
}
|
||||
|
||||
t.Run("encode", func(t *testing.T) {
|
||||
// Assert that bson.Marshal returns the same result when using the default registry and noReflectionRegistry.
|
||||
|
||||
expected, err := bson.Marshal(docWithAllTypes)
|
||||
assert.Nil(t, err, "Marshal error with default registry: %v", err)
|
||||
actual, err := bson.MarshalWithRegistry(noReflectionRegistry, docWithAllTypes)
|
||||
assert.Nil(t, err, "Marshal error with noReflectionRegistry: %v", err)
|
||||
assert.Equal(t, expected, actual, "expected doc %s, got %s", bson.Raw(expected), bson.Raw(actual))
|
||||
})
|
||||
t.Run("decode", func(t *testing.T) {
|
||||
// Assert that bson.Unmarshal returns the same result when using the default registry and noReflectionRegistry.
|
||||
|
||||
// docWithAllTypes contains some types that can't be roundtripped. For example, any slices besides primitive.A
|
||||
// would start of as []T but unmarshal to primitive.A. To get around this, we first marshal docWithAllTypes to
|
||||
// raw bytes and then Unmarshal to another primitive.D rather than asserting directly against docWithAllTypes.
|
||||
docBytes, err := bson.Marshal(docWithAllTypes)
|
||||
assert.Nil(t, err, "Marshal error: %v", err)
|
||||
|
||||
var expected, actual primitive.D
|
||||
err = bson.Unmarshal(docBytes, &expected)
|
||||
assert.Nil(t, err, "Unmarshal error with default registry: %v", err)
|
||||
err = bson.UnmarshalWithRegistry(noReflectionRegistry, docBytes, &actual)
|
||||
assert.Nil(t, err, "Unmarshal error with noReflectionRegistry: %v", err)
|
||||
|
||||
assert.Equal(t, expected, actual, "expected document %v, got %v", expected, actual)
|
||||
})
|
||||
}
|
||||
28
mongo/x/bsonx/registry.go
Normal file
28
mongo/x/bsonx/registry.go
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
)
|
||||
|
||||
// DefaultRegistry is the default bsoncodec.Registry. It contains the default codecs and the
|
||||
// primitive codecs.
|
||||
var DefaultRegistry = NewRegistryBuilder().Build()
|
||||
|
||||
// NewRegistryBuilder creates a new RegistryBuilder configured with the default encoders and
|
||||
// decoders from the bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the
|
||||
// PrimitiveCodecs type in this package.
|
||||
func NewRegistryBuilder() *bsoncodec.RegistryBuilder {
|
||||
rb := bsoncodec.NewRegistryBuilder()
|
||||
bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb)
|
||||
bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)
|
||||
bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb)
|
||||
primitiveCodecs.RegisterPrimitiveCodecs(rb)
|
||||
return rb
|
||||
}
|
||||
866
mongo/x/bsonx/value.go
Normal file
866
mongo/x/bsonx/value.go
Normal file
@@ -0,0 +1,866 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// Val represents a BSON value.
|
||||
type Val struct {
|
||||
// NOTE: The bootstrap is a small amount of space that'll be on the stack. At 15 bytes this
|
||||
// doesn't make this type any larger, since there are 7 bytes of padding and we want an int64 to
|
||||
// store small values (e.g. boolean, double, int64, etc...). The primitive property is where all
|
||||
// of the larger values go. They will use either Go primitives or the primitive.* types.
|
||||
t bsontype.Type
|
||||
bootstrap [15]byte
|
||||
primitive interface{}
|
||||
}
|
||||
|
||||
func (v Val) string() string {
|
||||
if v.primitive != nil {
|
||||
return v.primitive.(string)
|
||||
}
|
||||
// The string will either end with a null byte or it fills the entire bootstrap space.
|
||||
length := v.bootstrap[0]
|
||||
return string(v.bootstrap[1 : length+1])
|
||||
}
|
||||
|
||||
func (v Val) writestring(str string) Val {
|
||||
switch {
|
||||
case len(str) < 15:
|
||||
v.bootstrap[0] = uint8(len(str))
|
||||
copy(v.bootstrap[1:], str)
|
||||
default:
|
||||
v.primitive = str
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (v Val) i64() int64 {
|
||||
return int64(v.bootstrap[0]) | int64(v.bootstrap[1])<<8 | int64(v.bootstrap[2])<<16 |
|
||||
int64(v.bootstrap[3])<<24 | int64(v.bootstrap[4])<<32 | int64(v.bootstrap[5])<<40 |
|
||||
int64(v.bootstrap[6])<<48 | int64(v.bootstrap[7])<<56
|
||||
}
|
||||
|
||||
func (v Val) writei64(i64 int64) Val {
|
||||
v.bootstrap[0] = byte(i64)
|
||||
v.bootstrap[1] = byte(i64 >> 8)
|
||||
v.bootstrap[2] = byte(i64 >> 16)
|
||||
v.bootstrap[3] = byte(i64 >> 24)
|
||||
v.bootstrap[4] = byte(i64 >> 32)
|
||||
v.bootstrap[5] = byte(i64 >> 40)
|
||||
v.bootstrap[6] = byte(i64 >> 48)
|
||||
v.bootstrap[7] = byte(i64 >> 56)
|
||||
return v
|
||||
}
|
||||
|
||||
// IsZero returns true if this value is zero or a BSON null.
|
||||
func (v Val) IsZero() bool { return v.t == bsontype.Type(0) || v.t == bsontype.Null }
|
||||
|
||||
func (v Val) String() string {
|
||||
// TODO(GODRIVER-612): When bsoncore has appenders for extended JSON use that here.
|
||||
return fmt.Sprintf("%v", v.Interface())
|
||||
}
|
||||
|
||||
// Interface returns the Go value of this Value as an empty interface.
|
||||
//
|
||||
// This method will return nil if it is empty, otherwise it will return a Go primitive or a
|
||||
// primitive.* instance.
|
||||
func (v Val) Interface() interface{} {
|
||||
switch v.Type() {
|
||||
case bsontype.Double:
|
||||
return v.Double()
|
||||
case bsontype.String:
|
||||
return v.StringValue()
|
||||
case bsontype.EmbeddedDocument:
|
||||
switch v.primitive.(type) {
|
||||
case Doc:
|
||||
return v.primitive.(Doc)
|
||||
case MDoc:
|
||||
return v.primitive.(MDoc)
|
||||
default:
|
||||
return primitive.Null{}
|
||||
}
|
||||
case bsontype.Array:
|
||||
return v.Array()
|
||||
case bsontype.Binary:
|
||||
return v.primitive.(primitive.Binary)
|
||||
case bsontype.Undefined:
|
||||
return primitive.Undefined{}
|
||||
case bsontype.ObjectID:
|
||||
return v.ObjectID()
|
||||
case bsontype.Boolean:
|
||||
return v.Boolean()
|
||||
case bsontype.DateTime:
|
||||
return v.DateTime()
|
||||
case bsontype.Null:
|
||||
return primitive.Null{}
|
||||
case bsontype.Regex:
|
||||
return v.primitive.(primitive.Regex)
|
||||
case bsontype.DBPointer:
|
||||
return v.primitive.(primitive.DBPointer)
|
||||
case bsontype.JavaScript:
|
||||
return v.JavaScript()
|
||||
case bsontype.Symbol:
|
||||
return v.Symbol()
|
||||
case bsontype.CodeWithScope:
|
||||
return v.primitive.(primitive.CodeWithScope)
|
||||
case bsontype.Int32:
|
||||
return v.Int32()
|
||||
case bsontype.Timestamp:
|
||||
t, i := v.Timestamp()
|
||||
return primitive.Timestamp{T: t, I: i}
|
||||
case bsontype.Int64:
|
||||
return v.Int64()
|
||||
case bsontype.Decimal128:
|
||||
return v.Decimal128()
|
||||
case bsontype.MinKey:
|
||||
return primitive.MinKey{}
|
||||
case bsontype.MaxKey:
|
||||
return primitive.MaxKey{}
|
||||
default:
|
||||
return primitive.Null{}
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
|
||||
func (v Val) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
return v.MarshalAppendBSONValue(nil)
|
||||
}
|
||||
|
||||
// MarshalAppendBSONValue is similar to MarshalBSONValue, but allows the caller to specify a slice
|
||||
// to add the bytes to.
|
||||
func (v Val) MarshalAppendBSONValue(dst []byte) (bsontype.Type, []byte, error) {
|
||||
t := v.Type()
|
||||
switch v.Type() {
|
||||
case bsontype.Double:
|
||||
dst = bsoncore.AppendDouble(dst, v.Double())
|
||||
case bsontype.String:
|
||||
dst = bsoncore.AppendString(dst, v.String())
|
||||
case bsontype.EmbeddedDocument:
|
||||
switch v.primitive.(type) {
|
||||
case Doc:
|
||||
t, dst, _ = v.primitive.(Doc).MarshalBSONValue() // Doc.MarshalBSONValue never returns an error.
|
||||
case MDoc:
|
||||
t, dst, _ = v.primitive.(MDoc).MarshalBSONValue() // MDoc.MarshalBSONValue never returns an error.
|
||||
}
|
||||
case bsontype.Array:
|
||||
t, dst, _ = v.Array().MarshalBSONValue() // Arr.MarshalBSON never returns an error.
|
||||
case bsontype.Binary:
|
||||
subtype, bindata := v.Binary()
|
||||
dst = bsoncore.AppendBinary(dst, subtype, bindata)
|
||||
case bsontype.Undefined:
|
||||
case bsontype.ObjectID:
|
||||
dst = bsoncore.AppendObjectID(dst, v.ObjectID())
|
||||
case bsontype.Boolean:
|
||||
dst = bsoncore.AppendBoolean(dst, v.Boolean())
|
||||
case bsontype.DateTime:
|
||||
dst = bsoncore.AppendDateTime(dst, v.DateTime())
|
||||
case bsontype.Null:
|
||||
case bsontype.Regex:
|
||||
pattern, options := v.Regex()
|
||||
dst = bsoncore.AppendRegex(dst, pattern, options)
|
||||
case bsontype.DBPointer:
|
||||
ns, ptr := v.DBPointer()
|
||||
dst = bsoncore.AppendDBPointer(dst, ns, ptr)
|
||||
case bsontype.JavaScript:
|
||||
dst = bsoncore.AppendJavaScript(dst, v.JavaScript())
|
||||
case bsontype.Symbol:
|
||||
dst = bsoncore.AppendSymbol(dst, v.Symbol())
|
||||
case bsontype.CodeWithScope:
|
||||
code, doc := v.CodeWithScope()
|
||||
var scope []byte
|
||||
scope, _ = doc.MarshalBSON() // Doc.MarshalBSON never returns an error.
|
||||
dst = bsoncore.AppendCodeWithScope(dst, code, scope)
|
||||
case bsontype.Int32:
|
||||
dst = bsoncore.AppendInt32(dst, v.Int32())
|
||||
case bsontype.Timestamp:
|
||||
t, i := v.Timestamp()
|
||||
dst = bsoncore.AppendTimestamp(dst, t, i)
|
||||
case bsontype.Int64:
|
||||
dst = bsoncore.AppendInt64(dst, v.Int64())
|
||||
case bsontype.Decimal128:
|
||||
dst = bsoncore.AppendDecimal128(dst, v.Decimal128())
|
||||
case bsontype.MinKey:
|
||||
case bsontype.MaxKey:
|
||||
default:
|
||||
panic(fmt.Errorf("invalid BSON type %v", t))
|
||||
}
|
||||
|
||||
return t, dst, nil
|
||||
}
|
||||
|
||||
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.
|
||||
func (v *Val) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
|
||||
if v == nil {
|
||||
return errors.New("cannot unmarshal into nil Value")
|
||||
}
|
||||
var err error
|
||||
var ok = true
|
||||
var rem []byte
|
||||
switch t {
|
||||
case bsontype.Double:
|
||||
var f64 float64
|
||||
f64, rem, ok = bsoncore.ReadDouble(data)
|
||||
*v = Double(f64)
|
||||
case bsontype.String:
|
||||
var str string
|
||||
str, rem, ok = bsoncore.ReadString(data)
|
||||
*v = String(str)
|
||||
case bsontype.EmbeddedDocument:
|
||||
var raw []byte
|
||||
var doc Doc
|
||||
raw, rem, ok = bsoncore.ReadDocument(data)
|
||||
doc, err = ReadDoc(raw)
|
||||
*v = Document(doc)
|
||||
case bsontype.Array:
|
||||
var raw []byte
|
||||
arr := make(Arr, 0)
|
||||
raw, rem, ok = bsoncore.ReadArray(data)
|
||||
err = arr.UnmarshalBSONValue(t, raw)
|
||||
*v = Array(arr)
|
||||
case bsontype.Binary:
|
||||
var subtype byte
|
||||
var bindata []byte
|
||||
subtype, bindata, rem, ok = bsoncore.ReadBinary(data)
|
||||
*v = Binary(subtype, bindata)
|
||||
case bsontype.Undefined:
|
||||
*v = Undefined()
|
||||
case bsontype.ObjectID:
|
||||
var oid primitive.ObjectID
|
||||
oid, rem, ok = bsoncore.ReadObjectID(data)
|
||||
*v = ObjectID(oid)
|
||||
case bsontype.Boolean:
|
||||
var b bool
|
||||
b, rem, ok = bsoncore.ReadBoolean(data)
|
||||
*v = Boolean(b)
|
||||
case bsontype.DateTime:
|
||||
var dt int64
|
||||
dt, rem, ok = bsoncore.ReadDateTime(data)
|
||||
*v = DateTime(dt)
|
||||
case bsontype.Null:
|
||||
*v = Null()
|
||||
case bsontype.Regex:
|
||||
var pattern, options string
|
||||
pattern, options, rem, ok = bsoncore.ReadRegex(data)
|
||||
*v = Regex(pattern, options)
|
||||
case bsontype.DBPointer:
|
||||
var ns string
|
||||
var ptr primitive.ObjectID
|
||||
ns, ptr, rem, ok = bsoncore.ReadDBPointer(data)
|
||||
*v = DBPointer(ns, ptr)
|
||||
case bsontype.JavaScript:
|
||||
var js string
|
||||
js, rem, ok = bsoncore.ReadJavaScript(data)
|
||||
*v = JavaScript(js)
|
||||
case bsontype.Symbol:
|
||||
var symbol string
|
||||
symbol, rem, ok = bsoncore.ReadSymbol(data)
|
||||
*v = Symbol(symbol)
|
||||
case bsontype.CodeWithScope:
|
||||
var raw []byte
|
||||
var code string
|
||||
var scope Doc
|
||||
code, raw, rem, ok = bsoncore.ReadCodeWithScope(data)
|
||||
scope, err = ReadDoc(raw)
|
||||
*v = CodeWithScope(code, scope)
|
||||
case bsontype.Int32:
|
||||
var i32 int32
|
||||
i32, rem, ok = bsoncore.ReadInt32(data)
|
||||
*v = Int32(i32)
|
||||
case bsontype.Timestamp:
|
||||
var i, t uint32
|
||||
t, i, rem, ok = bsoncore.ReadTimestamp(data)
|
||||
*v = Timestamp(t, i)
|
||||
case bsontype.Int64:
|
||||
var i64 int64
|
||||
i64, rem, ok = bsoncore.ReadInt64(data)
|
||||
*v = Int64(i64)
|
||||
case bsontype.Decimal128:
|
||||
var d128 primitive.Decimal128
|
||||
d128, rem, ok = bsoncore.ReadDecimal128(data)
|
||||
*v = Decimal128(d128)
|
||||
case bsontype.MinKey:
|
||||
*v = MinKey()
|
||||
case bsontype.MaxKey:
|
||||
*v = MaxKey()
|
||||
default:
|
||||
err = fmt.Errorf("invalid BSON type %v", t)
|
||||
}
|
||||
|
||||
if !ok && err == nil {
|
||||
err = bsoncore.NewInsufficientBytesError(data, rem)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Type returns the BSON type of this value.
|
||||
func (v Val) Type() bsontype.Type {
|
||||
if v.t == bsontype.Type(0) {
|
||||
return bsontype.Null
|
||||
}
|
||||
return v.t
|
||||
}
|
||||
|
||||
// IsNumber returns true if the type of v is a numberic BSON type.
|
||||
func (v Val) IsNumber() bool {
|
||||
switch v.Type() {
|
||||
case bsontype.Double, bsontype.Int32, bsontype.Int64, bsontype.Decimal128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Double returns the BSON double value the Value represents. It panics if the value is a BSON type
|
||||
// other than double.
|
||||
func (v Val) Double() float64 {
|
||||
if v.t != bsontype.Double {
|
||||
panic(ElementTypeError{"bson.Value.Double", v.t})
|
||||
}
|
||||
return math.Float64frombits(binary.LittleEndian.Uint64(v.bootstrap[0:8]))
|
||||
}
|
||||
|
||||
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
|
||||
func (v Val) DoubleOK() (float64, bool) {
|
||||
if v.t != bsontype.Double {
|
||||
return 0, false
|
||||
}
|
||||
return math.Float64frombits(binary.LittleEndian.Uint64(v.bootstrap[0:8])), true
|
||||
}
|
||||
|
||||
// StringValue returns the BSON string the Value represents. It panics if the value is a BSON type
|
||||
// other than string.
|
||||
//
|
||||
// NOTE: This method is called StringValue to avoid it implementing the
|
||||
// fmt.Stringer interface.
|
||||
func (v Val) StringValue() string {
|
||||
if v.t != bsontype.String {
|
||||
panic(ElementTypeError{"bson.Value.StringValue", v.t})
|
||||
}
|
||||
return v.string()
|
||||
}
|
||||
|
||||
// StringValueOK is the same as StringValue, but returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) StringValueOK() (string, bool) {
|
||||
if v.t != bsontype.String {
|
||||
return "", false
|
||||
}
|
||||
return v.string(), true
|
||||
}
|
||||
|
||||
func (v Val) asDoc() Doc {
|
||||
doc, ok := v.primitive.(Doc)
|
||||
if ok {
|
||||
return doc
|
||||
}
|
||||
mdoc := v.primitive.(MDoc)
|
||||
for k, v := range mdoc {
|
||||
doc = append(doc, Elem{k, v})
|
||||
}
|
||||
return doc
|
||||
}
|
||||
|
||||
func (v Val) asMDoc() MDoc {
|
||||
mdoc, ok := v.primitive.(MDoc)
|
||||
if ok {
|
||||
return mdoc
|
||||
}
|
||||
mdoc = make(MDoc)
|
||||
doc := v.primitive.(Doc)
|
||||
for _, elem := range doc {
|
||||
mdoc[elem.Key] = elem.Value
|
||||
}
|
||||
return mdoc
|
||||
}
|
||||
|
||||
// Document returns the BSON embedded document value the Value represents. It panics if the value
|
||||
// is a BSON type other than embedded document.
|
||||
func (v Val) Document() Doc {
|
||||
if v.t != bsontype.EmbeddedDocument {
|
||||
panic(ElementTypeError{"bson.Value.Document", v.t})
|
||||
}
|
||||
return v.asDoc()
|
||||
}
|
||||
|
||||
// DocumentOK is the same as Document, except it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) DocumentOK() (Doc, bool) {
|
||||
if v.t != bsontype.EmbeddedDocument {
|
||||
return nil, false
|
||||
}
|
||||
return v.asDoc(), true
|
||||
}
|
||||
|
||||
// MDocument returns the BSON embedded document value the Value represents. It panics if the value
|
||||
// is a BSON type other than embedded document.
|
||||
func (v Val) MDocument() MDoc {
|
||||
if v.t != bsontype.EmbeddedDocument {
|
||||
panic(ElementTypeError{"bson.Value.MDocument", v.t})
|
||||
}
|
||||
return v.asMDoc()
|
||||
}
|
||||
|
||||
// MDocumentOK is the same as Document, except it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) MDocumentOK() (MDoc, bool) {
|
||||
if v.t != bsontype.EmbeddedDocument {
|
||||
return nil, false
|
||||
}
|
||||
return v.asMDoc(), true
|
||||
}
|
||||
|
||||
// Array returns the BSON array value the Value represents. It panics if the value is a BSON type
|
||||
// other than array.
|
||||
func (v Val) Array() Arr {
|
||||
if v.t != bsontype.Array {
|
||||
panic(ElementTypeError{"bson.Value.Array", v.t})
|
||||
}
|
||||
return v.primitive.(Arr)
|
||||
}
|
||||
|
||||
// ArrayOK is the same as Array, except it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) ArrayOK() (Arr, bool) {
|
||||
if v.t != bsontype.Array {
|
||||
return nil, false
|
||||
}
|
||||
return v.primitive.(Arr), true
|
||||
}
|
||||
|
||||
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
|
||||
// other than binary.
|
||||
func (v Val) Binary() (byte, []byte) {
|
||||
if v.t != bsontype.Binary {
|
||||
panic(ElementTypeError{"bson.Value.Binary", v.t})
|
||||
}
|
||||
bin := v.primitive.(primitive.Binary)
|
||||
return bin.Subtype, bin.Data
|
||||
}
|
||||
|
||||
// BinaryOK is the same as Binary, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) BinaryOK() (byte, []byte, bool) {
|
||||
if v.t != bsontype.Binary {
|
||||
return 0x00, nil, false
|
||||
}
|
||||
bin := v.primitive.(primitive.Binary)
|
||||
return bin.Subtype, bin.Data, true
|
||||
}
|
||||
|
||||
// Undefined returns the BSON undefined the Value represents. It panics if the value is a BSON type
|
||||
// other than binary.
|
||||
func (v Val) Undefined() {
|
||||
if v.t != bsontype.Undefined {
|
||||
panic(ElementTypeError{"bson.Value.Undefined", v.t})
|
||||
}
|
||||
}
|
||||
|
||||
// UndefinedOK is the same as Undefined, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) UndefinedOK() bool {
|
||||
return v.t == bsontype.Undefined
|
||||
}
|
||||
|
||||
// ObjectID returns the BSON ObjectID the Value represents. It panics if the value is a BSON type
|
||||
// other than ObjectID.
|
||||
func (v Val) ObjectID() primitive.ObjectID {
|
||||
if v.t != bsontype.ObjectID {
|
||||
panic(ElementTypeError{"bson.Value.ObjectID", v.t})
|
||||
}
|
||||
var oid primitive.ObjectID
|
||||
copy(oid[:], v.bootstrap[:12])
|
||||
return oid
|
||||
}
|
||||
|
||||
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) ObjectIDOK() (primitive.ObjectID, bool) {
|
||||
if v.t != bsontype.ObjectID {
|
||||
return primitive.ObjectID{}, false
|
||||
}
|
||||
var oid primitive.ObjectID
|
||||
copy(oid[:], v.bootstrap[:12])
|
||||
return oid, true
|
||||
}
|
||||
|
||||
// Boolean returns the BSON boolean the Value represents. It panics if the value is a BSON type
|
||||
// other than boolean.
|
||||
func (v Val) Boolean() bool {
|
||||
if v.t != bsontype.Boolean {
|
||||
panic(ElementTypeError{"bson.Value.Boolean", v.t})
|
||||
}
|
||||
return v.bootstrap[0] == 0x01
|
||||
}
|
||||
|
||||
// BooleanOK is the same as Boolean, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) BooleanOK() (bool, bool) {
|
||||
if v.t != bsontype.Boolean {
|
||||
return false, false
|
||||
}
|
||||
return v.bootstrap[0] == 0x01, true
|
||||
}
|
||||
|
||||
// DateTime returns the BSON datetime the Value represents. It panics if the value is a BSON type
|
||||
// other than datetime.
|
||||
func (v Val) DateTime() int64 {
|
||||
if v.t != bsontype.DateTime {
|
||||
panic(ElementTypeError{"bson.Value.DateTime", v.t})
|
||||
}
|
||||
return v.i64()
|
||||
}
|
||||
|
||||
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) DateTimeOK() (int64, bool) {
|
||||
if v.t != bsontype.DateTime {
|
||||
return 0, false
|
||||
}
|
||||
return v.i64(), true
|
||||
}
|
||||
|
||||
// Time returns the BSON datetime the Value represents as time.Time. It panics if the value is a BSON
|
||||
// type other than datetime.
|
||||
func (v Val) Time() time.Time {
|
||||
if v.t != bsontype.DateTime {
|
||||
panic(ElementTypeError{"bson.Value.Time", v.t})
|
||||
}
|
||||
i := v.i64()
|
||||
return time.Unix(i/1000, i%1000*1000000)
|
||||
}
|
||||
|
||||
// TimeOK is the same as Time, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) TimeOK() (time.Time, bool) {
|
||||
if v.t != bsontype.DateTime {
|
||||
return time.Time{}, false
|
||||
}
|
||||
i := v.i64()
|
||||
return time.Unix(i/1000, i%1000*1000000), true
|
||||
}
|
||||
|
||||
// Null returns the BSON undefined the Value represents. It panics if the value is a BSON type
|
||||
// other than binary.
|
||||
func (v Val) Null() {
|
||||
if v.t != bsontype.Null && v.t != bsontype.Type(0) {
|
||||
panic(ElementTypeError{"bson.Value.Null", v.t})
|
||||
}
|
||||
}
|
||||
|
||||
// NullOK is the same as Null, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) NullOK() bool {
|
||||
if v.t != bsontype.Null && v.t != bsontype.Type(0) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Regex returns the BSON regex the Value represents. It panics if the value is a BSON type
|
||||
// other than regex.
|
||||
func (v Val) Regex() (pattern, options string) {
|
||||
if v.t != bsontype.Regex {
|
||||
panic(ElementTypeError{"bson.Value.Regex", v.t})
|
||||
}
|
||||
regex := v.primitive.(primitive.Regex)
|
||||
return regex.Pattern, regex.Options
|
||||
}
|
||||
|
||||
// RegexOK is the same as Regex, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) RegexOK() (pattern, options string, ok bool) {
|
||||
if v.t != bsontype.Regex {
|
||||
return "", "", false
|
||||
}
|
||||
regex := v.primitive.(primitive.Regex)
|
||||
return regex.Pattern, regex.Options, true
|
||||
}
|
||||
|
||||
// DBPointer returns the BSON dbpointer the Value represents. It panics if the value is a BSON type
|
||||
// other than dbpointer.
|
||||
func (v Val) DBPointer() (string, primitive.ObjectID) {
|
||||
if v.t != bsontype.DBPointer {
|
||||
panic(ElementTypeError{"bson.Value.DBPointer", v.t})
|
||||
}
|
||||
dbptr := v.primitive.(primitive.DBPointer)
|
||||
return dbptr.DB, dbptr.Pointer
|
||||
}
|
||||
|
||||
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) DBPointerOK() (string, primitive.ObjectID, bool) {
|
||||
if v.t != bsontype.DBPointer {
|
||||
return "", primitive.ObjectID{}, false
|
||||
}
|
||||
dbptr := v.primitive.(primitive.DBPointer)
|
||||
return dbptr.DB, dbptr.Pointer, true
|
||||
}
|
||||
|
||||
// JavaScript returns the BSON JavaScript the Value represents. It panics if the value is a BSON type
|
||||
// other than JavaScript.
|
||||
func (v Val) JavaScript() string {
|
||||
if v.t != bsontype.JavaScript {
|
||||
panic(ElementTypeError{"bson.Value.JavaScript", v.t})
|
||||
}
|
||||
return v.string()
|
||||
}
|
||||
|
||||
// JavaScriptOK is the same as Javascript, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) JavaScriptOK() (string, bool) {
|
||||
if v.t != bsontype.JavaScript {
|
||||
return "", false
|
||||
}
|
||||
return v.string(), true
|
||||
}
|
||||
|
||||
// Symbol returns the BSON symbol the Value represents. It panics if the value is a BSON type
|
||||
// other than symbol.
|
||||
func (v Val) Symbol() string {
|
||||
if v.t != bsontype.Symbol {
|
||||
panic(ElementTypeError{"bson.Value.Symbol", v.t})
|
||||
}
|
||||
return v.string()
|
||||
}
|
||||
|
||||
// SymbolOK is the same as Javascript, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) SymbolOK() (string, bool) {
|
||||
if v.t != bsontype.Symbol {
|
||||
return "", false
|
||||
}
|
||||
return v.string(), true
|
||||
}
|
||||
|
||||
// CodeWithScope returns the BSON code with scope value the Value represents. It panics if the
|
||||
// value is a BSON type other than code with scope.
|
||||
func (v Val) CodeWithScope() (string, Doc) {
|
||||
if v.t != bsontype.CodeWithScope {
|
||||
panic(ElementTypeError{"bson.Value.CodeWithScope", v.t})
|
||||
}
|
||||
cws := v.primitive.(primitive.CodeWithScope)
|
||||
return string(cws.Code), cws.Scope.(Doc)
|
||||
}
|
||||
|
||||
// CodeWithScopeOK is the same as JavascriptWithScope,
|
||||
// except that it returns a boolean instead of panicking.
|
||||
func (v Val) CodeWithScopeOK() (string, Doc, bool) {
|
||||
if v.t != bsontype.CodeWithScope {
|
||||
return "", nil, false
|
||||
}
|
||||
cws := v.primitive.(primitive.CodeWithScope)
|
||||
return string(cws.Code), cws.Scope.(Doc), true
|
||||
}
|
||||
|
||||
// Int32 returns the BSON int32 the Value represents. It panics if the value is a BSON type
|
||||
// other than int32.
|
||||
func (v Val) Int32() int32 {
|
||||
if v.t != bsontype.Int32 {
|
||||
panic(ElementTypeError{"bson.Value.Int32", v.t})
|
||||
}
|
||||
return int32(v.bootstrap[0]) | int32(v.bootstrap[1])<<8 |
|
||||
int32(v.bootstrap[2])<<16 | int32(v.bootstrap[3])<<24
|
||||
}
|
||||
|
||||
// Int32OK is the same as Int32, except that it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) Int32OK() (int32, bool) {
|
||||
if v.t != bsontype.Int32 {
|
||||
return 0, false
|
||||
}
|
||||
return int32(v.bootstrap[0]) | int32(v.bootstrap[1])<<8 |
|
||||
int32(v.bootstrap[2])<<16 | int32(v.bootstrap[3])<<24,
|
||||
true
|
||||
}
|
||||
|
||||
// Timestamp returns the BSON timestamp the Value represents. It panics if the value is a
|
||||
// BSON type other than timestamp.
|
||||
func (v Val) Timestamp() (t, i uint32) {
|
||||
if v.t != bsontype.Timestamp {
|
||||
panic(ElementTypeError{"bson.Value.Timestamp", v.t})
|
||||
}
|
||||
return uint32(v.bootstrap[4]) | uint32(v.bootstrap[5])<<8 |
|
||||
uint32(v.bootstrap[6])<<16 | uint32(v.bootstrap[7])<<24,
|
||||
uint32(v.bootstrap[0]) | uint32(v.bootstrap[1])<<8 |
|
||||
uint32(v.bootstrap[2])<<16 | uint32(v.bootstrap[3])<<24
|
||||
}
|
||||
|
||||
// TimestampOK is the same as Timestamp, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) TimestampOK() (t uint32, i uint32, ok bool) {
|
||||
if v.t != bsontype.Timestamp {
|
||||
return 0, 0, false
|
||||
}
|
||||
return uint32(v.bootstrap[4]) | uint32(v.bootstrap[5])<<8 |
|
||||
uint32(v.bootstrap[6])<<16 | uint32(v.bootstrap[7])<<24,
|
||||
uint32(v.bootstrap[0]) | uint32(v.bootstrap[1])<<8 |
|
||||
uint32(v.bootstrap[2])<<16 | uint32(v.bootstrap[3])<<24,
|
||||
true
|
||||
}
|
||||
|
||||
// Int64 returns the BSON int64 the Value represents. It panics if the value is a BSON type
|
||||
// other than int64.
|
||||
func (v Val) Int64() int64 {
|
||||
if v.t != bsontype.Int64 {
|
||||
panic(ElementTypeError{"bson.Value.Int64", v.t})
|
||||
}
|
||||
return v.i64()
|
||||
}
|
||||
|
||||
// Int64OK is the same as Int64, except that it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) Int64OK() (int64, bool) {
|
||||
if v.t != bsontype.Int64 {
|
||||
return 0, false
|
||||
}
|
||||
return v.i64(), true
|
||||
}
|
||||
|
||||
// Decimal128 returns the BSON decimal128 value the Value represents. It panics if the value is a
|
||||
// BSON type other than decimal128.
|
||||
func (v Val) Decimal128() primitive.Decimal128 {
|
||||
if v.t != bsontype.Decimal128 {
|
||||
panic(ElementTypeError{"bson.Value.Decimal128", v.t})
|
||||
}
|
||||
return v.primitive.(primitive.Decimal128)
|
||||
}
|
||||
|
||||
// Decimal128OK is the same as Decimal128, except that it returns a boolean
|
||||
// instead of panicking.
|
||||
func (v Val) Decimal128OK() (primitive.Decimal128, bool) {
|
||||
if v.t != bsontype.Decimal128 {
|
||||
return primitive.Decimal128{}, false
|
||||
}
|
||||
return v.primitive.(primitive.Decimal128), true
|
||||
}
|
||||
|
||||
// MinKey returns the BSON minkey the Value represents. It panics if the value is a BSON type
|
||||
// other than binary.
|
||||
func (v Val) MinKey() {
|
||||
if v.t != bsontype.MinKey {
|
||||
panic(ElementTypeError{"bson.Value.MinKey", v.t})
|
||||
}
|
||||
}
|
||||
|
||||
// MinKeyOK is the same as MinKey, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) MinKeyOK() bool {
|
||||
return v.t == bsontype.MinKey
|
||||
}
|
||||
|
||||
// MaxKey returns the BSON maxkey the Value represents. It panics if the value is a BSON type
|
||||
// other than binary.
|
||||
func (v Val) MaxKey() {
|
||||
if v.t != bsontype.MaxKey {
|
||||
panic(ElementTypeError{"bson.Value.MaxKey", v.t})
|
||||
}
|
||||
}
|
||||
|
||||
// MaxKeyOK is the same as MaxKey, except it returns a boolean instead of
|
||||
// panicking.
|
||||
func (v Val) MaxKeyOK() bool {
|
||||
return v.t == bsontype.MaxKey
|
||||
}
|
||||
|
||||
// Equal compares v to v2 and returns true if they are equal. Unknown BSON types are
|
||||
// never equal. Two empty values are equal.
|
||||
func (v Val) Equal(v2 Val) bool {
|
||||
if v.Type() != v2.Type() {
|
||||
return false
|
||||
}
|
||||
if v.IsZero() && v2.IsZero() {
|
||||
return true
|
||||
}
|
||||
|
||||
switch v.Type() {
|
||||
case bsontype.Double, bsontype.DateTime, bsontype.Timestamp, bsontype.Int64:
|
||||
return bytes.Equal(v.bootstrap[0:8], v2.bootstrap[0:8])
|
||||
case bsontype.String:
|
||||
return v.string() == v2.string()
|
||||
case bsontype.EmbeddedDocument:
|
||||
return v.equalDocs(v2)
|
||||
case bsontype.Array:
|
||||
return v.Array().Equal(v2.Array())
|
||||
case bsontype.Binary:
|
||||
return v.primitive.(primitive.Binary).Equal(v2.primitive.(primitive.Binary))
|
||||
case bsontype.Undefined:
|
||||
return true
|
||||
case bsontype.ObjectID:
|
||||
return bytes.Equal(v.bootstrap[0:12], v2.bootstrap[0:12])
|
||||
case bsontype.Boolean:
|
||||
return v.bootstrap[0] == v2.bootstrap[0]
|
||||
case bsontype.Null:
|
||||
return true
|
||||
case bsontype.Regex:
|
||||
return v.primitive.(primitive.Regex).Equal(v2.primitive.(primitive.Regex))
|
||||
case bsontype.DBPointer:
|
||||
return v.primitive.(primitive.DBPointer).Equal(v2.primitive.(primitive.DBPointer))
|
||||
case bsontype.JavaScript:
|
||||
return v.JavaScript() == v2.JavaScript()
|
||||
case bsontype.Symbol:
|
||||
return v.Symbol() == v2.Symbol()
|
||||
case bsontype.CodeWithScope:
|
||||
code1, scope1 := v.primitive.(primitive.CodeWithScope).Code, v.primitive.(primitive.CodeWithScope).Scope
|
||||
code2, scope2 := v2.primitive.(primitive.CodeWithScope).Code, v2.primitive.(primitive.CodeWithScope).Scope
|
||||
return code1 == code2 && v.equalInterfaceDocs(scope1, scope2)
|
||||
case bsontype.Int32:
|
||||
return v.Int32() == v2.Int32()
|
||||
case bsontype.Decimal128:
|
||||
h, l := v.Decimal128().GetBytes()
|
||||
h2, l2 := v2.Decimal128().GetBytes()
|
||||
return h == h2 && l == l2
|
||||
case bsontype.MinKey:
|
||||
return true
|
||||
case bsontype.MaxKey:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (v Val) equalDocs(v2 Val) bool {
|
||||
_, ok1 := v.primitive.(MDoc)
|
||||
_, ok2 := v2.primitive.(MDoc)
|
||||
if ok1 || ok2 {
|
||||
return v.asMDoc().Equal(v2.asMDoc())
|
||||
}
|
||||
return v.asDoc().Equal(v2.asDoc())
|
||||
}
|
||||
|
||||
func (Val) equalInterfaceDocs(i, i2 interface{}) bool {
|
||||
switch d := i.(type) {
|
||||
case MDoc:
|
||||
d2, ok := i2.(IDoc)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return d.Equal(d2)
|
||||
case Doc:
|
||||
d2, ok := i2.(IDoc)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return d.Equal(d2)
|
||||
case nil:
|
||||
return i2 == nil
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
271
mongo/x/bsonx/value_test.go
Normal file
271
mongo/x/bsonx/value_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package bsonx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
)
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
longstr := "foobarbazqux, hello, world!"
|
||||
bytestr14 := "fourteen bytes"
|
||||
bin := primitive.Binary{Subtype: 0xFF, Data: []byte{0x01, 0x02, 0x03}}
|
||||
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
|
||||
now := time.Now().Truncate(time.Millisecond)
|
||||
nowdt := now.Unix()*1e3 + int64(now.Nanosecond()/1e6)
|
||||
regex := primitive.Regex{Pattern: "/foobarbaz/", Options: "abr"}
|
||||
dbptr := primitive.DBPointer{DB: "foobar", Pointer: oid}
|
||||
js := "var hello ='world';"
|
||||
symbol := "foobarbaz"
|
||||
cws := primitive.CodeWithScope{Code: primitive.JavaScript(js), Scope: Doc{}}
|
||||
code, scope := js, Doc{}
|
||||
ts := primitive.Timestamp{I: 12345, T: 67890}
|
||||
d128 := primitive.NewDecimal128(12345, 67890)
|
||||
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
name string
|
||||
fn interface{} // method to call
|
||||
ret []interface{} // return value
|
||||
err interface{} // panic result or bool
|
||||
}{
|
||||
{"Interface/Double", Double(3.14159).Interface, []interface{}{float64(3.14159)}, nil},
|
||||
{"Interface/String", String("foo").Interface, []interface{}{"foo"}, nil},
|
||||
{"Interface/Document", Document(Doc{}).Interface, []interface{}{Doc{}}, nil},
|
||||
{"Interface/Array", Array(Arr{}).Interface, []interface{}{Arr{}}, nil},
|
||||
{"Interface/Binary", Binary(bin.Subtype, bin.Data).Interface, []interface{}{bin}, nil},
|
||||
{"Interface/Undefined", Undefined().Interface, []interface{}{primitive.Undefined{}}, nil},
|
||||
{"Interface/Null", Null().Interface, []interface{}{primitive.Null{}}, nil},
|
||||
{"Interface/ObjectID", ObjectID(oid).Interface, []interface{}{oid}, nil},
|
||||
{"Interface/Boolean", Boolean(true).Interface, []interface{}{bool(true)}, nil},
|
||||
{"Interface/DateTime", DateTime(1234567890).Interface, []interface{}{int64(1234567890)}, nil},
|
||||
{"Interface/Time", Time(now).Interface, []interface{}{nowdt}, nil},
|
||||
{"Interface/Regex", Regex(regex.Pattern, regex.Options).Interface, []interface{}{regex}, nil},
|
||||
{"Interface/DBPointer", DBPointer(dbptr.DB, dbptr.Pointer).Interface, []interface{}{dbptr}, nil},
|
||||
{"Interface/JavaScript", JavaScript(js).Interface, []interface{}{js}, nil},
|
||||
{"Interface/Symbol", Symbol(symbol).Interface, []interface{}{symbol}, nil},
|
||||
{"Interface/CodeWithScope", CodeWithScope(string(cws.Code), cws.Scope.(Doc)).Interface, []interface{}{cws}, nil},
|
||||
{"Interface/Int32", Int32(12345).Interface, []interface{}{int32(12345)}, nil},
|
||||
{"Interface/Timestamp", Timestamp(ts.T, ts.I).Interface, []interface{}{ts}, nil},
|
||||
{"Interface/Int64", Int64(1234567890).Interface, []interface{}{int64(1234567890)}, nil},
|
||||
{"Interface/Decimal128", Decimal128(d128).Interface, []interface{}{d128}, nil},
|
||||
{"Interface/MinKey", MinKey().Interface, []interface{}{primitive.MinKey{}}, nil},
|
||||
{"Interface/MaxKey", MaxKey().Interface, []interface{}{primitive.MaxKey{}}, nil},
|
||||
{"Interface/Empty", Val{}.Interface, []interface{}{primitive.Null{}}, nil},
|
||||
{"IsNumber/Double", Double(0).IsNumber, []interface{}{bool(true)}, nil},
|
||||
{"IsNumber/Int32", Int32(0).IsNumber, []interface{}{bool(true)}, nil},
|
||||
{"IsNumber/Int64", Int64(0).IsNumber, []interface{}{bool(true)}, nil},
|
||||
{"IsNumber/Decimal128", Decimal128(primitive.Decimal128{}).IsNumber, []interface{}{bool(true)}, nil},
|
||||
{"IsNumber/String", String("").IsNumber, []interface{}{bool(false)}, nil},
|
||||
{"Double/panic", String("").Double, nil, ElementTypeError{"bson.Value.Double", bsontype.String}},
|
||||
{"Double/success", Double(3.14159).Double, []interface{}{float64(3.14159)}, nil},
|
||||
{"DoubleOK/error", String("").DoubleOK, []interface{}{float64(0), false}, nil},
|
||||
{"DoubleOK/success", Double(3.14159).DoubleOK, []interface{}{float64(3.14159), true}, nil},
|
||||
{"String/panic", Double(0).StringValue, nil, ElementTypeError{"bson.Value.StringValue", bsontype.Double}},
|
||||
{"String/success", String("bar").StringValue, []interface{}{"bar"}, nil},
|
||||
{"String/14bytes", String(bytestr14).StringValue, []interface{}{bytestr14}, nil},
|
||||
{"String/success(long)", String(longstr).StringValue, []interface{}{longstr}, nil},
|
||||
{"StringOK/error", Double(0).StringValueOK, []interface{}{"", false}, nil},
|
||||
{"StringOK/success", String("bar").StringValueOK, []interface{}{"bar", true}, nil},
|
||||
{"Document/panic", Double(0).Document, nil, ElementTypeError{"bson.Value.Document", bsontype.Double}},
|
||||
{"Document/success", Document(Doc{}).Document, []interface{}{Doc{}}, nil},
|
||||
{"DocumentOK/error", Double(0).DocumentOK, []interface{}{(Doc)(nil), false}, nil},
|
||||
{"DocumentOK/success", Document(Doc{}).DocumentOK, []interface{}{Doc{}, true}, nil},
|
||||
{"MDocument/panic", Double(0).MDocument, nil, ElementTypeError{"bson.Value.MDocument", bsontype.Double}},
|
||||
{"MDocument/success", Document(MDoc{}).MDocument, []interface{}{MDoc{}}, nil},
|
||||
{"MDocumentOK/error", Double(0).MDocumentOK, []interface{}{(MDoc)(nil), false}, nil},
|
||||
{"MDocumentOK/success", Document(MDoc{}).MDocumentOK, []interface{}{MDoc{}, true}, nil},
|
||||
{"Document->MDocument/success", Document(Doc{}).MDocument, []interface{}{MDoc{}}, nil},
|
||||
{"MDocument->Document/success", Document(MDoc{}).Document, []interface{}{Doc{}}, nil},
|
||||
{"Document->MDocumentOK/success", Document(Doc{}).MDocumentOK, []interface{}{MDoc{}, true}, nil},
|
||||
{"MDocument->DocumentOK/success", Document(MDoc{}).DocumentOK, []interface{}{Doc{}, true}, nil},
|
||||
{"Array/panic", Double(0).Array, nil, ElementTypeError{"bson.Value.Array", bsontype.Double}},
|
||||
{"Array/success", Array(Arr{}).Array, []interface{}{Arr{}}, nil},
|
||||
{"ArrayOK/error", Double(0).ArrayOK, []interface{}{(Arr)(nil), false}, nil},
|
||||
{"ArrayOK/success", Array(Arr{}).ArrayOK, []interface{}{Arr{}, true}, nil},
|
||||
{"Document/NilDocument", Document((Doc)(nil)).Interface, []interface{}{primitive.Null{}}, nil},
|
||||
{"Array/NilArray", Array((Arr)(nil)).Interface, []interface{}{primitive.Null{}}, nil},
|
||||
{"Document/Nil", Document(nil).Interface, []interface{}{primitive.Null{}}, nil},
|
||||
{"Array/Nil", Array(nil).Interface, []interface{}{primitive.Null{}}, nil},
|
||||
{"Binary/panic", Double(0).Binary, nil, ElementTypeError{"bson.Value.Binary", bsontype.Double}},
|
||||
{"Binary/success", Binary(bin.Subtype, bin.Data).Binary, []interface{}{bin.Subtype, bin.Data}, nil},
|
||||
{"BinaryOK/error", Double(0).BinaryOK, []interface{}{byte(0x00), []byte(nil), false}, nil},
|
||||
{"BinaryOK/success", Binary(bin.Subtype, bin.Data).BinaryOK, []interface{}{bin.Subtype, bin.Data, true}, nil},
|
||||
{"Undefined/panic", Double(0).Undefined, nil, ElementTypeError{"bson.Value.Undefined", bsontype.Double}},
|
||||
{"Undefined/success", Undefined().Undefined, nil, nil},
|
||||
{"UndefinedOK/error", Double(0).UndefinedOK, []interface{}{false}, nil},
|
||||
{"UndefinedOK/success", Undefined().UndefinedOK, []interface{}{true}, nil},
|
||||
{"ObjectID/panic", Double(0).ObjectID, nil, ElementTypeError{"bson.Value.ObjectID", bsontype.Double}},
|
||||
{"ObjectID/success", ObjectID(oid).ObjectID, []interface{}{oid}, nil},
|
||||
{"ObjectIDOK/error", Double(0).ObjectIDOK, []interface{}{primitive.ObjectID{}, false}, nil},
|
||||
{"ObjectIDOK/success", ObjectID(oid).ObjectIDOK, []interface{}{oid, true}, nil},
|
||||
{"Boolean/panic", Double(0).Boolean, nil, ElementTypeError{"bson.Value.Boolean", bsontype.Double}},
|
||||
{"Boolean/success", Boolean(true).Boolean, []interface{}{bool(true)}, nil},
|
||||
{"BooleanOK/error", Double(0).BooleanOK, []interface{}{bool(false), false}, nil},
|
||||
{"BooleanOK/success", Boolean(false).BooleanOK, []interface{}{false, true}, nil},
|
||||
{"DateTime/panic", Double(0).DateTime, nil, ElementTypeError{"bson.Value.DateTime", bsontype.Double}},
|
||||
{"DateTime/success", DateTime(1234567890).DateTime, []interface{}{int64(1234567890)}, nil},
|
||||
{"DateTimeOK/error", Double(0).DateTimeOK, []interface{}{int64(0), false}, nil},
|
||||
{"DateTimeOK/success", DateTime(987654321).DateTimeOK, []interface{}{int64(987654321), true}, nil},
|
||||
{"Time/panic", Double(0).Time, nil, ElementTypeError{"bson.Value.Time", bsontype.Double}},
|
||||
{"Time/success", Time(now).Time, []interface{}{now}, nil},
|
||||
{"TimeOK/error", Double(0).TimeOK, []interface{}{time.Time{}, false}, nil},
|
||||
{"TimeOK/success", Time(now).TimeOK, []interface{}{now, true}, nil},
|
||||
{"Time->DateTime", Time(now).DateTime, []interface{}{nowdt}, nil},
|
||||
{"DateTime->Time", DateTime(nowdt).Time, []interface{}{now}, nil},
|
||||
{"Null/panic", Double(0).Null, nil, ElementTypeError{"bson.Value.Null", bsontype.Double}},
|
||||
{"Null/success", Null().Null, nil, nil},
|
||||
{"NullOK/error", Double(0).NullOK, []interface{}{false}, nil},
|
||||
{"NullOK/success", Null().NullOK, []interface{}{true}, nil},
|
||||
{"Regex/panic", Double(0).Regex, nil, ElementTypeError{"bson.Value.Regex", bsontype.Double}},
|
||||
{"Regex/success", Regex(regex.Pattern, regex.Options).Regex, []interface{}{regex.Pattern, regex.Options}, nil},
|
||||
{"RegexOK/error", Double(0).RegexOK, []interface{}{"", "", false}, nil},
|
||||
{"RegexOK/success", Regex(regex.Pattern, regex.Options).RegexOK, []interface{}{regex.Pattern, regex.Options, true}, nil},
|
||||
{"DBPointer/panic", Double(0).DBPointer, nil, ElementTypeError{"bson.Value.DBPointer", bsontype.Double}},
|
||||
{"DBPointer/success", DBPointer(dbptr.DB, dbptr.Pointer).DBPointer, []interface{}{dbptr.DB, dbptr.Pointer}, nil},
|
||||
{"DBPointerOK/error", Double(0).DBPointerOK, []interface{}{"", primitive.ObjectID{}, false}, nil},
|
||||
{"DBPointerOK/success", DBPointer(dbptr.DB, dbptr.Pointer).DBPointerOK, []interface{}{dbptr.DB, dbptr.Pointer, true}, nil},
|
||||
{"JavaScript/panic", Double(0).JavaScript, nil, ElementTypeError{"bson.Value.JavaScript", bsontype.Double}},
|
||||
{"JavaScript/success", JavaScript(js).JavaScript, []interface{}{js}, nil},
|
||||
{"JavaScriptOK/error", Double(0).JavaScriptOK, []interface{}{"", false}, nil},
|
||||
{"JavaScriptOK/success", JavaScript(js).JavaScriptOK, []interface{}{js, true}, nil},
|
||||
{"Symbol/panic", Double(0).Symbol, nil, ElementTypeError{"bson.Value.Symbol", bsontype.Double}},
|
||||
{"Symbol/success", Symbol(symbol).Symbol, []interface{}{symbol}, nil},
|
||||
{"SymbolOK/error", Double(0).SymbolOK, []interface{}{"", false}, nil},
|
||||
{"SymbolOK/success", Symbol(symbol).SymbolOK, []interface{}{symbol, true}, nil},
|
||||
{"CodeWithScope/panic", Double(0).CodeWithScope, nil, ElementTypeError{"bson.Value.CodeWithScope", bsontype.Double}},
|
||||
{"CodeWithScope/success", CodeWithScope(code, scope).CodeWithScope, []interface{}{code, scope}, nil},
|
||||
{"CodeWithScopeOK/error", Double(0).CodeWithScopeOK, []interface{}{"", (Doc)(nil), false}, nil},
|
||||
{"CodeWithScopeOK/success", CodeWithScope(code, scope).CodeWithScopeOK, []interface{}{code, scope, true}, nil},
|
||||
{"Int32/panic", Double(0).Int32, nil, ElementTypeError{"bson.Value.Int32", bsontype.Double}},
|
||||
{"Int32/success", Int32(12345).Int32, []interface{}{int32(12345)}, nil},
|
||||
{"Int32OK/error", Double(0).Int32OK, []interface{}{int32(0), false}, nil},
|
||||
{"Int32OK/success", Int32(54321).Int32OK, []interface{}{int32(54321), true}, nil},
|
||||
{"Timestamp/panic", Double(0).Timestamp, nil, ElementTypeError{"bson.Value.Timestamp", bsontype.Double}},
|
||||
{"Timestamp/success", Timestamp(ts.T, ts.I).Timestamp, []interface{}{ts.T, ts.I}, nil},
|
||||
{"TimestampOK/error", Double(0).TimestampOK, []interface{}{uint32(0), uint32(0), false}, nil},
|
||||
{"TimestampOK/success", Timestamp(ts.T, ts.I).TimestampOK, []interface{}{ts.T, ts.I, true}, nil},
|
||||
{"Int64/panic", Double(0).Int64, nil, ElementTypeError{"bson.Value.Int64", bsontype.Double}},
|
||||
{"Int64/success", Int64(1234567890).Int64, []interface{}{int64(1234567890)}, nil},
|
||||
{"Int64OK/error", Double(0).Int64OK, []interface{}{int64(0), false}, nil},
|
||||
{"Int64OK/success", Int64(9876543210).Int64OK, []interface{}{int64(9876543210), true}, nil},
|
||||
{"Decimal128/panic", Double(0).Decimal128, nil, ElementTypeError{"bson.Value.Decimal128", bsontype.Double}},
|
||||
{"Decimal128/success", Decimal128(d128).Decimal128, []interface{}{d128}, nil},
|
||||
{"Decimal128OK/error", Double(0).Decimal128OK, []interface{}{primitive.Decimal128{}, false}, nil},
|
||||
{"Decimal128OK/success", Decimal128(d128).Decimal128OK, []interface{}{d128, true}, nil},
|
||||
{"MinKey/panic", Double(0).MinKey, nil, ElementTypeError{"bson.Value.MinKey", bsontype.Double}},
|
||||
{"MinKey/success", MinKey().MinKey, nil, nil},
|
||||
{"MinKeyOK/error", Double(0).MinKeyOK, []interface{}{false}, nil},
|
||||
{"MinKeyOK/success", MinKey().MinKeyOK, []interface{}{true}, nil},
|
||||
{"MaxKey/panic", Double(0).MaxKey, nil, ElementTypeError{"bson.Value.MaxKey", bsontype.Double}},
|
||||
{"MaxKey/success", MaxKey().MaxKey, nil, nil},
|
||||
{"MaxKeyOK/error", Double(0).MaxKeyOK, []interface{}{false}, nil},
|
||||
{"MaxKeyOK/success", MaxKey().MaxKeyOK, []interface{}{true}, nil},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
defer func() {
|
||||
err := recover()
|
||||
if err != nil && !cmp.Equal(err, tc.err) {
|
||||
t.Errorf("panic errors are not equal. got %v; want %v", err, tc.err)
|
||||
if tc.err == nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
fn := reflect.ValueOf(tc.fn)
|
||||
if fn.Kind() != reflect.Func {
|
||||
t.Fatalf("fn must be a function, but is a %s", fn.Kind())
|
||||
}
|
||||
ret := fn.Call(nil)
|
||||
if len(ret) != len(tc.ret) {
|
||||
t.Fatalf("number of returned values does not match. got %d; want %d", len(ret), len(tc.ret))
|
||||
}
|
||||
|
||||
for idx := range ret {
|
||||
got, want := ret[idx].Interface(), tc.ret[idx]
|
||||
if !cmp.Equal(got, want, cmp.Comparer(compareDecimal128)) {
|
||||
t.Errorf("Return %d does not match. got %v; want %v", idx, got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Equal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
v1 Val
|
||||
v2 Val
|
||||
res bool
|
||||
}{
|
||||
{"Different Types", String(""), Double(0), false},
|
||||
{"Unknown Types", Val{t: bsontype.Type(0x77)}, Val{t: bsontype.Type(0x77)}, false},
|
||||
{"Empty Types", Val{}, Val{}, true},
|
||||
{"Double/Equal", Double(3.14159), Double(3.14159), true},
|
||||
{"Double/Not Equal", Double(3.14159), Double(9.51413), false},
|
||||
{"DateTime/Equal", DateTime(nowdt), DateTime(nowdt), true},
|
||||
{"DateTime/Not Equal", DateTime(nowdt), DateTime(0), false},
|
||||
{"String/Equal", String("hello"), String("hello"), true},
|
||||
{"String/Not Equal", String("hello"), String("world"), false},
|
||||
{"Document/Equal", Document(Doc{}), Document(Doc{}), true},
|
||||
{"Document/Not Equal", Document(Doc{}), Document(Doc{{"", Null()}}), false},
|
||||
{"Array/Equal", Array(Arr{}), Array(Arr{}), true},
|
||||
{"Array/Not Equal", Array(Arr{}), Array(Arr{Null()}), false},
|
||||
{"Binary/Equal", Binary(bin.Subtype, bin.Data), Binary(bin.Subtype, bin.Data), true},
|
||||
{"Binary/Not Equal", Binary(bin.Subtype, bin.Data), Binary(0x00, nil), false},
|
||||
{"Undefined/Equal", Undefined(), Undefined(), true},
|
||||
{"ObjectID/Equal", ObjectID(oid), ObjectID(oid), true},
|
||||
{"ObjectID/Not Equal", ObjectID(oid), ObjectID(primitive.ObjectID{}), false},
|
||||
{"Boolean/Equal", Boolean(true), Boolean(true), true},
|
||||
{"Boolean/Not Equal", Boolean(true), Boolean(false), false},
|
||||
{"Null/Equal", Null(), Null(), true},
|
||||
{"Regex/Equal", Regex(regex.Pattern, regex.Options), Regex(regex.Pattern, regex.Options), true},
|
||||
{"Regex/Not Equal", Regex(regex.Pattern, regex.Options), Regex("", ""), false},
|
||||
{"DBPointer/Equal", DBPointer(dbptr.DB, dbptr.Pointer), DBPointer(dbptr.DB, dbptr.Pointer), true},
|
||||
{"DBPointer/Not Equal", DBPointer(dbptr.DB, dbptr.Pointer), DBPointer("", primitive.ObjectID{}), false},
|
||||
{"JavaScript/Equal", JavaScript(js), JavaScript(js), true},
|
||||
{"JavaScript/Not Equal", JavaScript(js), JavaScript(""), false},
|
||||
{"Symbol/Equal", Symbol(symbol), Symbol(symbol), true},
|
||||
{"Symbol/Not Equal", Symbol(symbol), Symbol(""), false},
|
||||
{"CodeWithScope/Equal", CodeWithScope(code, scope), CodeWithScope(code, scope), true},
|
||||
{"CodeWithScope/Equal (equal scope)", CodeWithScope(code, scope), CodeWithScope(code, Doc{}), true},
|
||||
{"CodeWithScope/Not Equal", CodeWithScope(code, scope), CodeWithScope("", nil), false},
|
||||
{"Int32/Equal", Int32(12345), Int32(12345), true},
|
||||
{"Int32/Not Equal", Int32(12345), Int32(54321), false},
|
||||
{"Timestamp/Equal", Timestamp(ts.T, ts.I), Timestamp(ts.T, ts.I), true},
|
||||
{"Timestamp/Not Equal", Timestamp(ts.T, ts.I), Timestamp(0, 0), false},
|
||||
{"Int64/Equal", Int64(1234567890), Int64(1234567890), true},
|
||||
{"Int64/Not Equal", Int64(1234567890), Int64(9876543210), false},
|
||||
{"Decimal128/Equal", Decimal128(d128), Decimal128(d128), true},
|
||||
{"Decimal128/Not Equal", Decimal128(d128), Decimal128(primitive.Decimal128{}), false},
|
||||
{"MinKey/Equal", MinKey(), MinKey(), true},
|
||||
{"MaxKey/Equal", MaxKey(), MaxKey(), true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
res := tc.v1.Equal(tc.v2)
|
||||
if res != tc.res {
|
||||
t.Errorf("results do not match. got %v; want %v", res, tc.res)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
23
mongo/x/mongo/driver/DESIGN.md
Normal file
23
mongo/x/mongo/driver/DESIGN.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Driver Library Design
|
||||
This document outlines the design for this package.
|
||||
|
||||
## Deployment, Server, and Connection
|
||||
Acquiring a `Connection` from a `Server` selected from a `Deployment` enables sending and receiving
|
||||
wire messages. A `Deployment` represents an set of MongoDB servers and a `Server` represents a
|
||||
member of that set. These three types form the operation execution stack.
|
||||
|
||||
### Compression
|
||||
Compression is handled by Connection type while uncompression is handled automatically by the
|
||||
Operation type. This is done because the compressor to use for compressing a wire message is
|
||||
chosen by the connection during handshake, while uncompression can be performed without this
|
||||
information. This does make the design of compression non-symmetric, but it makes the design simpler
|
||||
to implement and more consistent.
|
||||
|
||||
## Operation
|
||||
The `Operation` type handles executing a series of commands using a `Deployment`. For most uses
|
||||
`Operation` will only execute a single command, but the main use case for a series of commands is
|
||||
batch split write commands, such as insert. The type itself is heavily documented, so reading the
|
||||
code and comments together should provide an understanding of how the type works.
|
||||
|
||||
This type is not meant to be used directly by callers. Instead a wrapping type should be defined
|
||||
using the IDL.
|
||||
229
mongo/x/mongo/driver/auth/auth.go
Normal file
229
mongo/x/mongo/driver/auth/auth.go
Normal file
@@ -0,0 +1,229 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
// AuthenticatorFactory constructs an authenticator.
|
||||
type AuthenticatorFactory func(cred *Cred) (Authenticator, error)
|
||||
|
||||
var authFactories = make(map[string]AuthenticatorFactory)
|
||||
|
||||
func init() {
|
||||
RegisterAuthenticatorFactory("", newDefaultAuthenticator)
|
||||
RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
|
||||
RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
|
||||
RegisterAuthenticatorFactory(MONGODBCR, newMongoDBCRAuthenticator)
|
||||
RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
|
||||
RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
|
||||
RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
|
||||
RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator)
|
||||
}
|
||||
|
||||
// CreateAuthenticator creates an authenticator.
|
||||
func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
|
||||
if f, ok := authFactories[name]; ok {
|
||||
return f(cred)
|
||||
}
|
||||
|
||||
return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
|
||||
}
|
||||
|
||||
// RegisterAuthenticatorFactory registers the authenticator factory.
|
||||
func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
|
||||
authFactories[name] = factory
|
||||
}
|
||||
|
||||
// HandshakeOptions packages options that can be passed to the Handshaker()
|
||||
// function. DBUser is optional but must be of the form <dbname.username>;
|
||||
// if non-empty, then the connection will do SASL mechanism negotiation.
|
||||
type HandshakeOptions struct {
|
||||
AppName string
|
||||
Authenticator Authenticator
|
||||
Compressors []string
|
||||
DBUser string
|
||||
PerformAuthentication func(description.Server) bool
|
||||
ClusterClock *session.ClusterClock
|
||||
ServerAPI *driver.ServerAPIOptions
|
||||
LoadBalanced bool
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type authHandshaker struct {
|
||||
wrapped driver.Handshaker
|
||||
options *HandshakeOptions
|
||||
|
||||
handshakeInfo driver.HandshakeInformation
|
||||
conversation SpeculativeConversation
|
||||
}
|
||||
|
||||
var _ driver.Handshaker = (*authHandshaker)(nil)
|
||||
|
||||
// GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided
|
||||
// connection.
|
||||
func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
|
||||
if ah.wrapped != nil {
|
||||
return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
|
||||
}
|
||||
|
||||
op := operation.NewHello().
|
||||
AppName(ah.options.AppName).
|
||||
Compressors(ah.options.Compressors).
|
||||
SASLSupportedMechs(ah.options.DBUser).
|
||||
ClusterClock(ah.options.ClusterClock).
|
||||
ServerAPI(ah.options.ServerAPI).
|
||||
LoadBalanced(ah.options.LoadBalanced)
|
||||
|
||||
if ah.options.Authenticator != nil {
|
||||
if speculativeAuth, ok := ah.options.Authenticator.(SpeculativeAuthenticator); ok {
|
||||
var err error
|
||||
ah.conversation, err = speculativeAuth.CreateSpeculativeConversation()
|
||||
if err != nil {
|
||||
return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err)
|
||||
}
|
||||
|
||||
firstMsg, err := ah.conversation.FirstMessage()
|
||||
if err != nil {
|
||||
return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err)
|
||||
}
|
||||
|
||||
op = op.SpeculativeAuthenticate(firstMsg)
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn)
|
||||
if err != nil {
|
||||
return driver.HandshakeInformation{}, newAuthError("handshake failure", err)
|
||||
}
|
||||
return ah.handshakeInfo, nil
|
||||
}
|
||||
|
||||
// FinishHandshake performs authentication for conn if necessary.
|
||||
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
|
||||
performAuth := ah.options.PerformAuthentication
|
||||
if performAuth == nil {
|
||||
performAuth = func(serv description.Server) bool {
|
||||
// Authentication is possible against all server types except arbiters
|
||||
return serv.Kind != description.RSArbiter
|
||||
}
|
||||
}
|
||||
|
||||
desc := conn.Description()
|
||||
if performAuth(desc) && ah.options.Authenticator != nil {
|
||||
cfg := &Config{
|
||||
Description: desc,
|
||||
Connection: conn,
|
||||
ClusterClock: ah.options.ClusterClock,
|
||||
HandshakeInfo: ah.handshakeInfo,
|
||||
ServerAPI: ah.options.ServerAPI,
|
||||
HTTPClient: ah.options.HTTPClient,
|
||||
}
|
||||
|
||||
if err := ah.authenticate(ctx, cfg); err != nil {
|
||||
return newAuthError("auth error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if ah.wrapped == nil {
|
||||
return nil
|
||||
}
|
||||
return ah.wrapped.FinishHandshake(ctx, conn)
|
||||
}
|
||||
|
||||
func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error {
|
||||
// If the initial hello reply included a response to the speculative authentication attempt, we only need to
|
||||
// conduct the remainder of the conversation.
|
||||
if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil {
|
||||
// Defensively ensure that the server did not include a response if speculative auth was not attempted.
|
||||
if ah.conversation == nil {
|
||||
return errors.New("speculative auth was not attempted but the server included a response")
|
||||
}
|
||||
return ah.conversation.Finish(ctx, cfg, speculativeResponse)
|
||||
}
|
||||
|
||||
// If the server does not support speculative authentication or the first attempt was not successful, we need to
|
||||
// perform authentication from scratch.
|
||||
return ah.options.Authenticator.Auth(ctx, cfg)
|
||||
}
|
||||
|
||||
// Handshaker creates a connection handshaker for the given authenticator.
|
||||
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
|
||||
return &authHandshaker{
|
||||
wrapped: h,
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// Config holds the information necessary to perform an authentication attempt.
|
||||
type Config struct {
|
||||
Description description.Server
|
||||
Connection driver.Connection
|
||||
ClusterClock *session.ClusterClock
|
||||
HandshakeInfo driver.HandshakeInformation
|
||||
ServerAPI *driver.ServerAPIOptions
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// Authenticator handles authenticating a connection.
|
||||
type Authenticator interface {
|
||||
// Auth authenticates the connection.
|
||||
Auth(context.Context, *Config) error
|
||||
}
|
||||
|
||||
func newAuthError(msg string, inner error) error {
|
||||
return &Error{
|
||||
message: msg,
|
||||
inner: inner,
|
||||
}
|
||||
}
|
||||
|
||||
func newError(err error, mech string) error {
|
||||
return &Error{
|
||||
message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
|
||||
inner: err,
|
||||
}
|
||||
}
|
||||
|
||||
// Error is an error that occurred during authentication.
|
||||
type Error struct {
|
||||
message string
|
||||
inner error
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.inner == nil {
|
||||
return e.message
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.message, e.inner)
|
||||
}
|
||||
|
||||
// Inner returns the wrapped error.
|
||||
func (e *Error) Inner() error {
|
||||
return e.inner
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (e *Error) Unwrap() error {
|
||||
return e.inner
|
||||
}
|
||||
|
||||
// Message returns the message.
|
||||
func (e *Error) Message() string {
|
||||
return e.message
|
||||
}
|
||||
111
mongo/x/mongo/driver/auth/auth_spec_test.go
Normal file
111
mongo/x/mongo/driver/auth/auth_spec_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type credential struct {
|
||||
Username string
|
||||
Password *string
|
||||
Source string
|
||||
Mechanism string
|
||||
MechProps map[string]interface{} `json:"mechanism_properties"`
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
Description string
|
||||
URI string
|
||||
Valid bool
|
||||
Credential *credential
|
||||
}
|
||||
|
||||
type testContainer struct {
|
||||
Tests []testCase
|
||||
}
|
||||
|
||||
// Note a test supporting the deprecated gssapiServiceName property was removed from data/auth/auth_tests.json
|
||||
const authTestsDir = "../../../../testdata/auth/"
|
||||
|
||||
func runTestsInFile(t *testing.T, dirname string, filename string) {
|
||||
filepath := path.Join(dirname, filename)
|
||||
content, err := ioutil.ReadFile(filepath)
|
||||
require.NoError(t, err)
|
||||
|
||||
var container testContainer
|
||||
require.NoError(t, json.Unmarshal(content, &container))
|
||||
|
||||
// Remove ".json" from filename.
|
||||
filename = filename[:len(filename)-5]
|
||||
|
||||
for _, testCase := range container.Tests {
|
||||
runTest(t, filename, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
func runTest(t *testing.T, filename string, test testCase) {
|
||||
t.Run(filename+":"+test.Description, func(t *testing.T) {
|
||||
opts := options.Client().ApplyURI(test.URI)
|
||||
if test.Valid {
|
||||
require.NoError(t, opts.Validate())
|
||||
} else {
|
||||
require.Error(t, opts.Validate())
|
||||
return
|
||||
}
|
||||
|
||||
if test.Credential == nil {
|
||||
require.Nil(t, opts.Auth)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, opts.Auth)
|
||||
require.Equal(t, test.Credential.Username, opts.Auth.Username)
|
||||
|
||||
if test.Credential.Password == nil {
|
||||
require.False(t, opts.Auth.PasswordSet)
|
||||
} else {
|
||||
require.True(t, opts.Auth.PasswordSet)
|
||||
require.Equal(t, *test.Credential.Password, opts.Auth.Password)
|
||||
}
|
||||
|
||||
require.Equal(t, test.Credential.Source, opts.Auth.AuthSource)
|
||||
|
||||
require.Equal(t, test.Credential.Mechanism, opts.Auth.AuthMechanism)
|
||||
|
||||
if len(test.Credential.MechProps) > 0 {
|
||||
require.Equal(t, mapInterfaceToString(test.Credential.MechProps), opts.Auth.AuthMechanismProperties)
|
||||
} else {
|
||||
require.Equal(t, 0, len(opts.Auth.AuthMechanismProperties))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Convert each interface{} value in the map to a string.
|
||||
func mapInterfaceToString(m map[string]interface{}) map[string]string {
|
||||
out := make(map[string]string)
|
||||
|
||||
for key, value := range m {
|
||||
out[key] = fmt.Sprint(value)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Test case for all connection string spec tests.
|
||||
func TestAuthSpec(t *testing.T) {
|
||||
for _, file := range helpers.FindJSONFilesInDir(t, authTestsDir) {
|
||||
runTestsInFile(t, authTestsDir, file)
|
||||
}
|
||||
}
|
||||
126
mongo/x/mongo/driver/auth/auth_test.go
Normal file
126
mongo/x/mongo/driver/auth/auth_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||||
)
|
||||
|
||||
func TestCreateAuthenticator(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
source string
|
||||
auth Authenticator
|
||||
}{
|
||||
{name: "", auth: &DefaultAuthenticator{}},
|
||||
{name: "SCRAM-SHA-1", auth: &ScramAuthenticator{}},
|
||||
{name: "SCRAM-SHA-256", auth: &ScramAuthenticator{}},
|
||||
{name: "MONGODB-CR", auth: &MongoDBCRAuthenticator{}},
|
||||
{name: "PLAIN", auth: &PlainAuthenticator{}},
|
||||
{name: "MONGODB-X509", auth: &MongoDBX509Authenticator{}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cred := &Cred{
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
PasswordSet: true,
|
||||
}
|
||||
|
||||
a, err := CreateAuthenticator(test.name, cred)
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, test.auth, a)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document, dbName string) {
|
||||
_, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
var actualPayload bsoncore.Document
|
||||
switch opcode {
|
||||
case wiremessage.OpQuery:
|
||||
_, wm, ok := wiremessage.ReadQueryFlags(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
actualPayload, _, ok = wiremessage.ReadQueryQuery(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
case wiremessage.OpMsg:
|
||||
// Append the $db field.
|
||||
elems, err := expectedPayload.Elements()
|
||||
if err != nil {
|
||||
t.Fatalf("expectedPayload is not valid: %v", err)
|
||||
}
|
||||
elems = append(elems, bsoncore.AppendStringElement(nil, "$db", dbName))
|
||||
elems = append(elems, bsoncore.AppendDocumentElement(nil,
|
||||
"$readPreference",
|
||||
bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred")),
|
||||
))
|
||||
bslc := make([][]byte, 0, len(elems)) // BuildDocumentFromElements takes a [][]byte, not a []bsoncore.Element.
|
||||
for _, elem := range elems {
|
||||
bslc = append(bslc, elem)
|
||||
}
|
||||
expectedPayload = bsoncore.BuildDocumentFromElements(nil, bslc...)
|
||||
|
||||
_, wm, ok := wiremessage.ReadMsgFlags(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
loop:
|
||||
for {
|
||||
var stype wiremessage.SectionType
|
||||
stype, wm, ok = wiremessage.ReadMsgSectionType(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
break
|
||||
}
|
||||
switch stype {
|
||||
case wiremessage.DocumentSequence:
|
||||
_, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
break loop
|
||||
}
|
||||
case wiremessage.SingleDocument:
|
||||
actualPayload, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm)
|
||||
if !ok {
|
||||
t.Fatalf("wiremessage is too short to unmarshal")
|
||||
}
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !cmp.Equal(actualPayload, expectedPayload) {
|
||||
t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload)
|
||||
}
|
||||
}
|
||||
348
mongo/x/mongo/driver/auth/aws_conv.go
Normal file
348
mongo/x/mongo/driver/auth/aws_conv.go
Normal file
@@ -0,0 +1,348 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4"
|
||||
)
|
||||
|
||||
type clientState int
|
||||
|
||||
const (
|
||||
clientStarting clientState = iota
|
||||
clientFirst
|
||||
clientFinal
|
||||
clientDone
|
||||
)
|
||||
|
||||
type awsConversation struct {
|
||||
state clientState
|
||||
valid bool
|
||||
nonce []byte
|
||||
username string
|
||||
password string
|
||||
token string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
type serverMessage struct {
|
||||
Nonce primitive.Binary `bson:"s"`
|
||||
Host string `bson:"h"`
|
||||
}
|
||||
|
||||
type ecsResponse struct {
|
||||
AccessKeyID string `json:"AccessKeyId"`
|
||||
SecretAccessKey string `json:"SecretAccessKey"`
|
||||
Token string `json:"Token"`
|
||||
}
|
||||
|
||||
const (
|
||||
amzDateFormat = "20060102T150405Z"
|
||||
awsRelativeURI = "http://169.254.170.2/"
|
||||
awsEC2URI = "http://169.254.169.254/"
|
||||
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
|
||||
awsEC2TokenPath = "latest/api/token"
|
||||
defaultRegion = "us-east-1"
|
||||
maxHostLength = 255
|
||||
defaultHTTPTimeout = 10 * time.Second
|
||||
responceNonceLength = 64
|
||||
)
|
||||
|
||||
// Step takes a string provided from a server (or just an empty string for the
|
||||
// very first conversation step) and attempts to move the authentication
|
||||
// conversation forward. It returns a string to be sent to the server or an
|
||||
// error if the server message is invalid. Calling Step after a conversation
|
||||
// completes is also an error.
|
||||
func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
|
||||
switch ac.state {
|
||||
case clientStarting:
|
||||
ac.state = clientFirst
|
||||
response = ac.firstMsg()
|
||||
case clientFirst:
|
||||
ac.state = clientFinal
|
||||
response, err = ac.finalMsg(challenge)
|
||||
case clientFinal:
|
||||
ac.state = clientDone
|
||||
ac.valid = true
|
||||
default:
|
||||
response, err = nil, errors.New("Conversation already completed")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Done returns true if the conversation is completed or has errored.
|
||||
func (ac *awsConversation) Done() bool {
|
||||
return ac.state == clientDone
|
||||
}
|
||||
|
||||
// Valid returns true if the conversation successfully authenticated with the
|
||||
// server, including counter-validation that the server actually has the
|
||||
// user's stored credentials.
|
||||
func (ac *awsConversation) Valid() bool {
|
||||
return ac.valid
|
||||
}
|
||||
|
||||
func getRegion(host string) (string, error) {
|
||||
region := defaultRegion
|
||||
|
||||
if len(host) == 0 {
|
||||
return "", errors.New("invalid STS host: empty")
|
||||
}
|
||||
if len(host) > maxHostLength {
|
||||
return "", errors.New("invalid STS host: too large")
|
||||
}
|
||||
// The implicit region for sts.amazonaws.com is us-east-1
|
||||
if host == "sts.amazonaws.com" {
|
||||
return region, nil
|
||||
}
|
||||
if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
|
||||
return "", errors.New("invalid STS host: empty part")
|
||||
}
|
||||
|
||||
// If the host has multiple parts, the second part is the region
|
||||
parts := strings.Split(host, ".")
|
||||
if len(parts) >= 2 {
|
||||
region = parts[1]
|
||||
}
|
||||
|
||||
return region, nil
|
||||
}
|
||||
|
||||
func (ac *awsConversation) validateAndMakeCredentials() (*awsv4.StaticProvider, error) {
|
||||
if ac.username != "" && ac.password == "" {
|
||||
return nil, errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
|
||||
}
|
||||
if ac.username == "" && ac.password != "" {
|
||||
return nil, errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
|
||||
}
|
||||
if ac.username == "" && ac.password == "" && ac.token != "" {
|
||||
return nil, errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
|
||||
}
|
||||
if ac.username != "" || ac.password != "" || ac.token != "" {
|
||||
return &awsv4.StaticProvider{Value: awsv4.Value{
|
||||
AccessKeyID: ac.username,
|
||||
SecretAccessKey: ac.password,
|
||||
SessionToken: ac.token,
|
||||
}}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func executeAWSHTTPRequest(httpClient *http.Client, req *http.Request) ([]byte, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultHTTPTimeout)
|
||||
defer cancel()
|
||||
resp, err := httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return ioutil.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (ac *awsConversation) getEC2Credentials() (*awsv4.StaticProvider, error) {
|
||||
// get token
|
||||
req, err := http.NewRequest("PUT", awsEC2URI+awsEC2TokenPath, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "30")
|
||||
|
||||
token, err := executeAWSHTTPRequest(ac.httpClient, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(token) == 0 {
|
||||
return nil, errors.New("unable to retrieve token from EC2 metadata")
|
||||
}
|
||||
tokenStr := string(token)
|
||||
|
||||
// get role name
|
||||
req, err = http.NewRequest("GET", awsEC2URI+awsEC2RolePath, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
|
||||
|
||||
role, err := executeAWSHTTPRequest(ac.httpClient, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(role) == 0 {
|
||||
return nil, errors.New("unable to retrieve role_name from EC2 metadata")
|
||||
}
|
||||
|
||||
// get credentials
|
||||
pathWithRole := awsEC2URI + awsEC2RolePath + string(role)
|
||||
req, err = http.NewRequest("GET", pathWithRole, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
|
||||
creds, err := executeAWSHTTPRequest(ac.httpClient, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var es2Resp ecsResponse
|
||||
err = json.Unmarshal(creds, &es2Resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ac.username = es2Resp.AccessKeyID
|
||||
ac.password = es2Resp.SecretAccessKey
|
||||
ac.token = es2Resp.Token
|
||||
|
||||
return ac.validateAndMakeCredentials()
|
||||
}
|
||||
|
||||
func (ac *awsConversation) getCredentials() (*awsv4.StaticProvider, error) {
|
||||
// Credentials passed through URI
|
||||
creds, err := ac.validateAndMakeCredentials()
|
||||
if creds != nil || err != nil {
|
||||
return creds, err
|
||||
}
|
||||
|
||||
// Credentials from environment variables
|
||||
ac.username = os.Getenv("AWS_ACCESS_KEY_ID")
|
||||
ac.password = os.Getenv("AWS_SECRET_ACCESS_KEY")
|
||||
ac.token = os.Getenv("AWS_SESSION_TOKEN")
|
||||
|
||||
creds, err = ac.validateAndMakeCredentials()
|
||||
if creds != nil || err != nil {
|
||||
return creds, err
|
||||
}
|
||||
|
||||
// Credentials from ECS metadata
|
||||
relativeEcsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
|
||||
if len(relativeEcsURI) > 0 {
|
||||
fullURI := awsRelativeURI + relativeEcsURI
|
||||
|
||||
req, err := http.NewRequest("GET", fullURI, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := executeAWSHTTPRequest(ac.httpClient, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var espResp ecsResponse
|
||||
err = json.Unmarshal(body, &espResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ac.username = espResp.AccessKeyID
|
||||
ac.password = espResp.SecretAccessKey
|
||||
ac.token = espResp.Token
|
||||
|
||||
creds, err = ac.validateAndMakeCredentials()
|
||||
if creds != nil || err != nil {
|
||||
return creds, err
|
||||
}
|
||||
}
|
||||
|
||||
// Credentials from EC2 metadata
|
||||
creds, err = ac.getEC2Credentials()
|
||||
if creds == nil && err == nil {
|
||||
return nil, errors.New("unable to get credentials")
|
||||
}
|
||||
return creds, err
|
||||
}
|
||||
|
||||
func (ac *awsConversation) firstMsg() []byte {
|
||||
// Values are cached for use in final message parameters
|
||||
ac.nonce = make([]byte, 32)
|
||||
_, _ = rand.Read(ac.nonce)
|
||||
|
||||
idx, msg := bsoncore.AppendDocumentStart(nil)
|
||||
msg = bsoncore.AppendInt32Element(msg, "p", 110)
|
||||
msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
|
||||
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
|
||||
return msg
|
||||
}
|
||||
|
||||
func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
|
||||
var sm serverMessage
|
||||
err := bson.Unmarshal(s1, &sm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check nonce prefix
|
||||
if sm.Nonce.Subtype != 0x00 {
|
||||
return nil, errors.New("server reply contained unexpected binary subtype")
|
||||
}
|
||||
if len(sm.Nonce.Data) != responceNonceLength {
|
||||
return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
|
||||
}
|
||||
if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
|
||||
return nil, errors.New("server nonce did not extend client nonce")
|
||||
}
|
||||
|
||||
region, err := getRegion(sm.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
creds, err := ac.getCredentials()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentTime := time.Now().UTC()
|
||||
body := "Action=GetCallerIdentity&Version=2011-06-15"
|
||||
|
||||
// Create http.Request
|
||||
req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Content-Length", "43")
|
||||
req.Host = sm.Host
|
||||
req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
|
||||
if len(ac.token) > 0 {
|
||||
req.Header.Set("X-Amz-Security-Token", ac.token)
|
||||
}
|
||||
req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
|
||||
req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
|
||||
|
||||
// Create signer with credentials
|
||||
signer := awsv4.NewSigner(creds)
|
||||
|
||||
// Get signed header
|
||||
_, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// create message
|
||||
idx, msg := bsoncore.AppendDocumentStart(nil)
|
||||
msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
|
||||
msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
|
||||
if len(ac.token) > 0 {
|
||||
msg = bsoncore.AppendStringElement(msg, "t", ac.token)
|
||||
}
|
||||
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
31
mongo/x/mongo/driver/auth/conversation.go
Normal file
31
mongo/x/mongo/driver/auth/conversation.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// SpeculativeConversation represents an authentication conversation that can be merged with the initial connection
|
||||
// handshake.
|
||||
//
|
||||
// FirstMessage method returns the first message to be sent to the server. This message will be included in the initial
|
||||
// hello command.
|
||||
//
|
||||
// Finish takes the server response to the initial message and conducts the remainder of the conversation to
|
||||
// authenticate the provided connection.
|
||||
type SpeculativeConversation interface {
|
||||
FirstMessage() (bsoncore.Document, error)
|
||||
Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error
|
||||
}
|
||||
|
||||
// SpeculativeAuthenticator represents an authenticator that supports speculative authentication.
|
||||
type SpeculativeAuthenticator interface {
|
||||
CreateSpeculativeConversation() (SpeculativeConversation, error)
|
||||
}
|
||||
16
mongo/x/mongo/driver/auth/cred.go
Normal file
16
mongo/x/mongo/driver/auth/cred.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
// Cred is a user's credential.
|
||||
type Cred struct {
|
||||
Source string
|
||||
Username string
|
||||
Password string
|
||||
PasswordSet bool
|
||||
Props map[string]string
|
||||
}
|
||||
98
mongo/x/mongo/driver/auth/default.go
Normal file
98
mongo/x/mongo/driver/auth/default.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
)
|
||||
|
||||
func newDefaultAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
scram, err := newScramSHA256Authenticator(cred)
|
||||
if err != nil {
|
||||
return nil, newAuthError("failed to create internal authenticator", err)
|
||||
}
|
||||
speculative, ok := scram.(SpeculativeAuthenticator)
|
||||
if !ok {
|
||||
typeErr := fmt.Errorf("expected SCRAM authenticator to be SpeculativeAuthenticator but got %T", scram)
|
||||
return nil, newAuthError("failed to create internal authenticator", typeErr)
|
||||
}
|
||||
|
||||
return &DefaultAuthenticator{
|
||||
Cred: cred,
|
||||
speculativeAuthenticator: speculative,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DefaultAuthenticator uses SCRAM-SHA-1 or MONGODB-CR depending
|
||||
// on the server version.
|
||||
type DefaultAuthenticator struct {
|
||||
Cred *Cred
|
||||
|
||||
// The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing
|
||||
// the initial hello, SCRAM-SHA-256 is used for the speculative attempt.
|
||||
speculativeAuthenticator SpeculativeAuthenticator
|
||||
}
|
||||
|
||||
var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil)
|
||||
|
||||
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
|
||||
func (a *DefaultAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
|
||||
return a.speculativeAuthenticator.CreateSpeculativeConversation()
|
||||
}
|
||||
|
||||
// Auth authenticates the connection.
|
||||
func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
var actual Authenticator
|
||||
var err error
|
||||
|
||||
switch chooseAuthMechanism(cfg) {
|
||||
case SCRAMSHA256:
|
||||
actual, err = newScramSHA256Authenticator(a.Cred)
|
||||
case SCRAMSHA1:
|
||||
actual, err = newScramSHA1Authenticator(a.Cred)
|
||||
default:
|
||||
actual, err = newMongoDBCRAuthenticator(a.Cred)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return newAuthError("error creating authenticator", err)
|
||||
}
|
||||
|
||||
return actual.Auth(ctx, cfg)
|
||||
}
|
||||
|
||||
// If a server provides a list of supported mechanisms, we choose
|
||||
// SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1.
|
||||
// Otherwise, we decide based on what is supported.
|
||||
func chooseAuthMechanism(cfg *Config) string {
|
||||
if saslSupportedMechs := cfg.HandshakeInfo.SaslSupportedMechs; saslSupportedMechs != nil {
|
||||
for _, v := range saslSupportedMechs {
|
||||
if v == SCRAMSHA256 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return SCRAMSHA1
|
||||
}
|
||||
|
||||
if err := scramSHA1Supported(cfg.HandshakeInfo.Description.WireVersion); err == nil {
|
||||
return SCRAMSHA1
|
||||
}
|
||||
|
||||
return MONGODBCR
|
||||
}
|
||||
|
||||
// scramSHA1Supported returns an error if the given server version does not support scram-sha-1.
|
||||
func scramSHA1Supported(wireVersion *description.VersionRange) error {
|
||||
if wireVersion != nil && wireVersion.Max < 3 {
|
||||
return fmt.Errorf("SCRAM-SHA-1 is only supported for servers 3.0 or newer")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
23
mongo/x/mongo/driver/auth/doc.go
Normal file
23
mongo/x/mongo/driver/auth/doc.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Package auth is not for public use.
|
||||
//
|
||||
// The API for packages in the 'private' directory have no stability
|
||||
// guarantee.
|
||||
//
|
||||
// The packages within the 'private' directory would normally be put into an
|
||||
// 'internal' directory to prohibit their use outside the 'mongo' directory.
|
||||
// However, some MongoDB tools require very low-level access to the building
|
||||
// blocks of a driver, so we have placed them under 'private' to allow these
|
||||
// packages to be imported by projects that need them.
|
||||
//
|
||||
// These package APIs may be modified in backwards-incompatible ways at any
|
||||
// time.
|
||||
//
|
||||
// You are strongly discouraged from directly using any packages
|
||||
// under 'private'.
|
||||
package auth
|
||||
59
mongo/x/mongo/driver/auth/gssapi.go
Normal file
59
mongo/x/mongo/driver/auth/gssapi.go
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build gssapi && (windows || linux || darwin)
|
||||
// +build gssapi
|
||||
// +build windows linux darwin
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi"
|
||||
)
|
||||
|
||||
// GSSAPI is the mechanism name for GSSAPI.
|
||||
const GSSAPI = "GSSAPI"
|
||||
|
||||
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
if cred.Source != "" && cred.Source != "$external" {
|
||||
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
|
||||
}
|
||||
|
||||
return &GSSAPIAuthenticator{
|
||||
Username: cred.Username,
|
||||
Password: cred.Password,
|
||||
PasswordSet: cred.PasswordSet,
|
||||
Props: cred.Props,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GSSAPIAuthenticator uses the GSSAPI algorithm over SASL to authenticate a connection.
|
||||
type GSSAPIAuthenticator struct {
|
||||
Username string
|
||||
Password string
|
||||
PasswordSet bool
|
||||
Props map[string]string
|
||||
}
|
||||
|
||||
// Auth authenticates the connection.
|
||||
func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
target := cfg.Description.Addr.String()
|
||||
hostname, _, err := net.SplitHostPort(target)
|
||||
if err != nil {
|
||||
return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil)
|
||||
}
|
||||
|
||||
client, err := gssapi.New(hostname, a.Username, a.Password, a.PasswordSet, a.Props)
|
||||
|
||||
if err != nil {
|
||||
return newAuthError("error creating gssapi", err)
|
||||
}
|
||||
return ConductSaslConversation(ctx, cfg, "$external", client)
|
||||
}
|
||||
17
mongo/x/mongo/driver/auth/gssapi_not_enabled.go
Normal file
17
mongo/x/mongo/driver/auth/gssapi_not_enabled.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build !gssapi
|
||||
// +build !gssapi
|
||||
|
||||
package auth
|
||||
|
||||
// GSSAPI is the mechanism name for GSSAPI.
|
||||
const GSSAPI = "GSSAPI"
|
||||
|
||||
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil)
|
||||
}
|
||||
22
mongo/x/mongo/driver/auth/gssapi_not_supported.go
Normal file
22
mongo/x/mongo/driver/auth/gssapi_not_supported.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build gssapi && !windows && !linux && !darwin
|
||||
// +build gssapi,!windows,!linux,!darwin
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// GSSAPI is the mechanism name for GSSAPI.
|
||||
const GSSAPI = "GSSAPI"
|
||||
|
||||
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil)
|
||||
}
|
||||
45
mongo/x/mongo/driver/auth/gssapi_test.go
Normal file
45
mongo/x/mongo/driver/auth/gssapi_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build gssapi
|
||||
// +build gssapi
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
)
|
||||
|
||||
func TestGSSAPIAuthenticator(t *testing.T) {
|
||||
t.Run("PropsError", func(t *testing.T) {
|
||||
// Cannot specify both CANONICALIZE_HOST_NAME and SERVICE_HOST
|
||||
|
||||
authenticator := &GSSAPIAuthenticator{
|
||||
Username: "foo",
|
||||
Password: "bar",
|
||||
PasswordSet: true,
|
||||
Props: map[string]string{
|
||||
"CANONICALIZE_HOST_NAME": "true",
|
||||
"SERVICE_HOST": "localhost",
|
||||
},
|
||||
}
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 6,
|
||||
},
|
||||
Addr: address.Address("foo:27017"),
|
||||
}
|
||||
err := authenticator.Auth(context.Background(), &Config{Description: desc})
|
||||
if err == nil {
|
||||
t.Fatalf("expected err, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
63
mongo/x/mongo/driver/auth/internal/awsv4/credentials.go
Normal file
63
mongo/x/mongo/driver/auth/internal/awsv4/credentials.go
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/credentials/static_provider.go
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/credentials/credentials.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package awsv4
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// StaticProviderName provides a name of Static provider
|
||||
const StaticProviderName = "StaticProvider"
|
||||
|
||||
var (
|
||||
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
|
||||
ErrStaticCredentialsEmpty = errors.New("EmptyStaticCreds: static credentials are empty")
|
||||
)
|
||||
|
||||
// A Value is the AWS credentials value for individual credential fields.
|
||||
type Value struct {
|
||||
// AWS Access key ID
|
||||
AccessKeyID string
|
||||
|
||||
// AWS Secret Access Key
|
||||
SecretAccessKey string
|
||||
|
||||
// AWS Session Token
|
||||
SessionToken string
|
||||
|
||||
// Provider used to get credentials
|
||||
ProviderName string
|
||||
}
|
||||
|
||||
// HasKeys returns if the credentials Value has both AccessKeyID and
|
||||
// SecretAccessKey value set.
|
||||
func (v Value) HasKeys() bool {
|
||||
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
|
||||
}
|
||||
|
||||
// A StaticProvider is a set of credentials which are set programmatically,
|
||||
// and will never expire.
|
||||
type StaticProvider struct {
|
||||
Value
|
||||
}
|
||||
|
||||
// Retrieve returns the credentials or error if the credentials are invalid.
|
||||
func (s *StaticProvider) Retrieve() (Value, error) {
|
||||
if s.AccessKeyID == "" || s.SecretAccessKey == "" {
|
||||
return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty
|
||||
}
|
||||
|
||||
if len(s.Value.ProviderName) == 0 {
|
||||
s.Value.ProviderName = StaticProviderName
|
||||
}
|
||||
return s.Value, nil
|
||||
}
|
||||
15
mongo/x/mongo/driver/auth/internal/awsv4/doc.go
Normal file
15
mongo/x/mongo/driver/auth/internal/awsv4/doc.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go v1.34.28 by Amazon.com, Inc.
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
// Package awsv4 implements signing for AWS V4 signer with static credentials,
|
||||
// and is based on and modified from code in the package aws-sdk-go. The
|
||||
// modifications remove non-static credentials, support for non-sts services,
|
||||
// and the options for v4.Signer. They also reduce the number of non-Go
|
||||
// library dependencies.
|
||||
package awsv4
|
||||
80
mongo/x/mongo/driver/auth/internal/awsv4/request.go
Normal file
80
mongo/x/mongo/driver/auth/internal/awsv4/request.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package awsv4
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Returns host from request
|
||||
func getHost(r *http.Request) string {
|
||||
if r.Host != "" {
|
||||
return r.Host
|
||||
}
|
||||
|
||||
if r.URL == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return r.URL.Host
|
||||
}
|
||||
|
||||
// Hostname returns u.Host, without any port number.
|
||||
//
|
||||
// If Host is an IPv6 literal with a port number, Hostname returns the
|
||||
// IPv6 literal without the square brackets. IPv6 literals may include
|
||||
// a zone identifier.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func stripPort(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return hostport
|
||||
}
|
||||
if i := strings.IndexByte(hostport, ']'); i != -1 {
|
||||
return strings.TrimPrefix(hostport[:i], "[")
|
||||
}
|
||||
return hostport[:colon]
|
||||
}
|
||||
|
||||
// Port returns the port part of u.Host, without the leading colon.
|
||||
// If u.Host doesn't contain a port, Port returns an empty string.
|
||||
//
|
||||
// Copied from the Go 1.8 standard library (net/url)
|
||||
func portOnly(hostport string) string {
|
||||
colon := strings.IndexByte(hostport, ':')
|
||||
if colon == -1 {
|
||||
return ""
|
||||
}
|
||||
if i := strings.Index(hostport, "]:"); i != -1 {
|
||||
return hostport[i+len("]:"):]
|
||||
}
|
||||
if strings.Contains(hostport, "]") {
|
||||
return ""
|
||||
}
|
||||
return hostport[colon+len(":"):]
|
||||
}
|
||||
|
||||
// Returns true if the specified URI is using the standard port
|
||||
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
|
||||
func isDefaultPort(scheme, port string) bool {
|
||||
if port == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerCaseScheme := strings.ToLower(scheme)
|
||||
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
46
mongo/x/mongo/driver/auth/internal/awsv4/rest.go
Normal file
46
mongo/x/mongo/driver/auth/internal/awsv4/rest.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/private/protocol/rest/build.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package awsv4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Whether the byte value can be sent without escaping in AWS URLs
|
||||
var noEscape [256]bool
|
||||
|
||||
func init() {
|
||||
for i := 0; i < len(noEscape); i++ {
|
||||
// AWS expects every character except these to be escaped
|
||||
noEscape[i] = (i >= 'A' && i <= 'Z') ||
|
||||
(i >= 'a' && i <= 'z') ||
|
||||
(i >= '0' && i <= '9') ||
|
||||
i == '-' ||
|
||||
i == '.' ||
|
||||
i == '_' ||
|
||||
i == '~'
|
||||
}
|
||||
}
|
||||
|
||||
// EscapePath escapes part of a URL path in Amazon style
|
||||
func EscapePath(path string, encodeSep bool) string {
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if noEscape[c] || (c == '/' && !encodeSep) {
|
||||
buf.WriteByte(c)
|
||||
} else {
|
||||
fmt.Fprintf(&buf, "%%%02X", c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
98
mongo/x/mongo/driver/auth/internal/awsv4/rules.go
Normal file
98
mongo/x/mongo/driver/auth/internal/awsv4/rules.go
Normal file
@@ -0,0 +1,98 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/header_rules.go
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/internal/strings/strings.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package awsv4
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// validator houses a set of rule needed for validation of a
|
||||
// string value
|
||||
type rules []rule
|
||||
|
||||
// rule interface allows for more flexible rules and just simply
|
||||
// checks whether or not a value adheres to that rule
|
||||
type rule interface {
|
||||
IsValid(value string) bool
|
||||
}
|
||||
|
||||
// IsValid will iterate through all rules and see if any rules
|
||||
// apply to the value and supports nested rules
|
||||
func (r rules) IsValid(value string) bool {
|
||||
for _, rule := range r {
|
||||
if rule.IsValid(value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// mapRule generic rule for maps
|
||||
type mapRule map[string]struct{}
|
||||
|
||||
// IsValid for the map rule satisfies whether it exists in the map
|
||||
func (m mapRule) IsValid(value string) bool {
|
||||
_, ok := m[value]
|
||||
return ok
|
||||
}
|
||||
|
||||
// allowlist is a generic rule for allowlisting
|
||||
type allowlist struct {
|
||||
rule
|
||||
}
|
||||
|
||||
// IsValid for allowlist checks if the value is within the allowlist
|
||||
func (a allowlist) IsValid(value string) bool {
|
||||
return a.rule.IsValid(value)
|
||||
}
|
||||
|
||||
// denylist is a generic rule for denylisting
|
||||
type denylist struct {
|
||||
rule
|
||||
}
|
||||
|
||||
// IsValid for allowlist checks if the value is within the allowlist
|
||||
func (d denylist) IsValid(value string) bool {
|
||||
return !d.rule.IsValid(value)
|
||||
}
|
||||
|
||||
type patterns []string
|
||||
|
||||
// hasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
|
||||
// under Unicode case-folding.
|
||||
func hasPrefixFold(s, prefix string) bool {
|
||||
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
|
||||
}
|
||||
|
||||
// IsValid for patterns checks each pattern and returns if a match has
|
||||
// been found
|
||||
func (p patterns) IsValid(value string) bool {
|
||||
for _, pattern := range p {
|
||||
if hasPrefixFold(value, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// inclusiveRules rules allow for rules to depend on one another
|
||||
type inclusiveRules []rule
|
||||
|
||||
// IsValid will return true if all rules are true
|
||||
func (r inclusiveRules) IsValid(value string) bool {
|
||||
for _, rule := range r {
|
||||
if !rule.IsValid(value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
472
mongo/x/mongo/driver/auth/internal/awsv4/signer.go
Normal file
472
mongo/x/mongo/driver/auth/internal/awsv4/signer.go
Normal file
@@ -0,0 +1,472 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/v4.go
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/uri_path.go
|
||||
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/types.go
|
||||
// See THIRD-PARTY-NOTICES for original license terms
|
||||
|
||||
package awsv4
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
authHeaderSignatureElem = "Signature="
|
||||
|
||||
authHeaderPrefix = "AWS4-HMAC-SHA256"
|
||||
timeFormat = "20060102T150405Z"
|
||||
shortTimeFormat = "20060102"
|
||||
awsV4Request = "aws4_request"
|
||||
|
||||
// emptyStringSHA256 is a SHA256 of an empty string
|
||||
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
|
||||
)
|
||||
|
||||
var ignoredHeaders = rules{
|
||||
denylist{
|
||||
mapRule{
|
||||
authorizationHeader: struct{}{},
|
||||
"User-Agent": struct{}{},
|
||||
"X-Amzn-Trace-Id": struct{}{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Signer applies AWS v4 signing to given request. Use this to sign requests
|
||||
// that need to be signed with AWS V4 Signatures.
|
||||
type Signer struct {
|
||||
Credentials *StaticProvider
|
||||
}
|
||||
|
||||
// NewSigner returns a Signer pointer configured with the credentials and optional
|
||||
// option values provided. If not options are provided the Signer will use its
|
||||
// default configuration.
|
||||
func NewSigner(credentials *StaticProvider) *Signer {
|
||||
v4 := &Signer{
|
||||
Credentials: credentials,
|
||||
}
|
||||
|
||||
return v4
|
||||
}
|
||||
|
||||
type signingCtx struct {
|
||||
ServiceName string
|
||||
Region string
|
||||
Request *http.Request
|
||||
Body io.ReadSeeker
|
||||
Query url.Values
|
||||
Time time.Time
|
||||
SignedHeaderVals http.Header
|
||||
|
||||
credValues Value
|
||||
|
||||
bodyDigest string
|
||||
signedHeaders string
|
||||
canonicalHeaders string
|
||||
canonicalString string
|
||||
credentialString string
|
||||
stringToSign string
|
||||
signature string
|
||||
authorization string
|
||||
}
|
||||
|
||||
// Sign signs AWS v4 requests with the provided body, service name, region the
|
||||
// request is made to, and time the request is signed at. The signTime allows
|
||||
// you to specify that a request is signed for the future, and cannot be
|
||||
// used until then.
|
||||
//
|
||||
// Returns a list of HTTP headers that were included in the signature or an
|
||||
// error if signing the request failed. Generally for signed requests this value
|
||||
// is not needed as the full request context will be captured by the http.Request
|
||||
// value. It is included for reference though.
|
||||
//
|
||||
// Sign will set the request's Body to be the `body` parameter passed in. If
|
||||
// the body is not already an io.ReadCloser, it will be wrapped within one. If
|
||||
// a `nil` body parameter passed to Sign, the request's Body field will be
|
||||
// also set to nil. Its important to note that this functionality will not
|
||||
// change the request's ContentLength of the request.
|
||||
//
|
||||
// Sign differs from Presign in that it will sign the request using HTTP
|
||||
// header values. This type of signing is intended for http.Request values that
|
||||
// will not be shared, or are shared in a way the header values on the request
|
||||
// will not be lost.
|
||||
//
|
||||
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
|
||||
// generated. To bypass the signer computing the hash you can set the
|
||||
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
|
||||
// only compute the hash if the request header value is empty.
|
||||
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||
return v4.signWithBody(r, body, service, region, signTime)
|
||||
}
|
||||
|
||||
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
|
||||
ctx := &signingCtx{
|
||||
Request: r,
|
||||
Body: body,
|
||||
Query: r.URL.Query(),
|
||||
Time: signTime,
|
||||
ServiceName: service,
|
||||
Region: region,
|
||||
}
|
||||
|
||||
for key := range ctx.Query {
|
||||
sort.Strings(ctx.Query[key])
|
||||
}
|
||||
|
||||
if ctx.isRequestSigned() {
|
||||
ctx.Time = time.Now()
|
||||
}
|
||||
|
||||
var err error
|
||||
ctx.credValues, err = v4.Credentials.Retrieve()
|
||||
if err != nil {
|
||||
return http.Header{}, err
|
||||
}
|
||||
|
||||
ctx.sanitizeHostForHeader()
|
||||
ctx.assignAmzQueryValues()
|
||||
if err := ctx.build(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var reader io.ReadCloser
|
||||
if body != nil {
|
||||
var ok bool
|
||||
if reader, ok = body.(io.ReadCloser); !ok {
|
||||
reader = ioutil.NopCloser(body)
|
||||
}
|
||||
}
|
||||
r.Body = reader
|
||||
|
||||
return ctx.SignedHeaderVals, nil
|
||||
}
|
||||
|
||||
// sanitizeHostForHeader removes default port from host and updates request.Host
|
||||
func (ctx *signingCtx) sanitizeHostForHeader() {
|
||||
r := ctx.Request
|
||||
host := getHost(r)
|
||||
port := portOnly(host)
|
||||
if port != "" && isDefaultPort(r.URL.Scheme, port) {
|
||||
r.Host = stripPort(host)
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) assignAmzQueryValues() {
|
||||
if ctx.credValues.SessionToken != "" {
|
||||
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) build() error {
|
||||
ctx.buildTime() // no depends
|
||||
ctx.buildCredentialString() // no depends
|
||||
|
||||
if err := ctx.buildBodyDigest(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unsignedHeaders := ctx.Request.Header
|
||||
|
||||
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
|
||||
ctx.buildCanonicalString() // depends on canon headers / signed headers
|
||||
ctx.buildStringToSign() // depends on canon string
|
||||
ctx.buildSignature() // depends on string to sign
|
||||
|
||||
parts := []string{
|
||||
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
|
||||
"SignedHeaders=" + ctx.signedHeaders,
|
||||
authHeaderSignatureElem + ctx.signature,
|
||||
}
|
||||
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSignedRequestSignature attempts to extract the signature of the request.
|
||||
// Returning an error if the request is unsigned, or unable to extract the
|
||||
// signature.
|
||||
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
|
||||
|
||||
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
|
||||
ps := strings.Split(auth, ", ")
|
||||
for _, p := range ps {
|
||||
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
|
||||
sig := p[len(authHeaderSignatureElem):]
|
||||
if len(sig) == 0 {
|
||||
return nil, fmt.Errorf("invalid request signature authorization header")
|
||||
}
|
||||
return hex.DecodeString(sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
|
||||
return hex.DecodeString(sig)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("request not signed")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildTime() {
|
||||
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCredentialString() {
|
||||
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
|
||||
headers := make([]string, 0, len(header))
|
||||
headers = append(headers, "host")
|
||||
for k, v := range header {
|
||||
if !r.IsValid(k) {
|
||||
continue // ignored header
|
||||
}
|
||||
if ctx.SignedHeaderVals == nil {
|
||||
ctx.SignedHeaderVals = make(http.Header)
|
||||
}
|
||||
|
||||
lowerCaseKey := strings.ToLower(k)
|
||||
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
|
||||
// include additional values
|
||||
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
|
||||
continue
|
||||
}
|
||||
|
||||
headers = append(headers, lowerCaseKey)
|
||||
ctx.SignedHeaderVals[lowerCaseKey] = v
|
||||
}
|
||||
sort.Strings(headers)
|
||||
|
||||
ctx.signedHeaders = strings.Join(headers, ";")
|
||||
|
||||
headerValues := make([]string, len(headers))
|
||||
for i, k := range headers {
|
||||
if k == "host" {
|
||||
if ctx.Request.Host != "" {
|
||||
headerValues[i] = "host:" + ctx.Request.Host
|
||||
} else {
|
||||
headerValues[i] = "host:" + ctx.Request.URL.Host
|
||||
}
|
||||
} else {
|
||||
headerValues[i] = k + ":" +
|
||||
strings.Join(ctx.SignedHeaderVals[k], ",")
|
||||
}
|
||||
}
|
||||
stripExcessSpaces(headerValues)
|
||||
ctx.canonicalHeaders = strings.Join(headerValues, "\n")
|
||||
}
|
||||
|
||||
func getURIPath(u *url.URL) string {
|
||||
var uri string
|
||||
|
||||
if len(u.Opaque) > 0 {
|
||||
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
|
||||
} else {
|
||||
uri = u.EscapedPath()
|
||||
}
|
||||
|
||||
if len(uri) == 0 {
|
||||
uri = "/"
|
||||
}
|
||||
|
||||
return uri
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildCanonicalString() {
|
||||
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
|
||||
|
||||
uri := getURIPath(ctx.Request.URL)
|
||||
|
||||
uri = EscapePath(uri, false)
|
||||
|
||||
ctx.canonicalString = strings.Join([]string{
|
||||
ctx.Request.Method,
|
||||
uri,
|
||||
ctx.Request.URL.RawQuery,
|
||||
ctx.canonicalHeaders + "\n",
|
||||
ctx.signedHeaders,
|
||||
ctx.bodyDigest,
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildStringToSign() {
|
||||
ctx.stringToSign = strings.Join([]string{
|
||||
authHeaderPrefix,
|
||||
formatTime(ctx.Time),
|
||||
ctx.credentialString,
|
||||
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildSignature() {
|
||||
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
|
||||
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
|
||||
ctx.signature = hex.EncodeToString(signature)
|
||||
}
|
||||
|
||||
func (ctx *signingCtx) buildBodyDigest() error {
|
||||
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
|
||||
if hash == "" {
|
||||
if ctx.Body == nil {
|
||||
hash = emptyStringSHA256
|
||||
} else {
|
||||
hashBytes, err := makeSha256Reader(ctx.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hash = hex.EncodeToString(hashBytes)
|
||||
}
|
||||
}
|
||||
ctx.bodyDigest = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isRequestSigned returns if the request is currently signed or presigned
|
||||
func (ctx *signingCtx) isRequestSigned() bool {
|
||||
return ctx.Request.Header.Get("Authorization") != ""
|
||||
}
|
||||
|
||||
func hmacSHA256(key []byte, data []byte) []byte {
|
||||
hash := hmac.New(sha256.New, key)
|
||||
hash.Write(data)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
func hashSHA256(data []byte) []byte {
|
||||
hash := sha256.New()
|
||||
hash.Write(data)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
// seekerLen attempts to get the number of bytes remaining at the seeker's
|
||||
// current position. Returns the number of bytes remaining or error.
|
||||
func seekerLen(s io.Seeker) (int64, error) {
|
||||
curOffset, err := s.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
endOffset, err := s.Seek(0, io.SeekEnd)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = s.Seek(curOffset, io.SeekStart)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return endOffset - curOffset, nil
|
||||
}
|
||||
|
||||
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
|
||||
hash := sha256.New()
|
||||
start, err := reader.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
// ensure error is return if unable to seek back to start of payload.
|
||||
_, err = reader.Seek(start, io.SeekStart)
|
||||
}()
|
||||
|
||||
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
|
||||
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
|
||||
size, err := seekerLen(reader)
|
||||
if err != nil {
|
||||
_, _ = io.Copy(hash, reader)
|
||||
} else {
|
||||
_, _ = io.CopyN(hash, reader, size)
|
||||
}
|
||||
|
||||
return hash.Sum(nil), nil
|
||||
}
|
||||
|
||||
const doubleSpace = " "
|
||||
|
||||
// stripExcessSpaces will rewrite the passed in slice's string values to not
|
||||
// contain multiple side-by-side spaces.
|
||||
func stripExcessSpaces(vals []string) {
|
||||
var j, k, l, m, spaces int
|
||||
for i, str := range vals {
|
||||
// Trim trailing spaces
|
||||
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
|
||||
}
|
||||
|
||||
// Trim leading spaces
|
||||
for k = 0; k < j && str[k] == ' '; k++ {
|
||||
}
|
||||
str = str[k : j+1]
|
||||
|
||||
// Strip multiple spaces.
|
||||
j = strings.Index(str, doubleSpace)
|
||||
if j < 0 {
|
||||
vals[i] = str
|
||||
continue
|
||||
}
|
||||
|
||||
buf := []byte(str)
|
||||
for k, m, l = j, j, len(buf); k < l; k++ {
|
||||
if buf[k] == ' ' {
|
||||
if spaces == 0 {
|
||||
// First space.
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
spaces++
|
||||
} else {
|
||||
// End of multiple spaces.
|
||||
spaces = 0
|
||||
buf[m] = buf[k]
|
||||
m++
|
||||
}
|
||||
}
|
||||
|
||||
vals[i] = string(buf[:m])
|
||||
}
|
||||
}
|
||||
|
||||
func buildSigningScope(region, service string, dt time.Time) string {
|
||||
return strings.Join([]string{
|
||||
formatShortTime(dt),
|
||||
region,
|
||||
service,
|
||||
awsV4Request,
|
||||
}, "/")
|
||||
}
|
||||
|
||||
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
|
||||
keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
|
||||
keyRegion := hmacSHA256(keyDate, []byte(region))
|
||||
keyService := hmacSHA256(keyRegion, []byte(service))
|
||||
signingKey := hmacSHA256(keyService, []byte(awsV4Request))
|
||||
return signingKey
|
||||
}
|
||||
|
||||
func formatShortTime(dt time.Time) string {
|
||||
return dt.UTC().Format(shortTimeFormat)
|
||||
}
|
||||
|
||||
func formatTime(dt time.Time) string {
|
||||
return dt.UTC().Format(timeFormat)
|
||||
}
|
||||
167
mongo/x/mongo/driver/auth/internal/gssapi/gss.go
Normal file
167
mongo/x/mongo/driver/auth/internal/gssapi/gss.go
Normal file
@@ -0,0 +1,167 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build gssapi && (linux || darwin)
|
||||
// +build gssapi
|
||||
// +build linux darwin
|
||||
|
||||
package gssapi
|
||||
|
||||
/*
|
||||
#cgo linux CFLAGS: -DGOOS_linux
|
||||
#cgo linux LDFLAGS: -lgssapi_krb5 -lkrb5
|
||||
#cgo darwin CFLAGS: -DGOOS_darwin
|
||||
#cgo darwin LDFLAGS: -framework GSS
|
||||
#include "gss_wrapper.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// New creates a new SaslClient. The target parameter should be a hostname with no port.
|
||||
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
|
||||
serviceName := "mongodb"
|
||||
|
||||
for key, value := range props {
|
||||
switch strings.ToUpper(key) {
|
||||
case "CANONICALIZE_HOST_NAME":
|
||||
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME is not supported when using gssapi on %s", runtime.GOOS)
|
||||
case "SERVICE_REALM":
|
||||
return nil, fmt.Errorf("SERVICE_REALM is not supported when using gssapi on %s", runtime.GOOS)
|
||||
case "SERVICE_NAME":
|
||||
serviceName = value
|
||||
case "SERVICE_HOST":
|
||||
target = value
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown mechanism property %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
servicePrincipalName := fmt.Sprintf("%s@%s", serviceName, target)
|
||||
|
||||
return &SaslClient{
|
||||
servicePrincipalName: servicePrincipalName,
|
||||
username: username,
|
||||
password: password,
|
||||
passwordSet: passwordSet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type SaslClient struct {
|
||||
servicePrincipalName string
|
||||
username string
|
||||
password string
|
||||
passwordSet bool
|
||||
|
||||
// state
|
||||
state C.gssapi_client_state
|
||||
contextComplete bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Close() {
|
||||
C.gssapi_client_destroy(&sc.state)
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Start() (string, []byte, error) {
|
||||
const mechName = "GSSAPI"
|
||||
|
||||
cservicePrincipalName := C.CString(sc.servicePrincipalName)
|
||||
defer C.free(unsafe.Pointer(cservicePrincipalName))
|
||||
var cusername *C.char
|
||||
var cpassword *C.char
|
||||
if sc.username != "" {
|
||||
cusername = C.CString(sc.username)
|
||||
defer C.free(unsafe.Pointer(cusername))
|
||||
if sc.passwordSet {
|
||||
cpassword = C.CString(sc.password)
|
||||
defer C.free(unsafe.Pointer(cpassword))
|
||||
}
|
||||
}
|
||||
status := C.gssapi_client_init(&sc.state, cservicePrincipalName, cusername, cpassword)
|
||||
|
||||
if status != C.GSSAPI_OK {
|
||||
return mechName, nil, sc.getError("unable to initialize client")
|
||||
}
|
||||
|
||||
payload, err := sc.Next(nil)
|
||||
|
||||
return mechName, payload, err
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
|
||||
|
||||
var buf unsafe.Pointer
|
||||
var bufLen C.size_t
|
||||
var outBuf unsafe.Pointer
|
||||
var outBufLen C.size_t
|
||||
|
||||
if sc.contextComplete {
|
||||
if sc.username == "" {
|
||||
var cusername *C.char
|
||||
status := C.gssapi_client_username(&sc.state, &cusername)
|
||||
if status != C.GSSAPI_OK {
|
||||
return nil, sc.getError("unable to acquire username")
|
||||
}
|
||||
defer C.free(unsafe.Pointer(cusername))
|
||||
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
|
||||
}
|
||||
|
||||
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
|
||||
buf = unsafe.Pointer(&bytes[0])
|
||||
bufLen = C.size_t(len(bytes))
|
||||
status := C.gssapi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
|
||||
if status != C.GSSAPI_OK {
|
||||
return nil, sc.getError("unable to wrap authz")
|
||||
}
|
||||
|
||||
sc.done = true
|
||||
} else {
|
||||
if len(challenge) > 0 {
|
||||
buf = unsafe.Pointer(&challenge[0])
|
||||
bufLen = C.size_t(len(challenge))
|
||||
}
|
||||
|
||||
status := C.gssapi_client_negotiate(&sc.state, buf, bufLen, &outBuf, &outBufLen)
|
||||
switch status {
|
||||
case C.GSSAPI_OK:
|
||||
sc.contextComplete = true
|
||||
case C.GSSAPI_CONTINUE:
|
||||
default:
|
||||
return nil, sc.getError("unable to negotiate with server")
|
||||
}
|
||||
}
|
||||
|
||||
if outBuf != nil {
|
||||
defer C.free(outBuf)
|
||||
}
|
||||
|
||||
return C.GoBytes(outBuf, C.int(outBufLen)), nil
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Completed() bool {
|
||||
return sc.done
|
||||
}
|
||||
|
||||
func (sc *SaslClient) getError(prefix string) error {
|
||||
var desc *C.char
|
||||
|
||||
status := C.gssapi_error_desc(sc.state.maj_stat, sc.state.min_stat, &desc)
|
||||
if status != C.GSSAPI_OK {
|
||||
if desc != nil {
|
||||
C.free(unsafe.Pointer(desc))
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: (%v, %v)", prefix, sc.state.maj_stat, sc.state.min_stat)
|
||||
}
|
||||
defer C.free(unsafe.Pointer(desc))
|
||||
|
||||
return fmt.Errorf("%s: %v(%v,%v)", prefix, C.GoString(desc), int32(sc.state.maj_stat), int32(sc.state.min_stat))
|
||||
}
|
||||
254
mongo/x/mongo/driver/auth/internal/gssapi/gss_wrapper.c
Normal file
254
mongo/x/mongo/driver/auth/internal/gssapi/gss_wrapper.c
Normal file
@@ -0,0 +1,254 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//+build gssapi
|
||||
//+build linux darwin
|
||||
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include "gss_wrapper.h"
|
||||
|
||||
OM_uint32 gssapi_canonicalize_name(
|
||||
OM_uint32* minor_status,
|
||||
char *input_name,
|
||||
gss_OID input_name_type,
|
||||
gss_name_t *output_name
|
||||
)
|
||||
{
|
||||
OM_uint32 major_status;
|
||||
gss_name_t imported_name = GSS_C_NO_NAME;
|
||||
gss_buffer_desc buffer = GSS_C_EMPTY_BUFFER;
|
||||
|
||||
buffer.value = input_name;
|
||||
buffer.length = strlen(input_name);
|
||||
major_status = gss_import_name(minor_status, &buffer, input_name_type, &imported_name);
|
||||
if (GSS_ERROR(major_status)) {
|
||||
return major_status;
|
||||
}
|
||||
|
||||
major_status = gss_canonicalize_name(minor_status, imported_name, (gss_OID)gss_mech_krb5, output_name);
|
||||
if (imported_name != GSS_C_NO_NAME) {
|
||||
OM_uint32 ignored;
|
||||
gss_release_name(&ignored, &imported_name);
|
||||
}
|
||||
|
||||
return major_status;
|
||||
}
|
||||
|
||||
int gssapi_error_desc(
|
||||
OM_uint32 maj_stat,
|
||||
OM_uint32 min_stat,
|
||||
char **desc
|
||||
)
|
||||
{
|
||||
OM_uint32 stat = maj_stat;
|
||||
int stat_type = GSS_C_GSS_CODE;
|
||||
if (min_stat != 0) {
|
||||
stat = min_stat;
|
||||
stat_type = GSS_C_MECH_CODE;
|
||||
}
|
||||
|
||||
OM_uint32 local_maj_stat, local_min_stat;
|
||||
OM_uint32 msg_ctx = 0;
|
||||
gss_buffer_desc desc_buffer;
|
||||
do
|
||||
{
|
||||
local_maj_stat = gss_display_status(
|
||||
&local_min_stat,
|
||||
stat,
|
||||
stat_type,
|
||||
GSS_C_NO_OID,
|
||||
&msg_ctx,
|
||||
&desc_buffer
|
||||
);
|
||||
if (GSS_ERROR(local_maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
if (*desc) {
|
||||
free(*desc);
|
||||
}
|
||||
|
||||
*desc = malloc(desc_buffer.length+1);
|
||||
memcpy(*desc, desc_buffer.value, desc_buffer.length+1);
|
||||
|
||||
gss_release_buffer(&local_min_stat, &desc_buffer);
|
||||
}
|
||||
while(msg_ctx != 0);
|
||||
|
||||
return GSSAPI_OK;
|
||||
}
|
||||
|
||||
int gssapi_client_init(
|
||||
gssapi_client_state *client,
|
||||
char* spn,
|
||||
char* username,
|
||||
char* password
|
||||
)
|
||||
{
|
||||
client->cred = GSS_C_NO_CREDENTIAL;
|
||||
client->ctx = GSS_C_NO_CONTEXT;
|
||||
|
||||
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, spn, GSS_C_NT_HOSTBASED_SERVICE, &client->spn);
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
if (username) {
|
||||
gss_name_t name;
|
||||
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, username, GSS_C_NT_USER_NAME, &name);
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
if (password) {
|
||||
gss_buffer_desc password_buffer;
|
||||
password_buffer.value = password;
|
||||
password_buffer.length = strlen(password);
|
||||
client->maj_stat = gss_acquire_cred_with_password(&client->min_stat, name, &password_buffer, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
|
||||
} else {
|
||||
client->maj_stat = gss_acquire_cred(&client->min_stat, name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
|
||||
}
|
||||
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
OM_uint32 ignored;
|
||||
gss_release_name(&ignored, &name);
|
||||
}
|
||||
|
||||
return GSSAPI_OK;
|
||||
}
|
||||
|
||||
int gssapi_client_username(
|
||||
gssapi_client_state *client,
|
||||
char** username
|
||||
)
|
||||
{
|
||||
OM_uint32 ignored;
|
||||
gss_name_t name = GSS_C_NO_NAME;
|
||||
|
||||
client->maj_stat = gss_inquire_context(&client->min_stat, client->ctx, &name, NULL, NULL, NULL, NULL, NULL, NULL);
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
gss_buffer_desc name_buffer;
|
||||
client->maj_stat = gss_display_name(&client->min_stat, name, &name_buffer, NULL);
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
gss_release_name(&ignored, &name);
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
*username = malloc(name_buffer.length+1);
|
||||
memcpy(*username, name_buffer.value, name_buffer.length+1);
|
||||
|
||||
gss_release_buffer(&ignored, &name_buffer);
|
||||
gss_release_name(&ignored, &name);
|
||||
return GSSAPI_OK;
|
||||
}
|
||||
|
||||
int gssapi_client_negotiate(
|
||||
gssapi_client_state *client,
|
||||
void* input,
|
||||
size_t input_length,
|
||||
void** output,
|
||||
size_t* output_length
|
||||
)
|
||||
{
|
||||
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
|
||||
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
|
||||
|
||||
if (input) {
|
||||
input_buffer.value = input;
|
||||
input_buffer.length = input_length;
|
||||
}
|
||||
|
||||
client->maj_stat = gss_init_sec_context(
|
||||
&client->min_stat,
|
||||
client->cred,
|
||||
&client->ctx,
|
||||
client->spn,
|
||||
GSS_C_NO_OID,
|
||||
GSS_C_MUTUAL_FLAG | GSS_C_SEQUENCE_FLAG,
|
||||
0,
|
||||
GSS_C_NO_CHANNEL_BINDINGS,
|
||||
&input_buffer,
|
||||
NULL,
|
||||
&output_buffer,
|
||||
NULL,
|
||||
NULL
|
||||
);
|
||||
|
||||
if (output_buffer.length) {
|
||||
*output = malloc(output_buffer.length);
|
||||
*output_length = output_buffer.length;
|
||||
memcpy(*output, output_buffer.value, output_buffer.length);
|
||||
|
||||
OM_uint32 ignored;
|
||||
gss_release_buffer(&ignored, &output_buffer);
|
||||
}
|
||||
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
} else if (client->maj_stat == GSS_S_CONTINUE_NEEDED) {
|
||||
return GSSAPI_CONTINUE;
|
||||
}
|
||||
|
||||
return GSSAPI_OK;
|
||||
}
|
||||
|
||||
int gssapi_client_wrap_msg(
|
||||
gssapi_client_state *client,
|
||||
void* input,
|
||||
size_t input_length,
|
||||
void** output,
|
||||
size_t* output_length
|
||||
)
|
||||
{
|
||||
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
|
||||
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
|
||||
|
||||
input_buffer.value = input;
|
||||
input_buffer.length = input_length;
|
||||
|
||||
client->maj_stat = gss_wrap(&client->min_stat, client->ctx, 0, GSS_C_QOP_DEFAULT, &input_buffer, NULL, &output_buffer);
|
||||
|
||||
if (output_buffer.length) {
|
||||
*output = malloc(output_buffer.length);
|
||||
*output_length = output_buffer.length;
|
||||
memcpy(*output, output_buffer.value, output_buffer.length);
|
||||
|
||||
gss_release_buffer(&client->min_stat, &output_buffer);
|
||||
}
|
||||
|
||||
if (GSS_ERROR(client->maj_stat)) {
|
||||
return GSSAPI_ERROR;
|
||||
}
|
||||
|
||||
return GSSAPI_OK;
|
||||
}
|
||||
|
||||
int gssapi_client_destroy(
|
||||
gssapi_client_state *client
|
||||
)
|
||||
{
|
||||
OM_uint32 ignored;
|
||||
if (client->ctx != GSS_C_NO_CONTEXT) {
|
||||
gss_delete_sec_context(&ignored, &client->ctx, GSS_C_NO_BUFFER);
|
||||
}
|
||||
|
||||
if (client->spn != GSS_C_NO_NAME) {
|
||||
gss_release_name(&ignored, &client->spn);
|
||||
}
|
||||
|
||||
if (client->cred != GSS_C_NO_CREDENTIAL) {
|
||||
gss_release_cred(&ignored, &client->cred);
|
||||
}
|
||||
|
||||
return GSSAPI_OK;
|
||||
}
|
||||
72
mongo/x/mongo/driver/auth/internal/gssapi/gss_wrapper.h
Normal file
72
mongo/x/mongo/driver/auth/internal/gssapi/gss_wrapper.h
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//+build gssapi
|
||||
//+build linux darwin
|
||||
#ifndef GSS_WRAPPER_H
|
||||
#define GSS_WRAPPER_H
|
||||
|
||||
#include <stdlib.h>
|
||||
#ifdef GOOS_linux
|
||||
#include <gssapi/gssapi.h>
|
||||
#include <gssapi/gssapi_krb5.h>
|
||||
#endif
|
||||
#ifdef GOOS_darwin
|
||||
#include <GSS/GSS.h>
|
||||
#endif
|
||||
|
||||
#define GSSAPI_OK 0
|
||||
#define GSSAPI_CONTINUE 1
|
||||
#define GSSAPI_ERROR 2
|
||||
|
||||
typedef struct {
|
||||
gss_name_t spn;
|
||||
gss_cred_id_t cred;
|
||||
gss_ctx_id_t ctx;
|
||||
|
||||
OM_uint32 maj_stat;
|
||||
OM_uint32 min_stat;
|
||||
} gssapi_client_state;
|
||||
|
||||
int gssapi_error_desc(
|
||||
OM_uint32 maj_stat,
|
||||
OM_uint32 min_stat,
|
||||
char **desc
|
||||
);
|
||||
|
||||
int gssapi_client_init(
|
||||
gssapi_client_state *client,
|
||||
char* spn,
|
||||
char* username,
|
||||
char* password
|
||||
);
|
||||
|
||||
int gssapi_client_username(
|
||||
gssapi_client_state *client,
|
||||
char** username
|
||||
);
|
||||
|
||||
int gssapi_client_negotiate(
|
||||
gssapi_client_state *client,
|
||||
void* input,
|
||||
size_t input_length,
|
||||
void** output,
|
||||
size_t* output_length
|
||||
);
|
||||
|
||||
int gssapi_client_wrap_msg(
|
||||
gssapi_client_state *client,
|
||||
void* input,
|
||||
size_t input_length,
|
||||
void** output,
|
||||
size_t* output_length
|
||||
);
|
||||
|
||||
int gssapi_client_destroy(
|
||||
gssapi_client_state *client
|
||||
);
|
||||
|
||||
#endif
|
||||
353
mongo/x/mongo/driver/auth/internal/gssapi/sspi.go
Normal file
353
mongo/x/mongo/driver/auth/internal/gssapi/sspi.go
Normal file
@@ -0,0 +1,353 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build gssapi && windows
|
||||
// +build gssapi,windows
|
||||
|
||||
package gssapi
|
||||
|
||||
// #include "sspi_wrapper.h"
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// New creates a new SaslClient. The target parameter should be a hostname with no port.
|
||||
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
|
||||
initOnce.Do(initSSPI)
|
||||
if initError != nil {
|
||||
return nil, initError
|
||||
}
|
||||
|
||||
var err error
|
||||
serviceName := "mongodb"
|
||||
serviceRealm := ""
|
||||
canonicalizeHostName := false
|
||||
var serviceHostSet bool
|
||||
|
||||
for key, value := range props {
|
||||
switch strings.ToUpper(key) {
|
||||
case "CANONICALIZE_HOST_NAME":
|
||||
canonicalizeHostName, err = strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s must be a boolean (true, false, 0, 1) but got '%s'", key, value)
|
||||
}
|
||||
|
||||
case "SERVICE_REALM":
|
||||
serviceRealm = value
|
||||
case "SERVICE_NAME":
|
||||
serviceName = value
|
||||
case "SERVICE_HOST":
|
||||
serviceHostSet = true
|
||||
target = value
|
||||
}
|
||||
}
|
||||
|
||||
if canonicalizeHostName {
|
||||
// Should not canonicalize the SERVICE_HOST
|
||||
if serviceHostSet {
|
||||
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME and SERVICE_HOST canonot both be specified")
|
||||
}
|
||||
|
||||
names, err := net.LookupAddr(target)
|
||||
if err != nil || len(names) == 0 {
|
||||
return nil, fmt.Errorf("unable to canonicalize hostname: %s", err)
|
||||
}
|
||||
target = names[0]
|
||||
if target[len(target)-1] == '.' {
|
||||
target = target[:len(target)-1]
|
||||
}
|
||||
}
|
||||
|
||||
servicePrincipalName := fmt.Sprintf("%s/%s", serviceName, target)
|
||||
if serviceRealm != "" {
|
||||
servicePrincipalName += "@" + serviceRealm
|
||||
}
|
||||
|
||||
return &SaslClient{
|
||||
servicePrincipalName: servicePrincipalName,
|
||||
username: username,
|
||||
password: password,
|
||||
passwordSet: passwordSet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type SaslClient struct {
|
||||
servicePrincipalName string
|
||||
username string
|
||||
password string
|
||||
passwordSet bool
|
||||
|
||||
// state
|
||||
state C.sspi_client_state
|
||||
contextComplete bool
|
||||
done bool
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Close() {
|
||||
C.sspi_client_destroy(&sc.state)
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Start() (string, []byte, error) {
|
||||
const mechName = "GSSAPI"
|
||||
|
||||
var cusername *C.char
|
||||
var cpassword *C.char
|
||||
if sc.username != "" {
|
||||
cusername = C.CString(sc.username)
|
||||
defer C.free(unsafe.Pointer(cusername))
|
||||
if sc.passwordSet {
|
||||
cpassword = C.CString(sc.password)
|
||||
defer C.free(unsafe.Pointer(cpassword))
|
||||
}
|
||||
}
|
||||
status := C.sspi_client_init(&sc.state, cusername, cpassword)
|
||||
|
||||
if status != C.SSPI_OK {
|
||||
return mechName, nil, sc.getError("unable to intitialize client")
|
||||
}
|
||||
|
||||
payload, err := sc.Next(nil)
|
||||
|
||||
return mechName, payload, err
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
|
||||
|
||||
var outBuf C.PVOID
|
||||
var outBufLen C.ULONG
|
||||
|
||||
if sc.contextComplete {
|
||||
if sc.username == "" {
|
||||
var cusername *C.char
|
||||
status := C.sspi_client_username(&sc.state, &cusername)
|
||||
if status != C.SSPI_OK {
|
||||
return nil, sc.getError("unable to acquire username")
|
||||
}
|
||||
defer C.free(unsafe.Pointer(cusername))
|
||||
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
|
||||
}
|
||||
|
||||
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
|
||||
buf := (C.PVOID)(unsafe.Pointer(&bytes[0]))
|
||||
bufLen := C.ULONG(len(bytes))
|
||||
status := C.sspi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
|
||||
if status != C.SSPI_OK {
|
||||
return nil, sc.getError("unable to wrap authz")
|
||||
}
|
||||
|
||||
sc.done = true
|
||||
} else {
|
||||
var buf C.PVOID
|
||||
var bufLen C.ULONG
|
||||
if len(challenge) > 0 {
|
||||
buf = (C.PVOID)(unsafe.Pointer(&challenge[0]))
|
||||
bufLen = C.ULONG(len(challenge))
|
||||
}
|
||||
cservicePrincipalName := C.CString(sc.servicePrincipalName)
|
||||
defer C.free(unsafe.Pointer(cservicePrincipalName))
|
||||
|
||||
status := C.sspi_client_negotiate(&sc.state, cservicePrincipalName, buf, bufLen, &outBuf, &outBufLen)
|
||||
switch status {
|
||||
case C.SSPI_OK:
|
||||
sc.contextComplete = true
|
||||
case C.SSPI_CONTINUE:
|
||||
default:
|
||||
return nil, sc.getError("unable to negotiate with server")
|
||||
}
|
||||
}
|
||||
|
||||
if outBuf != C.PVOID(nil) {
|
||||
defer C.free(unsafe.Pointer(outBuf))
|
||||
}
|
||||
|
||||
return C.GoBytes(unsafe.Pointer(outBuf), C.int(outBufLen)), nil
|
||||
}
|
||||
|
||||
func (sc *SaslClient) Completed() bool {
|
||||
return sc.done
|
||||
}
|
||||
|
||||
func (sc *SaslClient) getError(prefix string) error {
|
||||
return getError(prefix, sc.state.status)
|
||||
}
|
||||
|
||||
var initOnce sync.Once
|
||||
var initError error
|
||||
|
||||
func initSSPI() {
|
||||
rc := C.sspi_init()
|
||||
if rc != 0 {
|
||||
initError = fmt.Errorf("error initializing sspi: %v", rc)
|
||||
}
|
||||
}
|
||||
|
||||
func getError(prefix string, status C.SECURITY_STATUS) error {
|
||||
var s string
|
||||
switch status {
|
||||
case C.SEC_E_ALGORITHM_MISMATCH:
|
||||
s = "The client and server cannot communicate because they do not possess a common algorithm."
|
||||
case C.SEC_E_BAD_BINDINGS:
|
||||
s = "The SSPI channel bindings supplied by the client are incorrect."
|
||||
case C.SEC_E_BAD_PKGID:
|
||||
s = "The requested package identifier does not exist."
|
||||
case C.SEC_E_BUFFER_TOO_SMALL:
|
||||
s = "The buffers supplied to the function are not large enough to contain the information."
|
||||
case C.SEC_E_CANNOT_INSTALL:
|
||||
s = "The security package cannot initialize successfully and should not be installed."
|
||||
case C.SEC_E_CANNOT_PACK:
|
||||
s = "The package is unable to pack the context."
|
||||
case C.SEC_E_CERT_EXPIRED:
|
||||
s = "The received certificate has expired."
|
||||
case C.SEC_E_CERT_UNKNOWN:
|
||||
s = "An unknown error occurred while processing the certificate."
|
||||
case C.SEC_E_CERT_WRONG_USAGE:
|
||||
s = "The certificate is not valid for the requested usage."
|
||||
case C.SEC_E_CONTEXT_EXPIRED:
|
||||
s = "The application is referencing a context that has already been closed. A properly written application should not receive this error."
|
||||
case C.SEC_E_CROSSREALM_DELEGATION_FAILURE:
|
||||
s = "The server attempted to make a Kerberos-constrained delegation request for a target outside the server's realm."
|
||||
case C.SEC_E_CRYPTO_SYSTEM_INVALID:
|
||||
s = "The cryptographic system or checksum function is not valid because a required function is unavailable."
|
||||
case C.SEC_E_DECRYPT_FAILURE:
|
||||
s = "The specified data could not be decrypted."
|
||||
case C.SEC_E_DELEGATION_REQUIRED:
|
||||
s = "The requested operation cannot be completed. The computer must be trusted for delegation"
|
||||
case C.SEC_E_DOWNGRADE_DETECTED:
|
||||
s = "The system detected a possible attempt to compromise security. Verify that the server that authenticated you can be contacted."
|
||||
case C.SEC_E_ENCRYPT_FAILURE:
|
||||
s = "The specified data could not be encrypted."
|
||||
case C.SEC_E_ILLEGAL_MESSAGE:
|
||||
s = "The message received was unexpected or badly formatted."
|
||||
case C.SEC_E_INCOMPLETE_CREDENTIALS:
|
||||
s = "The credentials supplied were not complete and could not be verified. The context could not be initialized."
|
||||
case C.SEC_E_INCOMPLETE_MESSAGE:
|
||||
s = "The message supplied was incomplete. The signature was not verified."
|
||||
case C.SEC_E_INSUFFICIENT_MEMORY:
|
||||
s = "Not enough memory is available to complete the request."
|
||||
case C.SEC_E_INTERNAL_ERROR:
|
||||
s = "An error occurred that did not map to an SSPI error code."
|
||||
case C.SEC_E_INVALID_HANDLE:
|
||||
s = "The handle passed to the function is not valid."
|
||||
case C.SEC_E_INVALID_TOKEN:
|
||||
s = "The token passed to the function is not valid."
|
||||
case C.SEC_E_ISSUING_CA_UNTRUSTED:
|
||||
s = "An untrusted certification authority (CA) was detected while processing the smart card certificate used for authentication."
|
||||
case C.SEC_E_ISSUING_CA_UNTRUSTED_KDC:
|
||||
s = "An untrusted CA was detected while processing the domain controller certificate used for authentication. The system event log contains additional information."
|
||||
case C.SEC_E_KDC_CERT_EXPIRED:
|
||||
s = "The domain controller certificate used for smart card logon has expired."
|
||||
case C.SEC_E_KDC_CERT_REVOKED:
|
||||
s = "The domain controller certificate used for smart card logon has been revoked."
|
||||
case C.SEC_E_KDC_INVALID_REQUEST:
|
||||
s = "A request that is not valid was sent to the KDC."
|
||||
case C.SEC_E_KDC_UNABLE_TO_REFER:
|
||||
s = "The KDC was unable to generate a referral for the service requested."
|
||||
case C.SEC_E_KDC_UNKNOWN_ETYPE:
|
||||
s = "The requested encryption type is not supported by the KDC."
|
||||
case C.SEC_E_LOGON_DENIED:
|
||||
s = "The logon has been denied"
|
||||
case C.SEC_E_MAX_REFERRALS_EXCEEDED:
|
||||
s = "The number of maximum ticket referrals has been exceeded."
|
||||
case C.SEC_E_MESSAGE_ALTERED:
|
||||
s = "The message supplied for verification has been altered."
|
||||
case C.SEC_E_MULTIPLE_ACCOUNTS:
|
||||
s = "The received certificate was mapped to multiple accounts."
|
||||
case C.SEC_E_MUST_BE_KDC:
|
||||
s = "The local computer must be a Kerberos domain controller (KDC)"
|
||||
case C.SEC_E_NO_AUTHENTICATING_AUTHORITY:
|
||||
s = "No authority could be contacted for authentication."
|
||||
case C.SEC_E_NO_CREDENTIALS:
|
||||
s = "No credentials are available."
|
||||
case C.SEC_E_NO_IMPERSONATION:
|
||||
s = "No impersonation is allowed for this context."
|
||||
case C.SEC_E_NO_IP_ADDRESSES:
|
||||
s = "Unable to accomplish the requested task because the local computer does not have any IP addresses."
|
||||
case C.SEC_E_NO_KERB_KEY:
|
||||
s = "No Kerberos key was found."
|
||||
case C.SEC_E_NO_PA_DATA:
|
||||
s = "Policy administrator (PA) data is needed to determine the encryption type"
|
||||
case C.SEC_E_NO_S4U_PROT_SUPPORT:
|
||||
s = "The Kerberos subsystem encountered an error. A service for user protocol request was made against a domain controller which does not support service for a user."
|
||||
case C.SEC_E_NO_TGT_REPLY:
|
||||
s = "The client is trying to negotiate a context and the server requires a user-to-user connection"
|
||||
case C.SEC_E_NOT_OWNER:
|
||||
s = "The caller of the function does not own the credentials."
|
||||
case C.SEC_E_OK:
|
||||
s = "The operation completed successfully."
|
||||
case C.SEC_E_OUT_OF_SEQUENCE:
|
||||
s = "The message supplied for verification is out of sequence."
|
||||
case C.SEC_E_PKINIT_CLIENT_FAILURE:
|
||||
s = "The smart card certificate used for authentication is not trusted."
|
||||
case C.SEC_E_PKINIT_NAME_MISMATCH:
|
||||
s = "The client certificate does not contain a valid UPN or does not match the client name in the logon request."
|
||||
case C.SEC_E_QOP_NOT_SUPPORTED:
|
||||
s = "The quality of protection attribute is not supported by this package."
|
||||
case C.SEC_E_REVOCATION_OFFLINE_C:
|
||||
s = "The revocation status of the smart card certificate used for authentication could not be determined."
|
||||
case C.SEC_E_REVOCATION_OFFLINE_KDC:
|
||||
s = "The revocation status of the domain controller certificate used for smart card authentication could not be determined. The system event log contains additional information."
|
||||
case C.SEC_E_SECPKG_NOT_FOUND:
|
||||
s = "The security package was not recognized."
|
||||
case C.SEC_E_SECURITY_QOS_FAILED:
|
||||
s = "The security context could not be established due to a failure in the requested quality of service (for example"
|
||||
case C.SEC_E_SHUTDOWN_IN_PROGRESS:
|
||||
s = "A system shutdown is in progress."
|
||||
case C.SEC_E_SMARTCARD_CERT_EXPIRED:
|
||||
s = "The smart card certificate used for authentication has expired."
|
||||
case C.SEC_E_SMARTCARD_CERT_REVOKED:
|
||||
s = "The smart card certificate used for authentication has been revoked. Additional information may exist in the event log."
|
||||
case C.SEC_E_SMARTCARD_LOGON_REQUIRED:
|
||||
s = "Smart card logon is required and was not used."
|
||||
case C.SEC_E_STRONG_CRYPTO_NOT_SUPPORTED:
|
||||
s = "The other end of the security negotiation requires strong cryptography"
|
||||
case C.SEC_E_TARGET_UNKNOWN:
|
||||
s = "The target was not recognized."
|
||||
case C.SEC_E_TIME_SKEW:
|
||||
s = "The clocks on the client and server computers do not match."
|
||||
case C.SEC_E_TOO_MANY_PRINCIPALS:
|
||||
s = "The KDC reply contained more than one principal name."
|
||||
case C.SEC_E_UNFINISHED_CONTEXT_DELETED:
|
||||
s = "A security context was deleted before the context was completed. This is considered a logon failure."
|
||||
case C.SEC_E_UNKNOWN_CREDENTIALS:
|
||||
s = "The credentials provided were not recognized."
|
||||
case C.SEC_E_UNSUPPORTED_FUNCTION:
|
||||
s = "The requested function is not supported."
|
||||
case C.SEC_E_UNSUPPORTED_PREAUTH:
|
||||
s = "An unsupported preauthentication mechanism was presented to the Kerberos package."
|
||||
case C.SEC_E_UNTRUSTED_ROOT:
|
||||
s = "The certificate chain was issued by an authority that is not trusted."
|
||||
case C.SEC_E_WRONG_CREDENTIAL_HANDLE:
|
||||
s = "The supplied credential handle does not match the credential associated with the security context."
|
||||
case C.SEC_E_WRONG_PRINCIPAL:
|
||||
s = "The target principal name is incorrect."
|
||||
case C.SEC_I_COMPLETE_AND_CONTINUE:
|
||||
s = "The function completed successfully"
|
||||
case C.SEC_I_COMPLETE_NEEDED:
|
||||
s = "The function completed successfully"
|
||||
case C.SEC_I_CONTEXT_EXPIRED:
|
||||
s = "The message sender has finished using the connection and has initiated a shutdown. For information about initiating or recognizing a shutdown"
|
||||
case C.SEC_I_CONTINUE_NEEDED:
|
||||
s = "The function completed successfully"
|
||||
case C.SEC_I_INCOMPLETE_CREDENTIALS:
|
||||
s = "The credentials supplied were not complete and could not be verified. Additional information can be returned from the context."
|
||||
case C.SEC_I_LOCAL_LOGON:
|
||||
s = "The logon was completed"
|
||||
case C.SEC_I_NO_LSA_CONTEXT:
|
||||
s = "There is no LSA mode context associated with this context."
|
||||
case C.SEC_I_RENEGOTIATE:
|
||||
s = "The context data must be renegotiated with the peer."
|
||||
default:
|
||||
return fmt.Errorf("%s: 0x%x", prefix, uint32(status))
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s: %s(0x%x)", prefix, s, uint32(status))
|
||||
}
|
||||
249
mongo/x/mongo/driver/auth/internal/gssapi/sspi_wrapper.c
Normal file
249
mongo/x/mongo/driver/auth/internal/gssapi/sspi_wrapper.c
Normal file
@@ -0,0 +1,249 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//+build gssapi,windows
|
||||
|
||||
#include "sspi_wrapper.h"
|
||||
|
||||
static HINSTANCE sspi_secur32_dll = NULL;
|
||||
static PSecurityFunctionTable sspi_functions = NULL;
|
||||
static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
|
||||
|
||||
int sspi_init(
|
||||
)
|
||||
{
|
||||
// Load the secur32.dll library using its exact path. Passing the exact DLL path rather than allowing LoadLibrary to
|
||||
// search in different locations removes the possibility of DLL preloading attacks. We use GetSystemDirectoryA and
|
||||
// LoadLibraryA rather than the GetSystemDirectory/LoadLibrary aliases to ensure the ANSI versions are used so we
|
||||
// don't have to account for variations in char sizes if UNICODE is enabled.
|
||||
|
||||
// Passing a 0 size will return the required buffer length to hold the path, including the null terminator.
|
||||
int requiredLen = GetSystemDirectoryA(NULL, 0);
|
||||
if (!requiredLen) {
|
||||
return GetLastError();
|
||||
}
|
||||
|
||||
// Allocate a buffer to hold the system directory + "\secur32.dll" (length 12, not including null terminator).
|
||||
int actualLen = requiredLen + 12;
|
||||
char *directoryBuffer = (char *) calloc(1, actualLen);
|
||||
int directoryLen = GetSystemDirectoryA(directoryBuffer, actualLen);
|
||||
if (!directoryLen) {
|
||||
free(directoryBuffer);
|
||||
return GetLastError();
|
||||
}
|
||||
|
||||
// Append the DLL name to the buffer.
|
||||
char *dllName = "\\secur32.dll";
|
||||
strcpy_s(&(directoryBuffer[directoryLen]), actualLen - directoryLen, dllName);
|
||||
|
||||
sspi_secur32_dll = LoadLibraryA(directoryBuffer);
|
||||
free(directoryBuffer);
|
||||
if (!sspi_secur32_dll) {
|
||||
return GetLastError();
|
||||
}
|
||||
|
||||
INIT_SECURITY_INTERFACE init_security_interface = (INIT_SECURITY_INTERFACE)GetProcAddress(sspi_secur32_dll, SECURITY_ENTRYPOINT);
|
||||
if (!init_security_interface) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
sspi_functions = (*init_security_interface)();
|
||||
if (!sspi_functions) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
return SSPI_OK;
|
||||
}
|
||||
|
||||
int sspi_client_init(
|
||||
sspi_client_state *client,
|
||||
char* username,
|
||||
char* password
|
||||
)
|
||||
{
|
||||
TimeStamp timestamp;
|
||||
|
||||
if (username) {
|
||||
if (password) {
|
||||
SEC_WINNT_AUTH_IDENTITY auth_identity;
|
||||
|
||||
#ifdef _UNICODE
|
||||
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
|
||||
#else
|
||||
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
|
||||
#endif
|
||||
auth_identity.User = (LPSTR) username;
|
||||
auth_identity.UserLength = strlen(username);
|
||||
auth_identity.Password = (LPSTR) password;
|
||||
auth_identity.PasswordLength = strlen(password);
|
||||
auth_identity.Domain = NULL;
|
||||
auth_identity.DomainLength = 0;
|
||||
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, &client->cred, ×tamp);
|
||||
} else {
|
||||
client->status = sspi_functions->AcquireCredentialsHandle(username, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, ×tamp);
|
||||
}
|
||||
} else {
|
||||
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, ×tamp);
|
||||
}
|
||||
|
||||
if (client->status != SEC_E_OK) {
|
||||
return SSPI_ERROR;
|
||||
}
|
||||
|
||||
return SSPI_OK;
|
||||
}
|
||||
|
||||
int sspi_client_username(
|
||||
sspi_client_state *client,
|
||||
char** username
|
||||
)
|
||||
{
|
||||
SecPkgCredentials_Names names;
|
||||
client->status = sspi_functions->QueryCredentialsAttributes(&client->cred, SECPKG_CRED_ATTR_NAMES, &names);
|
||||
|
||||
if (client->status != SEC_E_OK) {
|
||||
return SSPI_ERROR;
|
||||
}
|
||||
|
||||
int len = strlen(names.sUserName) + 1;
|
||||
*username = malloc(len);
|
||||
memcpy(*username, names.sUserName, len);
|
||||
|
||||
sspi_functions->FreeContextBuffer(names.sUserName);
|
||||
|
||||
return SSPI_OK;
|
||||
}
|
||||
|
||||
int sspi_client_negotiate(
|
||||
sspi_client_state *client,
|
||||
char* spn,
|
||||
PVOID input,
|
||||
ULONG input_length,
|
||||
PVOID* output,
|
||||
ULONG* output_length
|
||||
)
|
||||
{
|
||||
SecBufferDesc inbuf;
|
||||
SecBuffer in_bufs[1];
|
||||
SecBufferDesc outbuf;
|
||||
SecBuffer out_bufs[1];
|
||||
|
||||
if (client->has_ctx > 0) {
|
||||
inbuf.ulVersion = SECBUFFER_VERSION;
|
||||
inbuf.cBuffers = 1;
|
||||
inbuf.pBuffers = in_bufs;
|
||||
in_bufs[0].pvBuffer = input;
|
||||
in_bufs[0].cbBuffer = input_length;
|
||||
in_bufs[0].BufferType = SECBUFFER_TOKEN;
|
||||
}
|
||||
|
||||
outbuf.ulVersion = SECBUFFER_VERSION;
|
||||
outbuf.cBuffers = 1;
|
||||
outbuf.pBuffers = out_bufs;
|
||||
out_bufs[0].pvBuffer = NULL;
|
||||
out_bufs[0].cbBuffer = 0;
|
||||
out_bufs[0].BufferType = SECBUFFER_TOKEN;
|
||||
|
||||
ULONG context_attr = 0;
|
||||
|
||||
client->status = sspi_functions->InitializeSecurityContext(
|
||||
&client->cred,
|
||||
client->has_ctx > 0 ? &client->ctx : NULL,
|
||||
(LPSTR) spn,
|
||||
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
|
||||
0,
|
||||
SECURITY_NETWORK_DREP,
|
||||
client->has_ctx > 0 ? &inbuf : NULL,
|
||||
0,
|
||||
&client->ctx,
|
||||
&outbuf,
|
||||
&context_attr,
|
||||
NULL);
|
||||
|
||||
if (client->status != SEC_E_OK && client->status != SEC_I_CONTINUE_NEEDED) {
|
||||
return SSPI_ERROR;
|
||||
}
|
||||
|
||||
client->has_ctx = 1;
|
||||
|
||||
*output = malloc(out_bufs[0].cbBuffer);
|
||||
*output_length = out_bufs[0].cbBuffer;
|
||||
memcpy(*output, out_bufs[0].pvBuffer, *output_length);
|
||||
sspi_functions->FreeContextBuffer(out_bufs[0].pvBuffer);
|
||||
|
||||
if (client->status == SEC_I_CONTINUE_NEEDED) {
|
||||
return SSPI_CONTINUE;
|
||||
}
|
||||
|
||||
return SSPI_OK;
|
||||
}
|
||||
|
||||
int sspi_client_wrap_msg(
|
||||
sspi_client_state *client,
|
||||
PVOID input,
|
||||
ULONG input_length,
|
||||
PVOID* output,
|
||||
ULONG* output_length
|
||||
)
|
||||
{
|
||||
SecPkgContext_Sizes sizes;
|
||||
|
||||
client->status = sspi_functions->QueryContextAttributes(&client->ctx, SECPKG_ATTR_SIZES, &sizes);
|
||||
if (client->status != SEC_E_OK) {
|
||||
return SSPI_ERROR;
|
||||
}
|
||||
|
||||
char *msg = malloc((sizes.cbSecurityTrailer + input_length + sizes.cbBlockSize) * sizeof(char));
|
||||
memcpy(&msg[sizes.cbSecurityTrailer], input, input_length);
|
||||
|
||||
SecBuffer wrap_bufs[3];
|
||||
SecBufferDesc wrap_buf_desc;
|
||||
wrap_buf_desc.cBuffers = 3;
|
||||
wrap_buf_desc.pBuffers = wrap_bufs;
|
||||
wrap_buf_desc.ulVersion = SECBUFFER_VERSION;
|
||||
|
||||
wrap_bufs[0].cbBuffer = sizes.cbSecurityTrailer;
|
||||
wrap_bufs[0].BufferType = SECBUFFER_TOKEN;
|
||||
wrap_bufs[0].pvBuffer = msg;
|
||||
|
||||
wrap_bufs[1].cbBuffer = input_length;
|
||||
wrap_bufs[1].BufferType = SECBUFFER_DATA;
|
||||
wrap_bufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
|
||||
|
||||
wrap_bufs[2].cbBuffer = sizes.cbBlockSize;
|
||||
wrap_bufs[2].BufferType = SECBUFFER_PADDING;
|
||||
wrap_bufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + input_length;
|
||||
|
||||
client->status = sspi_functions->EncryptMessage(&client->ctx, SECQOP_WRAP_NO_ENCRYPT, &wrap_buf_desc, 0);
|
||||
if (client->status != SEC_E_OK) {
|
||||
free(msg);
|
||||
return SSPI_ERROR;
|
||||
}
|
||||
|
||||
*output_length = wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer + wrap_bufs[2].cbBuffer;
|
||||
*output = malloc(*output_length);
|
||||
|
||||
memcpy(*output, wrap_bufs[0].pvBuffer, wrap_bufs[0].cbBuffer);
|
||||
memcpy(*output + wrap_bufs[0].cbBuffer, wrap_bufs[1].pvBuffer, wrap_bufs[1].cbBuffer);
|
||||
memcpy(*output + wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer, wrap_bufs[2].pvBuffer, wrap_bufs[2].cbBuffer);
|
||||
|
||||
free(msg);
|
||||
|
||||
return SSPI_OK;
|
||||
}
|
||||
|
||||
int sspi_client_destroy(
|
||||
sspi_client_state *client
|
||||
)
|
||||
{
|
||||
if (client->has_ctx > 0) {
|
||||
sspi_functions->DeleteSecurityContext(&client->ctx);
|
||||
}
|
||||
|
||||
sspi_functions->FreeCredentialsHandle(&client->cred);
|
||||
|
||||
return SSPI_OK;
|
||||
}
|
||||
64
mongo/x/mongo/driver/auth/internal/gssapi/sspi_wrapper.h
Normal file
64
mongo/x/mongo/driver/auth/internal/gssapi/sspi_wrapper.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//+build gssapi,windows
|
||||
|
||||
#ifndef SSPI_WRAPPER_H
|
||||
#define SSPI_WRAPPER_H
|
||||
|
||||
#define SECURITY_WIN32 1 /* Required for SSPI */
|
||||
|
||||
#include <windows.h>
|
||||
#include <sspi.h>
|
||||
|
||||
#define SSPI_OK 0
|
||||
#define SSPI_CONTINUE 1
|
||||
#define SSPI_ERROR 2
|
||||
|
||||
typedef struct {
|
||||
CredHandle cred;
|
||||
CtxtHandle ctx;
|
||||
|
||||
int has_ctx;
|
||||
|
||||
SECURITY_STATUS status;
|
||||
} sspi_client_state;
|
||||
|
||||
int sspi_init();
|
||||
|
||||
int sspi_client_init(
|
||||
sspi_client_state *client,
|
||||
char* username,
|
||||
char* password
|
||||
);
|
||||
|
||||
int sspi_client_username(
|
||||
sspi_client_state *client,
|
||||
char** username
|
||||
);
|
||||
|
||||
int sspi_client_negotiate(
|
||||
sspi_client_state *client,
|
||||
char* spn,
|
||||
PVOID input,
|
||||
ULONG input_length,
|
||||
PVOID* output,
|
||||
ULONG* output_length
|
||||
);
|
||||
|
||||
int sspi_client_wrap_msg(
|
||||
sspi_client_state *client,
|
||||
PVOID input,
|
||||
ULONG input_length,
|
||||
PVOID* output,
|
||||
ULONG* output_length
|
||||
);
|
||||
|
||||
int sspi_client_destroy(
|
||||
sspi_client_state *client
|
||||
);
|
||||
|
||||
#endif
|
||||
82
mongo/x/mongo/driver/auth/mongodbaws.go
Normal file
82
mongo/x/mongo/driver/auth/mongodbaws.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// MongoDBAWS is the mechanism name for MongoDBAWS.
|
||||
const MongoDBAWS = "MONGODB-AWS"
|
||||
|
||||
func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
if cred.Source != "" && cred.Source != "$external" {
|
||||
return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil)
|
||||
}
|
||||
return &MongoDBAWSAuthenticator{
|
||||
source: cred.Source,
|
||||
username: cred.Username,
|
||||
password: cred.Password,
|
||||
sessionToken: cred.Props["AWS_SESSION_TOKEN"],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
|
||||
type MongoDBAWSAuthenticator struct {
|
||||
source string
|
||||
username string
|
||||
password string
|
||||
sessionToken string
|
||||
}
|
||||
|
||||
// Auth authenticates the connection.
|
||||
func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
httpClient := cfg.HTTPClient
|
||||
if httpClient == nil {
|
||||
return errors.New("cfg.HTTPClient must not be nil")
|
||||
}
|
||||
adapter := &awsSaslAdapter{
|
||||
conversation: &awsConversation{
|
||||
username: a.username,
|
||||
password: a.password,
|
||||
token: a.sessionToken,
|
||||
httpClient: httpClient,
|
||||
},
|
||||
}
|
||||
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
|
||||
if err != nil {
|
||||
return newAuthError("sasl conversation error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type awsSaslAdapter struct {
|
||||
conversation *awsConversation
|
||||
}
|
||||
|
||||
var _ SaslClient = (*awsSaslAdapter)(nil)
|
||||
|
||||
func (a *awsSaslAdapter) Start() (string, []byte, error) {
|
||||
step, err := a.conversation.Step(nil)
|
||||
if err != nil {
|
||||
return MongoDBAWS, nil, err
|
||||
}
|
||||
return MongoDBAWS, step, nil
|
||||
}
|
||||
|
||||
func (a *awsSaslAdapter) Next(challenge []byte) ([]byte, error) {
|
||||
step, err := a.conversation.Step(challenge)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return step, nil
|
||||
}
|
||||
|
||||
func (a *awsSaslAdapter) Completed() bool {
|
||||
return a.conversation.Done()
|
||||
}
|
||||
48
mongo/x/mongo/driver/auth/mongodbaws_test.go
Normal file
48
mongo/x/mongo/driver/auth/mongodbaws_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
)
|
||||
|
||||
func TestGetRegion(t *testing.T) {
|
||||
longHost := make([]rune, 256)
|
||||
emptyErr := errors.New("invalid STS host: empty")
|
||||
tooLongErr := errors.New("invalid STS host: too large")
|
||||
emptyPartErr := errors.New("invalid STS host: empty part")
|
||||
testCases := []struct {
|
||||
name string
|
||||
host string
|
||||
err error
|
||||
region string
|
||||
}{
|
||||
{"success default", "sts.amazonaws.com", nil, "us-east-1"},
|
||||
{"success parse", "first.second", nil, "second"},
|
||||
{"success no region", "first", nil, "us-east-1"},
|
||||
{"error host too long", string(longHost), tooLongErr, ""},
|
||||
{"error host empty", "", emptyErr, ""},
|
||||
{"error empty middle part", "abc..def", emptyPartErr, ""},
|
||||
{"error empty part", "first.", emptyPartErr, ""},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
reg, err := getRegion(tc.host)
|
||||
if tc.err == nil {
|
||||
assert.Nil(t, err, "error getting region: %v", err)
|
||||
assert.Equal(t, tc.region, reg, "expected %v, got %v", tc.region, reg)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, err, "expected error, got nil")
|
||||
assert.Equal(t, err, tc.err, "expected error: %v, got: %v", tc.err, err)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
110
mongo/x/mongo/driver/auth/mongodbcr.go
Normal file
110
mongo/x/mongo/driver/auth/mongodbcr.go
Normal file
@@ -0,0 +1,110 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
|
||||
// to use MD5 here to implement the MONGODB-CR specification.
|
||||
/* #nosec G501 */
|
||||
"crypto/md5"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
// MONGODBCR is the mechanism name for MONGODB-CR.
|
||||
//
|
||||
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
|
||||
// MongoDB 4.0.
|
||||
const MONGODBCR = "MONGODB-CR"
|
||||
|
||||
func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
return &MongoDBCRAuthenticator{
|
||||
DB: cred.Source,
|
||||
Username: cred.Username,
|
||||
Password: cred.Password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MongoDBCRAuthenticator uses the MONGODB-CR algorithm to authenticate a connection.
|
||||
//
|
||||
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
|
||||
// MongoDB 4.0.
|
||||
type MongoDBCRAuthenticator struct {
|
||||
DB string
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// Auth authenticates the connection.
|
||||
//
|
||||
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
|
||||
// MongoDB 4.0.
|
||||
func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
|
||||
db := a.DB
|
||||
if db == "" {
|
||||
db = defaultAuthDB
|
||||
}
|
||||
|
||||
doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
|
||||
cmd := operation.NewCommand(doc).
|
||||
Database(db).
|
||||
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
|
||||
ClusterClock(cfg.ClusterClock).
|
||||
ServerAPI(cfg.ServerAPI)
|
||||
err := cmd.Execute(ctx)
|
||||
if err != nil {
|
||||
return newError(err, MONGODBCR)
|
||||
}
|
||||
rdr := cmd.Result()
|
||||
|
||||
var getNonceResult struct {
|
||||
Nonce string `bson:"nonce"`
|
||||
}
|
||||
|
||||
err = bson.Unmarshal(rdr, &getNonceResult)
|
||||
if err != nil {
|
||||
return newAuthError("unmarshal error", err)
|
||||
}
|
||||
|
||||
doc = bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "authenticate", 1),
|
||||
bsoncore.AppendStringElement(nil, "user", a.Username),
|
||||
bsoncore.AppendStringElement(nil, "nonce", getNonceResult.Nonce),
|
||||
bsoncore.AppendStringElement(nil, "key", a.createKey(getNonceResult.Nonce)),
|
||||
)
|
||||
cmd = operation.NewCommand(doc).
|
||||
Database(db).
|
||||
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
|
||||
ClusterClock(cfg.ClusterClock).
|
||||
ServerAPI(cfg.ServerAPI)
|
||||
err = cmd.Execute(ctx)
|
||||
if err != nil {
|
||||
return newError(err, MONGODBCR)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *MongoDBCRAuthenticator) createKey(nonce string) string {
|
||||
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
|
||||
// implement the MONGODB-CR specification.
|
||||
/* #nosec G401 */
|
||||
h := md5.New()
|
||||
|
||||
_, _ = io.WriteString(h, nonce)
|
||||
_, _ = io.WriteString(h, a.Username)
|
||||
_, _ = io.WriteString(h, mongoPasswordDigest(a.Username, a.Password))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
114
mongo/x/mongo/driver/auth/mongodbcr_test.go
Normal file
114
mongo/x/mongo/driver/auth/mongodbcr_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
|
||||
)
|
||||
|
||||
func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authenticator := MongoDBCRAuthenticator{
|
||||
DB: "source",
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
}
|
||||
|
||||
resps := make(chan []byte, 2)
|
||||
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"),
|
||||
), bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 0),
|
||||
))
|
||||
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 6,
|
||||
},
|
||||
}
|
||||
c := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, 2),
|
||||
ReadResp: resps,
|
||||
Desc: desc,
|
||||
}
|
||||
|
||||
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
|
||||
errPrefix := "unable to authenticate using mechanism \"MONGODB-CR\""
|
||||
if !strings.HasPrefix(err.Error(), errPrefix) {
|
||||
t.Fatalf("expected an err starting with \"%s\" but got \"%s\"", errPrefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authenticator := MongoDBCRAuthenticator{
|
||||
DB: "source",
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
}
|
||||
|
||||
resps := make(chan []byte, 2)
|
||||
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"),
|
||||
), bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
))
|
||||
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 6,
|
||||
},
|
||||
}
|
||||
c := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, 2),
|
||||
ReadResp: resps,
|
||||
Desc: desc,
|
||||
}
|
||||
|
||||
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got \"%s\"", err)
|
||||
}
|
||||
|
||||
if len(c.Written) != 2 {
|
||||
t.Fatalf("expected 2 messages to be sent but had %d", len(c.Written))
|
||||
}
|
||||
|
||||
want := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
|
||||
compareResponses(t, <-c.Written, want, "source")
|
||||
|
||||
expectedAuthenticateDoc := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "authenticate", 1),
|
||||
bsoncore.AppendStringElement(nil, "user", "user"),
|
||||
bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"),
|
||||
bsoncore.AppendStringElement(nil, "key", "21742f26431831d5cfca035a08c5bdf6"),
|
||||
)
|
||||
compareResponses(t, <-c.Written, expectedAuthenticateDoc, "source")
|
||||
}
|
||||
|
||||
func writeReplies(c chan []byte, docs ...bsoncore.Document) {
|
||||
for _, doc := range docs {
|
||||
reply := drivertest.MakeReply(doc)
|
||||
c <- reply
|
||||
}
|
||||
}
|
||||
55
mongo/x/mongo/driver/auth/plain.go
Normal file
55
mongo/x/mongo/driver/auth/plain.go
Normal file
@@ -0,0 +1,55 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// PLAIN is the mechanism name for PLAIN.
|
||||
const PLAIN = "PLAIN"
|
||||
|
||||
func newPlainAuthenticator(cred *Cred) (Authenticator, error) {
|
||||
return &PlainAuthenticator{
|
||||
Username: cred.Username,
|
||||
Password: cred.Password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
|
||||
type PlainAuthenticator struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// Auth authenticates the connection.
|
||||
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
|
||||
username: a.Username,
|
||||
password: a.Password,
|
||||
})
|
||||
}
|
||||
|
||||
type plainSaslClient struct {
|
||||
username string
|
||||
password string
|
||||
}
|
||||
|
||||
var _ SaslClient = (*plainSaslClient)(nil)
|
||||
|
||||
func (c *plainSaslClient) Start() (string, []byte, error) {
|
||||
b := []byte("\x00" + c.username + "\x00" + c.password)
|
||||
return PLAIN, b, nil
|
||||
}
|
||||
|
||||
func (c *plainSaslClient) Next(challenge []byte) ([]byte, error) {
|
||||
return nil, newAuthError("unexpected server challenge", nil)
|
||||
}
|
||||
|
||||
func (c *plainSaslClient) Completed() bool {
|
||||
return true
|
||||
}
|
||||
147
mongo/x/mongo/driver/auth/plain_test.go
Normal file
147
mongo/x/mongo/driver/auth/plain_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"encoding/base64"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
|
||||
)
|
||||
|
||||
func TestPlainAuthenticator_Fails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authenticator := PlainAuthenticator{
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
}
|
||||
|
||||
resps := make(chan []byte, 1)
|
||||
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendInt32Element(nil, "conversationId", 1),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
|
||||
bsoncore.AppendInt32Element(nil, "code", 143),
|
||||
bsoncore.AppendBooleanElement(nil, "done", true),
|
||||
))
|
||||
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 6,
|
||||
},
|
||||
}
|
||||
c := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, 1),
|
||||
ReadResp: resps,
|
||||
Desc: desc,
|
||||
}
|
||||
|
||||
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
|
||||
errPrefix := "unable to authenticate using mechanism \"PLAIN\""
|
||||
if !strings.HasPrefix(err.Error(), errPrefix) {
|
||||
t.Fatalf("expected an err starting with \"%s\" but got \"%s\"", errPrefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainAuthenticator_Extra_server_message(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authenticator := PlainAuthenticator{
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
}
|
||||
|
||||
resps := make(chan []byte, 2)
|
||||
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendInt32Element(nil, "conversationId", 1),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
|
||||
bsoncore.AppendBooleanElement(nil, "done", false),
|
||||
), bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendInt32Element(nil, "conversationId", 1),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
|
||||
bsoncore.AppendBooleanElement(nil, "done", true),
|
||||
))
|
||||
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 6,
|
||||
},
|
||||
}
|
||||
c := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, 1),
|
||||
ReadResp: resps,
|
||||
Desc: desc,
|
||||
}
|
||||
|
||||
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
|
||||
errPrefix := "unable to authenticate using mechanism \"PLAIN\": unexpected server challenge"
|
||||
if !strings.HasPrefix(err.Error(), errPrefix) {
|
||||
t.Fatalf("expected an err starting with \"%s\" but got \"%s\"", errPrefix, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlainAuthenticator_Succeeds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
authenticator := PlainAuthenticator{
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
}
|
||||
|
||||
resps := make(chan []byte, 1)
|
||||
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendInt32Element(nil, "conversationId", 1),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
|
||||
bsoncore.AppendBooleanElement(nil, "done", true),
|
||||
))
|
||||
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 6,
|
||||
},
|
||||
}
|
||||
c := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, 1),
|
||||
ReadResp: resps,
|
||||
Desc: desc,
|
||||
}
|
||||
|
||||
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error but got \"%s\"", err)
|
||||
}
|
||||
|
||||
if len(c.Written) != 1 {
|
||||
t.Fatalf("expected 1 messages to be sent but had %d", len(c.Written))
|
||||
}
|
||||
|
||||
payload, _ := base64.StdEncoding.DecodeString("AHVzZXIAcGVuY2ls")
|
||||
expectedCmd := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "saslStart", 1),
|
||||
bsoncore.AppendStringElement(nil, "mechanism", "PLAIN"),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
|
||||
)
|
||||
compareResponses(t, <-c.Written, expectedCmd, "$external")
|
||||
}
|
||||
174
mongo/x/mongo/driver/auth/sasl.go
Normal file
174
mongo/x/mongo/driver/auth/sasl.go
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
// SaslClient is the client piece of a sasl conversation.
|
||||
type SaslClient interface {
|
||||
Start() (string, []byte, error)
|
||||
Next(challenge []byte) ([]byte, error)
|
||||
Completed() bool
|
||||
}
|
||||
|
||||
// SaslClientCloser is a SaslClient that has resources to clean up.
|
||||
type SaslClientCloser interface {
|
||||
SaslClient
|
||||
Close()
|
||||
}
|
||||
|
||||
// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command.
|
||||
type ExtraOptionsSaslClient interface {
|
||||
StartCommandOptions() bsoncore.Document
|
||||
}
|
||||
|
||||
// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the
|
||||
// conversation can be executed in multi-step speculative fashion.
|
||||
type saslConversation struct {
|
||||
client SaslClient
|
||||
source string
|
||||
mechanism string
|
||||
speculative bool
|
||||
}
|
||||
|
||||
var _ SpeculativeConversation = (*saslConversation)(nil)
|
||||
|
||||
func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation {
|
||||
authSource := source
|
||||
if authSource == "" {
|
||||
authSource = defaultAuthDB
|
||||
}
|
||||
return &saslConversation{
|
||||
client: client,
|
||||
source: authSource,
|
||||
speculative: speculative,
|
||||
}
|
||||
}
|
||||
|
||||
// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used
|
||||
// for speculative authentication.
|
||||
func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) {
|
||||
var payload []byte
|
||||
var err error
|
||||
sc.mechanism, payload, err = sc.client.Start()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
saslCmdElements := [][]byte{
|
||||
bsoncore.AppendInt32Element(nil, "saslStart", 1),
|
||||
bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
|
||||
}
|
||||
if sc.speculative {
|
||||
// The "db" field is only appended for speculative auth because the hello command is executed against admin
|
||||
// so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands
|
||||
// will be executed against the auth source.
|
||||
saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source))
|
||||
}
|
||||
if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok {
|
||||
optionsDoc := extraOptionsClient.StartCommandOptions()
|
||||
saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc))
|
||||
}
|
||||
|
||||
return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil
|
||||
}
|
||||
|
||||
type saslResponse struct {
|
||||
ConversationID int `bson:"conversationId"`
|
||||
Code int `bson:"code"`
|
||||
Done bool `bson:"done"`
|
||||
Payload []byte `bson:"payload"`
|
||||
}
|
||||
|
||||
// Finish completes the conversation based on the first server response to authenticate the given connection.
|
||||
func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error {
|
||||
if closer, ok := sc.client.(SaslClientCloser); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
|
||||
var saslResp saslResponse
|
||||
err := bson.Unmarshal(firstResponse, &saslResp)
|
||||
if err != nil {
|
||||
fullErr := fmt.Errorf("unmarshal error: %v", err)
|
||||
return newError(fullErr, sc.mechanism)
|
||||
}
|
||||
|
||||
cid := saslResp.ConversationID
|
||||
var payload []byte
|
||||
var rdr bsoncore.Document
|
||||
for {
|
||||
if saslResp.Code != 0 {
|
||||
return newError(err, sc.mechanism)
|
||||
}
|
||||
|
||||
if saslResp.Done && sc.client.Completed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload, err = sc.client.Next(saslResp.Payload)
|
||||
if err != nil {
|
||||
return newError(err, sc.mechanism)
|
||||
}
|
||||
|
||||
if saslResp.Done && sc.client.Completed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
doc := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "saslContinue", 1),
|
||||
bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
|
||||
)
|
||||
saslContinueCmd := operation.NewCommand(doc).
|
||||
Database(sc.source).
|
||||
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
|
||||
ClusterClock(cfg.ClusterClock).
|
||||
ServerAPI(cfg.ServerAPI)
|
||||
|
||||
err = saslContinueCmd.Execute(ctx)
|
||||
if err != nil {
|
||||
return newError(err, sc.mechanism)
|
||||
}
|
||||
rdr = saslContinueCmd.Result()
|
||||
|
||||
err = bson.Unmarshal(rdr, &saslResp)
|
||||
if err != nil {
|
||||
fullErr := fmt.Errorf("unmarshal error: %v", err)
|
||||
return newError(fullErr, sc.mechanism)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ConductSaslConversation runs a full SASL conversation to authenticate the given connection.
|
||||
func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error {
|
||||
// Create a non-speculative SASL conversation.
|
||||
conversation := newSaslConversation(client, authSource, false)
|
||||
|
||||
saslStartDoc, err := conversation.FirstMessage()
|
||||
if err != nil {
|
||||
return newError(err, conversation.mechanism)
|
||||
}
|
||||
saslStartCmd := operation.NewCommand(saslStartDoc).
|
||||
Database(authSource).
|
||||
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
|
||||
ClusterClock(cfg.ClusterClock).
|
||||
ServerAPI(cfg.ServerAPI)
|
||||
if err := saslStartCmd.Execute(ctx); err != nil {
|
||||
return newError(err, conversation.mechanism)
|
||||
}
|
||||
|
||||
return conversation.Finish(ctx, cfg, saslStartCmd.Result())
|
||||
}
|
||||
130
mongo/x/mongo/driver/auth/scram.go
Normal file
130
mongo/x/mongo/driver/auth/scram.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Copyright (C) MongoDB, Inc. 2018-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/xdg-go/scram"
|
||||
"github.com/xdg-go/stringprep"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
const (
|
||||
// SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1"
|
||||
SCRAMSHA1 = "SCRAM-SHA-1"
|
||||
|
||||
// SCRAMSHA256 holds the mechanism name "SCRAM-SHA-256"
|
||||
SCRAMSHA256 = "SCRAM-SHA-256"
|
||||
)
|
||||
|
||||
var (
|
||||
// Additional options for the saslStart command to enable a shorter SCRAM conversation
|
||||
scramStartOptions bsoncore.Document = bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
|
||||
)
|
||||
)
|
||||
|
||||
func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) {
|
||||
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
|
||||
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
|
||||
if err != nil {
|
||||
return nil, newAuthError("error initializing SCRAM-SHA-1 client", err)
|
||||
}
|
||||
client.WithMinIterations(4096)
|
||||
return &ScramAuthenticator{
|
||||
mechanism: SCRAMSHA1,
|
||||
source: cred.Source,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) {
|
||||
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
|
||||
if err != nil {
|
||||
return nil, newAuthError(fmt.Sprintf("error SASLprepping password '%s'", cred.Password), err)
|
||||
}
|
||||
client, err := scram.SHA256.NewClientUnprepped(cred.Username, passprep, "")
|
||||
if err != nil {
|
||||
return nil, newAuthError("error initializing SCRAM-SHA-256 client", err)
|
||||
}
|
||||
client.WithMinIterations(4096)
|
||||
return &ScramAuthenticator{
|
||||
mechanism: SCRAMSHA256,
|
||||
source: cred.Source,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ScramAuthenticator uses the SCRAM algorithm over SASL to authenticate a connection.
|
||||
type ScramAuthenticator struct {
|
||||
mechanism string
|
||||
source string
|
||||
client *scram.Client
|
||||
}
|
||||
|
||||
var _ SpeculativeAuthenticator = (*ScramAuthenticator)(nil)
|
||||
|
||||
// Auth authenticates the provided connection by conducting a full SASL conversation.
|
||||
func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
err := ConductSaslConversation(ctx, cfg, a.source, a.createSaslClient())
|
||||
if err != nil {
|
||||
return newAuthError("sasl conversation error", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
|
||||
func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
|
||||
return newSaslConversation(a.createSaslClient(), a.source, true), nil
|
||||
}
|
||||
|
||||
func (a *ScramAuthenticator) createSaslClient() SaslClient {
|
||||
return &scramSaslAdapter{
|
||||
conversation: a.client.NewConversation(),
|
||||
mechanism: a.mechanism,
|
||||
}
|
||||
}
|
||||
|
||||
type scramSaslAdapter struct {
|
||||
mechanism string
|
||||
conversation *scram.ClientConversation
|
||||
}
|
||||
|
||||
var _ SaslClient = (*scramSaslAdapter)(nil)
|
||||
var _ ExtraOptionsSaslClient = (*scramSaslAdapter)(nil)
|
||||
|
||||
func (a *scramSaslAdapter) Start() (string, []byte, error) {
|
||||
step, err := a.conversation.Step("")
|
||||
if err != nil {
|
||||
return a.mechanism, nil, err
|
||||
}
|
||||
return a.mechanism, []byte(step), nil
|
||||
}
|
||||
|
||||
func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) {
|
||||
step, err := a.conversation.Step(string(challenge))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(step), nil
|
||||
}
|
||||
|
||||
func (a *scramSaslAdapter) Completed() bool {
|
||||
return a.conversation.Done()
|
||||
}
|
||||
|
||||
func (*scramSaslAdapter) StartCommandOptions() bsoncore.Document {
|
||||
return scramStartOptions
|
||||
}
|
||||
120
mongo/x/mongo/driver/auth/scram_test.go
Normal file
120
mongo/x/mongo/driver/auth/scram_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
|
||||
)
|
||||
|
||||
const (
|
||||
scramSha1Nonce = "fyko+d2lbbFgONRv9qkxdawL"
|
||||
scramSha256Nonce = "rOprNGfwEbeRWgbNEkqO"
|
||||
)
|
||||
|
||||
var (
|
||||
scramSha1ShortPayloads = [][]byte{
|
||||
[]byte("r=fyko+d2lbbFgONRv9qkxdawLHo+Vgk7qvUOKUwuWLIWg4l/9SraGMHEE,s=rQ9ZY3MntBeuP3E1TDVC4w==,i=10000"),
|
||||
[]byte("v=UMWeI25JD1yNYZRMpZ4VHvhZ9e0="),
|
||||
}
|
||||
scramSha256ShortPayloads = [][]byte{
|
||||
[]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"),
|
||||
[]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="),
|
||||
}
|
||||
scramSha1LongPayloads = append(scramSha1ShortPayloads, []byte{})
|
||||
scramSha256LongPayloads = append(scramSha256ShortPayloads, []byte{})
|
||||
)
|
||||
|
||||
func TestSCRAM(t *testing.T) {
|
||||
t.Run("conversation", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
createAuthenticatorFn func(*Cred) (Authenticator, error)
|
||||
payloads [][]byte
|
||||
nonce string
|
||||
}{
|
||||
{"scram-sha-1 short conversation", newScramSHA1Authenticator, scramSha1ShortPayloads, scramSha1Nonce},
|
||||
{"scram-sha-256 short conversation", newScramSHA256Authenticator, scramSha256ShortPayloads, scramSha256Nonce},
|
||||
{"scram-sha-1 long conversation", newScramSHA1Authenticator, scramSha1LongPayloads, scramSha1Nonce},
|
||||
{"scram-sha-256 long conversation", newScramSHA256Authenticator, scramSha256LongPayloads, scramSha256Nonce},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
authenticator, err := tc.createAuthenticatorFn(&Cred{
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
Source: "admin",
|
||||
})
|
||||
assert.Nil(t, err, "error creating authenticator: %v", err)
|
||||
sa, _ := authenticator.(*ScramAuthenticator)
|
||||
sa.client = sa.client.WithNonceGenerator(func() string {
|
||||
return tc.nonce
|
||||
})
|
||||
|
||||
responses := make(chan []byte, len(tc.payloads))
|
||||
writeReplies(responses, createSCRAMConversation(tc.payloads)...)
|
||||
|
||||
desc := description.Server{
|
||||
WireVersion: &description.VersionRange{
|
||||
Max: 4,
|
||||
},
|
||||
}
|
||||
conn := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, len(tc.payloads)),
|
||||
ReadResp: responses,
|
||||
Desc: desc,
|
||||
}
|
||||
|
||||
err = authenticator.Auth(context.Background(), &Config{Description: desc, Connection: conn})
|
||||
assert.Nil(t, err, "Auth error: %v\n", err)
|
||||
|
||||
// Verify that the first command sent is saslStart.
|
||||
assert.True(t, len(conn.Written) > 1, "wire messages were written to the connection")
|
||||
startCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing wire message: %v", err)
|
||||
cmdName := startCmd.Index(0).Key()
|
||||
assert.Equal(t, cmdName, "saslStart", "cmd name mismatch; expected 'saslStart', got %v", cmdName)
|
||||
|
||||
// Verify that the saslStart command always has {options: {skipEmptyExchange: true}}
|
||||
optionsVal, err := startCmd.LookupErr("options")
|
||||
assert.Nil(t, err, "no options found in saslStart command")
|
||||
optionsDoc := optionsVal.Document()
|
||||
assert.Equal(t, optionsDoc, scramStartOptions, "expected options %v, got %v", scramStartOptions, optionsDoc)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createSCRAMConversation(payloads [][]byte) []bsoncore.Document {
|
||||
responses := make([]bsoncore.Document, len(payloads))
|
||||
for idx, payload := range payloads {
|
||||
res := createSCRAMServerResponse(payload, idx == len(payloads)-1)
|
||||
responses[idx] = res
|
||||
}
|
||||
return responses
|
||||
}
|
||||
|
||||
func createSCRAMServerResponse(payload []byte, done bool) bsoncore.Document {
|
||||
return bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "conversationId", 1),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
|
||||
bsoncore.AppendBooleanElement(nil, "done", done),
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
)
|
||||
}
|
||||
|
||||
func writeReplies(c chan []byte, docs ...bsoncore.Document) {
|
||||
for _, doc := range docs {
|
||||
reply := drivertest.MakeReply(doc)
|
||||
c <- reply
|
||||
}
|
||||
}
|
||||
253
mongo/x/mongo/driver/auth/speculative_scram_test.go
Normal file
253
mongo/x/mongo/driver/auth/speculative_scram_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
|
||||
)
|
||||
|
||||
var (
|
||||
// The base elements for a hello response.
|
||||
handshakeHelloElements = [][]byte{
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
bsoncore.AppendBooleanElement(nil, internal.LegacyHelloLowercase, true),
|
||||
bsoncore.AppendInt32Element(nil, "maxBsonObjectSize", 16777216),
|
||||
bsoncore.AppendInt32Element(nil, "maxMessageSizeBytes", 48000000),
|
||||
bsoncore.AppendInt32Element(nil, "minWireVersion", 0),
|
||||
bsoncore.AppendInt32Element(nil, "maxWireVersion", 6),
|
||||
}
|
||||
// The first payload sent by the driver for SCRAM-SHA-1/256 authentication.
|
||||
firstScramSha1ClientPayload = []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL")
|
||||
firstScramSha256ClientPayload = []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO")
|
||||
)
|
||||
|
||||
func TestSpeculativeSCRAM(t *testing.T) {
|
||||
cred := &Cred{
|
||||
Username: "user",
|
||||
Password: "pencil",
|
||||
PasswordSet: true,
|
||||
Source: "admin",
|
||||
}
|
||||
|
||||
t.Run("speculative response included", func(t *testing.T) {
|
||||
// Tests for SCRAM-SHA1 and SCRAM-SHA-256 when the hello response contains a reply to the speculative
|
||||
// authentication attempt. The driver should only send a saslContinue after the hello to complete
|
||||
// authentication.
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
mechanism string
|
||||
firstClientPayload []byte
|
||||
payloads [][]byte
|
||||
nonce string
|
||||
}{
|
||||
{"SCRAM-SHA-1", "SCRAM-SHA-1", firstScramSha1ClientPayload, scramSha1ShortPayloads, scramSha1Nonce},
|
||||
{"SCRAM-SHA-256", "SCRAM-SHA-256", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
|
||||
{"Default", "", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a SCRAM authenticator and overwrite the nonce generator to make the conversation
|
||||
// deterministic.
|
||||
authenticator, err := CreateAuthenticator(tc.mechanism, cred)
|
||||
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
|
||||
setNonce(t, authenticator, tc.nonce)
|
||||
|
||||
// Create a Handshaker and fake connection to authenticate.
|
||||
handshaker := Handshaker(nil, &HandshakeOptions{
|
||||
Authenticator: authenticator,
|
||||
DBUser: "admin.user",
|
||||
})
|
||||
responses := make(chan []byte, len(tc.payloads))
|
||||
writeReplies(responses, createSpeculativeSCRAMHandshake(tc.payloads)...)
|
||||
|
||||
conn := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, len(tc.payloads)),
|
||||
ReadResp: responses,
|
||||
}
|
||||
|
||||
// Do both parts of the handshake.
|
||||
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
|
||||
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
|
||||
assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
|
||||
conn.Desc = info.Description // Set conn.Desc so the new description will be used for the authentication.
|
||||
|
||||
err = handshaker.FinishHandshake(context.Background(), conn)
|
||||
assert.Nil(t, err, "FinishHandshake error: %v", err)
|
||||
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
|
||||
|
||||
// Assert that the driver sent hello with the speculative authentication message.
|
||||
assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d",
|
||||
len(tc.payloads), (conn.Written))
|
||||
helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing hello command: %v", err)
|
||||
assertCommandName(t, helloCmd, internal.LegacyHello)
|
||||
|
||||
// Assert that the correct document was sent for speculative authentication.
|
||||
authDocVal, err := helloCmd.LookupErr("speculativeAuthenticate")
|
||||
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(helloCmd))
|
||||
authDoc := authDocVal.Document()
|
||||
sentMechanism := tc.mechanism
|
||||
if sentMechanism == "" {
|
||||
sentMechanism = "SCRAM-SHA-256"
|
||||
}
|
||||
|
||||
expectedAuthDoc := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "saslStart", 1),
|
||||
bsoncore.AppendStringElement(nil, "mechanism", sentMechanism),
|
||||
bsoncore.AppendBinaryElement(nil, "payload", 0x00, tc.firstClientPayload),
|
||||
bsoncore.AppendStringElement(nil, "db", "admin"),
|
||||
bsoncore.AppendDocumentElement(nil, "options", bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
|
||||
)),
|
||||
)
|
||||
assert.True(t, bytes.Equal(expectedAuthDoc, authDoc),
|
||||
"expected speculative auth document %s, got %s",
|
||||
bson.Raw(expectedAuthDoc),
|
||||
authDoc,
|
||||
)
|
||||
|
||||
// Assert that the last command sent in the handshake is saslContinue.
|
||||
|
||||
saslContinueCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing saslContinue command: %v", err)
|
||||
assertCommandName(t, saslContinueCmd, "saslContinue")
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("speculative response not included", func(t *testing.T) {
|
||||
// Tests for SCRAM-SHA-1 and SCRAM-SHA-256 when the hello response does not contain a reply to the
|
||||
// speculative authentication attempt. The driver should send both saslStart and saslContinue after the initial
|
||||
// hello.
|
||||
|
||||
// There is no test for the default mechanism because we can't control the nonce used for the actual
|
||||
// authentication attempt after the speculative attempt fails.
|
||||
|
||||
testCases := []struct {
|
||||
mechanism string
|
||||
payloads [][]byte
|
||||
nonce string
|
||||
}{
|
||||
{"SCRAM-SHA-1", scramSha1ShortPayloads, scramSha1Nonce},
|
||||
{"SCRAM-SHA-256", scramSha256ShortPayloads, scramSha256Nonce},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.mechanism, func(t *testing.T) {
|
||||
authenticator, err := CreateAuthenticator(tc.mechanism, cred)
|
||||
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
|
||||
setNonce(t, authenticator, tc.nonce)
|
||||
|
||||
handshaker := Handshaker(nil, &HandshakeOptions{
|
||||
Authenticator: authenticator,
|
||||
DBUser: "admin.user",
|
||||
})
|
||||
numResponses := len(tc.payloads) + 1 // +1 for hello response
|
||||
responses := make(chan []byte, numResponses)
|
||||
writeReplies(responses, createRegularSCRAMHandshake(tc.payloads)...)
|
||||
|
||||
conn := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, numResponses),
|
||||
ReadResp: responses,
|
||||
}
|
||||
|
||||
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
|
||||
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
|
||||
assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
|
||||
bson.Raw(info.SpeculativeAuthenticate))
|
||||
conn.Desc = info.Description
|
||||
|
||||
err = handshaker.FinishHandshake(context.Background(), conn)
|
||||
assert.Nil(t, err, "FinishHandshake error: %v", err)
|
||||
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
|
||||
|
||||
assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
|
||||
numResponses, len(conn.Written))
|
||||
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing hello command: %v", err)
|
||||
assertCommandName(t, hello, internal.LegacyHello)
|
||||
_, err = hello.LookupErr("speculativeAuthenticate")
|
||||
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
|
||||
|
||||
saslStart, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing saslStart command: %v", err)
|
||||
assertCommandName(t, saslStart, "saslStart")
|
||||
|
||||
saslContinue, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing saslContinue command: %v", err)
|
||||
assertCommandName(t, saslContinue, "saslContinue")
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func setNonce(t *testing.T, authenticator Authenticator, nonce string) {
|
||||
t.Helper()
|
||||
nonceGenerator := func() string {
|
||||
return nonce
|
||||
}
|
||||
|
||||
switch converted := authenticator.(type) {
|
||||
case *ScramAuthenticator:
|
||||
converted.client = converted.client.WithNonceGenerator(nonceGenerator)
|
||||
case *DefaultAuthenticator:
|
||||
sa := converted.speculativeAuthenticator.(*ScramAuthenticator)
|
||||
sa.client = sa.client.WithNonceGenerator(nonceGenerator)
|
||||
default:
|
||||
t.Fatalf("invalid authenticator type %T", authenticator)
|
||||
}
|
||||
}
|
||||
|
||||
// createSpeculativeSCRAMHandshake creates the server replies for a successful speculative SCRAM authentication attempt.
|
||||
// There are two replies:
|
||||
//
|
||||
// 1. hello reply containing a "speculativeAuthenticate" document.
|
||||
// 2. saslContinue reply with done:true
|
||||
func createSpeculativeSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
|
||||
firstAuthResponse := createSCRAMServerResponse(payloads[0], false)
|
||||
firstAuthElem := bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", firstAuthResponse)
|
||||
hello := bsoncore.BuildDocumentFromElements(nil, append(handshakeHelloElements, firstAuthElem)...)
|
||||
|
||||
responses := []bsoncore.Document{hello}
|
||||
for idx := 1; idx < len(payloads); idx++ {
|
||||
responses = append(responses, createSCRAMServerResponse(payloads[idx], idx == len(payloads)-1))
|
||||
}
|
||||
return responses
|
||||
}
|
||||
|
||||
// createRegularSCRAMHandshake creates the server replies for a handshake + SCRAM authentication attempt. There are
|
||||
// three replies:
|
||||
//
|
||||
// 1. hello reply
|
||||
// 2. saslStart reply with done:false
|
||||
// 3. saslContinue reply with done:true
|
||||
func createRegularSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
|
||||
hello := bsoncore.BuildDocumentFromElements(nil, handshakeHelloElements...)
|
||||
responses := []bsoncore.Document{hello}
|
||||
|
||||
for idx, payload := range payloads {
|
||||
responses = append(responses, createSCRAMServerResponse(payload, idx == len(payloads)-1))
|
||||
}
|
||||
return responses
|
||||
}
|
||||
|
||||
func assertCommandName(t *testing.T, cmd bsoncore.Document, expectedName string) {
|
||||
t.Helper()
|
||||
|
||||
actualName := cmd.Index(0).Key()
|
||||
assert.Equal(t, expectedName, actualName, "expected command name '%s', got '%s'", expectedName, actualName)
|
||||
}
|
||||
136
mongo/x/mongo/driver/auth/speculative_x509_test.go
Normal file
136
mongo/x/mongo/driver/auth/speculative_x509_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
|
||||
)
|
||||
|
||||
var (
|
||||
x509Response bsoncore.Document = bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendStringElement(nil, "dbname", "$external"),
|
||||
bsoncore.AppendStringElement(nil, "user", "username"),
|
||||
bsoncore.AppendInt32Element(nil, "ok", 1),
|
||||
)
|
||||
)
|
||||
|
||||
func TestSpeculativeX509(t *testing.T) {
|
||||
t.Run("speculative response included", func(t *testing.T) {
|
||||
// Tests for X509 when the hello response contains a reply to the speculative authentication attempt. The
|
||||
// driver should not send any more commands after the hello.
|
||||
|
||||
authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{})
|
||||
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
|
||||
handshaker := Handshaker(nil, &HandshakeOptions{
|
||||
Authenticator: authenticator,
|
||||
})
|
||||
|
||||
numResponses := 1
|
||||
responses := make(chan []byte, numResponses)
|
||||
writeReplies(responses, createSpeculativeX509Handshake()...)
|
||||
|
||||
conn := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, numResponses),
|
||||
ReadResp: responses,
|
||||
}
|
||||
|
||||
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
|
||||
assert.Nil(t, err, "GetDescription error: %v", err)
|
||||
assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
|
||||
conn.Desc = info.Description
|
||||
|
||||
err = handshaker.FinishHandshake(context.Background(), conn)
|
||||
assert.Nil(t, err, "FinishHandshake error: %v", err)
|
||||
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
|
||||
|
||||
assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
|
||||
numResponses, len(conn.Written))
|
||||
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing hello command: %v", err)
|
||||
assertCommandName(t, hello, internal.LegacyHello)
|
||||
|
||||
authDocVal, err := hello.LookupErr("speculativeAuthenticate")
|
||||
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
|
||||
authDoc := authDocVal.Document()
|
||||
expectedAuthDoc := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, "authenticate", 1),
|
||||
bsoncore.AppendStringElement(nil, "mechanism", "MONGODB-X509"),
|
||||
)
|
||||
assert.True(t, bytes.Equal(expectedAuthDoc, authDoc), "expected speculative auth document %s, got %s",
|
||||
expectedAuthDoc, authDoc)
|
||||
})
|
||||
t.Run("speculative response not included", func(t *testing.T) {
|
||||
// Tests for X509 when the hello response does not contain a reply to the speculative authentication attempt.
|
||||
// The driver should send an authenticate command after the hello.
|
||||
|
||||
authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{})
|
||||
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
|
||||
handshaker := Handshaker(nil, &HandshakeOptions{
|
||||
Authenticator: authenticator,
|
||||
})
|
||||
|
||||
numResponses := 2
|
||||
responses := make(chan []byte, numResponses)
|
||||
writeReplies(responses, createRegularX509Handshake()...)
|
||||
|
||||
conn := &drivertest.ChannelConn{
|
||||
Written: make(chan []byte, numResponses),
|
||||
ReadResp: responses,
|
||||
}
|
||||
|
||||
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
|
||||
assert.Nil(t, err, "GetDescription error: %v", err)
|
||||
assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
|
||||
bson.Raw(info.SpeculativeAuthenticate))
|
||||
conn.Desc = info.Description
|
||||
|
||||
err = handshaker.FinishHandshake(context.Background(), conn)
|
||||
assert.Nil(t, err, "FinishHandshake error: %v", err)
|
||||
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
|
||||
|
||||
assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
|
||||
numResponses, len(conn.Written))
|
||||
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing hello command: %v", err)
|
||||
assertCommandName(t, hello, internal.LegacyHello)
|
||||
_, err = hello.LookupErr("speculativeAuthenticate")
|
||||
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
|
||||
|
||||
authenticate, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
|
||||
assert.Nil(t, err, "error parsing authenticate command: %v", err)
|
||||
assertCommandName(t, authenticate, "authenticate")
|
||||
})
|
||||
}
|
||||
|
||||
// createSpeculativeX509Handshake creates the server replies for a successful speculative X509 authentication attempt.
|
||||
// There is only one reply:
|
||||
//
|
||||
// 1. hello reply containing a "speculativeAuthenticate" document.
|
||||
func createSpeculativeX509Handshake() []bsoncore.Document {
|
||||
firstAuthElem := bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", x509Response)
|
||||
hello := bsoncore.BuildDocumentFromElements(nil, append(handshakeHelloElements, firstAuthElem)...)
|
||||
return []bsoncore.Document{hello}
|
||||
}
|
||||
|
||||
// createSpeculativeX509Handshake creates the server replies for a handshake + X509 authentication attempt.
|
||||
// There are two replies:
|
||||
//
|
||||
// 1. hello reply
|
||||
// 2. authenticate reply
|
||||
func createRegularX509Handshake() []bsoncore.Document {
|
||||
hello := bsoncore.BuildDocumentFromElements(nil, handshakeHelloElements...)
|
||||
return []bsoncore.Document{hello, x509Response}
|
||||
}
|
||||
30
mongo/x/mongo/driver/auth/util.go
Normal file
30
mongo/x/mongo/driver/auth/util.go
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
|
||||
// to use MD5 here to implement the SCRAM specification.
|
||||
/* #nosec G501 */
|
||||
"crypto/md5"
|
||||
)
|
||||
|
||||
const defaultAuthDB = "admin"
|
||||
|
||||
func mongoPasswordDigest(username, password string) string {
|
||||
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
|
||||
// implement the SCRAM specification.
|
||||
/* #nosec G401 */
|
||||
h := md5.New()
|
||||
_, _ = io.WriteString(h, username)
|
||||
_, _ = io.WriteString(h, ":mongo:")
|
||||
_, _ = io.WriteString(h, password)
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
85
mongo/x/mongo/driver/auth/x509.go
Normal file
85
mongo/x/mongo/driver/auth/x509.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
// MongoDBX509 is the mechanism name for MongoDBX509.
|
||||
const MongoDBX509 = "MONGODB-X509"
|
||||
|
||||
func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) {
|
||||
return &MongoDBX509Authenticator{User: cred.Username}, nil
|
||||
}
|
||||
|
||||
// MongoDBX509Authenticator uses X.509 certificates over TLS to authenticate a connection.
|
||||
type MongoDBX509Authenticator struct {
|
||||
User string
|
||||
}
|
||||
|
||||
var _ SpeculativeAuthenticator = (*MongoDBX509Authenticator)(nil)
|
||||
|
||||
// x509 represents a X509 authentication conversation. This type implements the SpeculativeConversation interface so the
|
||||
// conversation can be executed in multi-step speculative fashion.
|
||||
type x509Conversation struct{}
|
||||
|
||||
var _ SpeculativeConversation = (*x509Conversation)(nil)
|
||||
|
||||
// FirstMessage returns the first message to be sent to the server.
|
||||
func (c *x509Conversation) FirstMessage() (bsoncore.Document, error) {
|
||||
return createFirstX509Message(description.Server{}, ""), nil
|
||||
}
|
||||
|
||||
// createFirstX509Message creates the first message for the X509 conversation.
|
||||
func createFirstX509Message(desc description.Server, user string) bsoncore.Document {
|
||||
elements := [][]byte{
|
||||
bsoncore.AppendInt32Element(nil, "authenticate", 1),
|
||||
bsoncore.AppendStringElement(nil, "mechanism", MongoDBX509),
|
||||
}
|
||||
|
||||
// Server versions < 3.4 require the username to be included in the message. Versions >= 3.4 will extract the
|
||||
// username from the certificate.
|
||||
if desc.WireVersion != nil && desc.WireVersion.Max < 5 {
|
||||
elements = append(elements, bsoncore.AppendStringElement(nil, "user", user))
|
||||
}
|
||||
|
||||
return bsoncore.BuildDocument(nil, elements...)
|
||||
}
|
||||
|
||||
// Finish implements the SpeculativeConversation interface and is a no-op because an X509 conversation only has one
|
||||
// step.
|
||||
func (c *x509Conversation) Finish(context.Context, *Config, bsoncore.Document) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateSpeculativeConversation creates a speculative conversation for X509 authentication.
|
||||
func (a *MongoDBX509Authenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
|
||||
return &x509Conversation{}, nil
|
||||
}
|
||||
|
||||
// Auth authenticates the provided connection by conducting an X509 authentication conversation.
|
||||
func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error {
|
||||
requestDoc := createFirstX509Message(cfg.Description, a.User)
|
||||
authCmd := operation.
|
||||
NewCommand(requestDoc).
|
||||
Database("$external").
|
||||
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
|
||||
ClusterClock(cfg.ClusterClock).
|
||||
ServerAPI(cfg.ServerAPI)
|
||||
err := authCmd.Execute(ctx)
|
||||
if err != nil {
|
||||
return newAuthError("round trip error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
479
mongo/x/mongo/driver/batch_cursor.go
Normal file
479
mongo/x/mongo/driver/batch_cursor.go
Normal file
@@ -0,0 +1,479 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead
|
||||
// of one at a time. An individual document cursor can be built on top of this batch cursor.
|
||||
type BatchCursor struct {
|
||||
clientSession *session.Client
|
||||
clock *session.ClusterClock
|
||||
comment bsoncore.Value
|
||||
database string
|
||||
collection string
|
||||
id int64
|
||||
err error
|
||||
server Server
|
||||
serverDescription description.Server
|
||||
errorProcessor ErrorProcessor // This will only be set when pinning to a connection.
|
||||
connection PinnedConnection
|
||||
batchSize int32
|
||||
maxTimeMS int64
|
||||
currentBatch *bsoncore.DocumentSequence
|
||||
firstBatch bool
|
||||
cmdMonitor *event.CommandMonitor
|
||||
postBatchResumeToken bsoncore.Document
|
||||
crypt Crypt
|
||||
serverAPI *ServerAPIOptions
|
||||
|
||||
// legacy server (< 3.2) fields
|
||||
limit int32
|
||||
numReturned int32 // number of docs returned by server
|
||||
}
|
||||
|
||||
// CursorResponse represents the response from a command the results in a cursor. A BatchCursor can
|
||||
// be constructed from a CursorResponse.
|
||||
type CursorResponse struct {
|
||||
Server Server
|
||||
ErrorProcessor ErrorProcessor // This will only be set when pinning to a connection.
|
||||
Connection PinnedConnection
|
||||
Desc description.Server
|
||||
FirstBatch *bsoncore.DocumentSequence
|
||||
Database string
|
||||
Collection string
|
||||
ID int64
|
||||
postBatchResumeToken bsoncore.Document
|
||||
}
|
||||
|
||||
// NewCursorResponse constructs a cursor response from the given response and server. This method
|
||||
// can be used within the ProcessResponse method for an operation.
|
||||
func NewCursorResponse(info ResponseInfo) (CursorResponse, error) {
|
||||
response := info.ServerResponse
|
||||
cur, ok := response.Lookup("cursor").DocumentOK()
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("cursor should be an embedded document but is of BSON type %s", response.Lookup("cursor").Type)
|
||||
}
|
||||
elems, err := cur.Elements()
|
||||
if err != nil {
|
||||
return CursorResponse{}, err
|
||||
}
|
||||
curresp := CursorResponse{Server: info.Server, Desc: info.ConnectionDescription}
|
||||
|
||||
for _, elem := range elems {
|
||||
switch elem.Key() {
|
||||
case "firstBatch":
|
||||
arr, ok := elem.Value().ArrayOK()
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("firstBatch should be an array but is a BSON %s", elem.Value().Type)
|
||||
}
|
||||
curresp.FirstBatch = &bsoncore.DocumentSequence{Style: bsoncore.ArrayStyle, Data: arr}
|
||||
case "ns":
|
||||
ns, ok := elem.Value().StringValueOK()
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("ns should be a string but is a BSON %s", elem.Value().Type)
|
||||
}
|
||||
index := strings.Index(ns, ".")
|
||||
if index == -1 {
|
||||
return CursorResponse{}, errors.New("ns field must contain a valid namespace, but is missing '.'")
|
||||
}
|
||||
curresp.Database = ns[:index]
|
||||
curresp.Collection = ns[index+1:]
|
||||
case "id":
|
||||
curresp.ID, ok = elem.Value().Int64OK()
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
|
||||
}
|
||||
case "postBatchResumeToken":
|
||||
curresp.postBatchResumeToken, ok = elem.Value().DocumentOK()
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("post batch resume token should be a document but it is a BSON %s", elem.Value().Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the deployment is behind a load balancer and the cursor has a non-zero ID, pin the cursor to a connection and
|
||||
// use the same connection to execute getMore and killCursors commands.
|
||||
if curresp.Desc.LoadBalanced() && curresp.ID != 0 {
|
||||
// Cache the server as an ErrorProcessor to use when constructing deployments for cursor commands.
|
||||
ep, ok := curresp.Server.(ErrorProcessor)
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("expected Server used to establish a cursor to implement ErrorProcessor, but got %T", curresp.Server)
|
||||
}
|
||||
curresp.ErrorProcessor = ep
|
||||
|
||||
refConn, ok := info.Connection.(PinnedConnection)
|
||||
if !ok {
|
||||
return CursorResponse{}, fmt.Errorf("expected Connection used to establish a cursor to implement PinnedConnection, but got %T", info.Connection)
|
||||
}
|
||||
if err := refConn.PinToCursor(); err != nil {
|
||||
return CursorResponse{}, fmt.Errorf("error incrementing connection reference count when creating a cursor: %v", err)
|
||||
}
|
||||
curresp.Connection = refConn
|
||||
}
|
||||
|
||||
return curresp, nil
|
||||
}
|
||||
|
||||
// CursorOptions are extra options that are required to construct a BatchCursor.
|
||||
type CursorOptions struct {
|
||||
BatchSize int32
|
||||
Comment bsoncore.Value
|
||||
MaxTimeMS int64
|
||||
Limit int32
|
||||
CommandMonitor *event.CommandMonitor
|
||||
Crypt Crypt
|
||||
ServerAPI *ServerAPIOptions
|
||||
}
|
||||
|
||||
// NewBatchCursor creates a new BatchCursor from the provided parameters.
|
||||
func NewBatchCursor(cr CursorResponse, clientSession *session.Client, clock *session.ClusterClock, opts CursorOptions) (*BatchCursor, error) {
|
||||
ds := cr.FirstBatch
|
||||
bc := &BatchCursor{
|
||||
clientSession: clientSession,
|
||||
clock: clock,
|
||||
comment: opts.Comment,
|
||||
database: cr.Database,
|
||||
collection: cr.Collection,
|
||||
id: cr.ID,
|
||||
server: cr.Server,
|
||||
connection: cr.Connection,
|
||||
errorProcessor: cr.ErrorProcessor,
|
||||
batchSize: opts.BatchSize,
|
||||
maxTimeMS: opts.MaxTimeMS,
|
||||
cmdMonitor: opts.CommandMonitor,
|
||||
firstBatch: true,
|
||||
postBatchResumeToken: cr.postBatchResumeToken,
|
||||
crypt: opts.Crypt,
|
||||
serverAPI: opts.ServerAPI,
|
||||
serverDescription: cr.Desc,
|
||||
}
|
||||
|
||||
if ds != nil {
|
||||
bc.numReturned = int32(ds.DocumentCount())
|
||||
}
|
||||
if cr.Desc.WireVersion == nil || cr.Desc.WireVersion.Max < 4 {
|
||||
bc.limit = opts.Limit
|
||||
|
||||
// Take as many documents from the batch as needed.
|
||||
if bc.limit != 0 && bc.limit < bc.numReturned {
|
||||
for i := int32(0); i < bc.limit; i++ {
|
||||
_, err := ds.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
ds.Data = ds.Data[:ds.Pos]
|
||||
ds.ResetIterator()
|
||||
}
|
||||
}
|
||||
|
||||
bc.currentBatch = ds
|
||||
return bc, nil
|
||||
}
|
||||
|
||||
// NewEmptyBatchCursor returns a batch cursor that is empty.
|
||||
func NewEmptyBatchCursor() *BatchCursor {
|
||||
return &BatchCursor{currentBatch: new(bsoncore.DocumentSequence)}
|
||||
}
|
||||
|
||||
// NewBatchCursorFromDocuments returns a batch cursor with current batch set to a sequence-style
|
||||
// DocumentSequence containing the provided documents.
|
||||
func NewBatchCursorFromDocuments(documents []byte) *BatchCursor {
|
||||
return &BatchCursor{
|
||||
currentBatch: &bsoncore.DocumentSequence{
|
||||
Data: documents,
|
||||
Style: bsoncore.SequenceStyle,
|
||||
},
|
||||
// BatchCursors created with this function have no associated ID nor server, so no getMore
|
||||
// calls will be made.
|
||||
id: 0,
|
||||
server: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// ID returns the cursor ID for this batch cursor.
|
||||
func (bc *BatchCursor) ID() int64 {
|
||||
return bc.id
|
||||
}
|
||||
|
||||
// Next indicates if there is another batch available. Returning false does not necessarily indicate
|
||||
// that the cursor is closed. This method will return false when an empty batch is returned.
|
||||
//
|
||||
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
|
||||
// is not a valid batch of documents available.
|
||||
func (bc *BatchCursor) Next(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
if bc.firstBatch {
|
||||
bc.firstBatch = false
|
||||
return !bc.currentBatch.Empty()
|
||||
}
|
||||
|
||||
if bc.id == 0 || bc.server == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
bc.getMore(ctx)
|
||||
|
||||
return !bc.currentBatch.Empty()
|
||||
}
|
||||
|
||||
// Batch will return a DocumentSequence for the current batch of documents. The returned
|
||||
// DocumentSequence is only valid until the next call to Next or Close.
|
||||
func (bc *BatchCursor) Batch() *bsoncore.DocumentSequence { return bc.currentBatch }
|
||||
|
||||
// Err returns the latest error encountered.
|
||||
func (bc *BatchCursor) Err() error { return bc.err }
|
||||
|
||||
// Close closes this batch cursor.
|
||||
func (bc *BatchCursor) Close(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
err := bc.KillCursor(ctx)
|
||||
bc.id = 0
|
||||
bc.currentBatch.Data = nil
|
||||
bc.currentBatch.Style = 0
|
||||
bc.currentBatch.ResetIterator()
|
||||
|
||||
connErr := bc.unpinConnection()
|
||||
if err == nil {
|
||||
err = connErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (bc *BatchCursor) unpinConnection() error {
|
||||
if bc.connection == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := bc.connection.UnpinFromCursor()
|
||||
closeErr := bc.connection.Close()
|
||||
if err == nil && closeErr != nil {
|
||||
err = closeErr
|
||||
}
|
||||
bc.connection = nil
|
||||
return err
|
||||
}
|
||||
|
||||
// Server returns the server for this cursor.
|
||||
func (bc *BatchCursor) Server() Server {
|
||||
return bc.server
|
||||
}
|
||||
|
||||
func (bc *BatchCursor) clearBatch() {
|
||||
bc.currentBatch.Data = bc.currentBatch.Data[:0]
|
||||
}
|
||||
|
||||
// KillCursor kills cursor on server without closing batch cursor
|
||||
func (bc *BatchCursor) KillCursor(ctx context.Context) error {
|
||||
if bc.server == nil || bc.id == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return Operation{
|
||||
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
|
||||
dst = bsoncore.AppendStringElement(dst, "killCursors", bc.collection)
|
||||
dst = bsoncore.BuildArrayElement(dst, "cursors", bsoncore.Value{Type: bsontype.Int64, Data: bsoncore.AppendInt64(nil, bc.id)})
|
||||
return dst, nil
|
||||
},
|
||||
Database: bc.database,
|
||||
Deployment: bc.getOperationDeployment(),
|
||||
Client: bc.clientSession,
|
||||
Clock: bc.clock,
|
||||
Legacy: LegacyKillCursors,
|
||||
CommandMonitor: bc.cmdMonitor,
|
||||
ServerAPI: bc.serverAPI,
|
||||
}.Execute(ctx)
|
||||
}
|
||||
|
||||
// calcGetMoreBatchSize calculates the number of documents to return in the
|
||||
// response of a "getMore" operation based on the given limit, batchSize, and
|
||||
// number of documents already returned. Returns false if a non-trivial limit is
|
||||
// lower than or equal to the number of documents already returned.
|
||||
func calcGetMoreBatchSize(bc BatchCursor) (int32, bool) {
|
||||
gmBatchSize := bc.batchSize
|
||||
|
||||
// Account for legacy operations that don't support setting a limit.
|
||||
if bc.limit != 0 && bc.numReturned+bc.batchSize >= bc.limit {
|
||||
gmBatchSize = bc.limit - bc.numReturned
|
||||
if gmBatchSize <= 0 {
|
||||
return gmBatchSize, false
|
||||
}
|
||||
}
|
||||
|
||||
return gmBatchSize, true
|
||||
}
|
||||
|
||||
func (bc *BatchCursor) getMore(ctx context.Context) {
|
||||
bc.clearBatch()
|
||||
if bc.id == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
numToReturn, ok := calcGetMoreBatchSize(*bc)
|
||||
if !ok {
|
||||
if err := bc.Close(ctx); err != nil {
|
||||
bc.err = err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
bc.err = Operation{
|
||||
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
|
||||
dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id)
|
||||
dst = bsoncore.AppendStringElement(dst, "collection", bc.collection)
|
||||
if numToReturn > 0 {
|
||||
dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn)
|
||||
}
|
||||
if bc.maxTimeMS > 0 {
|
||||
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", bc.maxTimeMS)
|
||||
}
|
||||
// The getMore command does not support commenting pre-4.4.
|
||||
if bc.comment.Type != bsontype.Type(0) && bc.serverDescription.WireVersion.Max >= 9 {
|
||||
dst = bsoncore.AppendValueElement(dst, "comment", bc.comment)
|
||||
}
|
||||
return dst, nil
|
||||
},
|
||||
Database: bc.database,
|
||||
Deployment: bc.getOperationDeployment(),
|
||||
ProcessResponseFn: func(info ResponseInfo) error {
|
||||
response := info.ServerResponse
|
||||
id, ok := response.Lookup("cursor", "id").Int64OK()
|
||||
if !ok {
|
||||
return fmt.Errorf("cursor.id should be an int64 but is a BSON %s", response.Lookup("cursor", "id").Type)
|
||||
}
|
||||
bc.id = id
|
||||
|
||||
batch, ok := response.Lookup("cursor", "nextBatch").ArrayOK()
|
||||
if !ok {
|
||||
return fmt.Errorf("cursor.nextBatch should be an array but is a BSON %s", response.Lookup("cursor", "nextBatch").Type)
|
||||
}
|
||||
bc.currentBatch.Style = bsoncore.ArrayStyle
|
||||
bc.currentBatch.Data = batch
|
||||
bc.currentBatch.ResetIterator()
|
||||
bc.numReturned += int32(bc.currentBatch.DocumentCount()) // Required for legacy operations which don't support limit.
|
||||
|
||||
pbrt, err := response.LookupErr("cursor", "postBatchResumeToken")
|
||||
if err != nil {
|
||||
// I don't really understand why we don't set bc.err here
|
||||
return nil
|
||||
}
|
||||
|
||||
pbrtDoc, ok := pbrt.DocumentOK()
|
||||
if !ok {
|
||||
bc.err = fmt.Errorf("expected BSON type for post batch resume token to be EmbeddedDocument but got %s", pbrt.Type)
|
||||
return nil
|
||||
}
|
||||
|
||||
bc.postBatchResumeToken = pbrtDoc
|
||||
|
||||
return nil
|
||||
},
|
||||
Client: bc.clientSession,
|
||||
Clock: bc.clock,
|
||||
Legacy: LegacyGetMore,
|
||||
CommandMonitor: bc.cmdMonitor,
|
||||
Crypt: bc.crypt,
|
||||
ServerAPI: bc.serverAPI,
|
||||
}.Execute(ctx)
|
||||
|
||||
// Once the cursor has been drained, we can unpin the connection if one is currently pinned.
|
||||
if bc.id == 0 {
|
||||
err := bc.unpinConnection()
|
||||
if err != nil && bc.err == nil {
|
||||
bc.err = err
|
||||
}
|
||||
}
|
||||
|
||||
// If we're in load balanced mode and the pinned connection encounters a network error, we should not use it for
|
||||
// future commands. Per the spec, the connection will not be unpinned until the cursor is actually closed, but
|
||||
// we set the cursor ID to 0 to ensure the Close() call will not execute a killCursors command.
|
||||
if driverErr, ok := bc.err.(Error); ok && driverErr.NetworkError() && bc.connection != nil {
|
||||
bc.id = 0
|
||||
}
|
||||
|
||||
// Required for legacy operations which don't support limit.
|
||||
if bc.limit != 0 && bc.numReturned >= bc.limit {
|
||||
// call KillCursor instead of Close because Close will clear out the data for the current batch.
|
||||
err := bc.KillCursor(ctx)
|
||||
if err != nil && bc.err == nil {
|
||||
bc.err = err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PostBatchResumeToken returns the latest seen post batch resume token.
|
||||
func (bc *BatchCursor) PostBatchResumeToken() bsoncore.Document {
|
||||
return bc.postBatchResumeToken
|
||||
}
|
||||
|
||||
// SetBatchSize sets the batchSize for future getMores.
|
||||
func (bc *BatchCursor) SetBatchSize(size int32) {
|
||||
bc.batchSize = size
|
||||
}
|
||||
|
||||
func (bc *BatchCursor) getOperationDeployment() Deployment {
|
||||
if bc.connection != nil {
|
||||
return &loadBalancedCursorDeployment{
|
||||
errorProcessor: bc.errorProcessor,
|
||||
conn: bc.connection,
|
||||
}
|
||||
}
|
||||
return SingleServerDeployment{bc.server}
|
||||
}
|
||||
|
||||
// loadBalancedCursorDeployment is used as a Deployment for getMore and killCursors commands when pinning to a
|
||||
// connection in load balanced mode. This type also functions as an ErrorProcessor to ensure that SDAM errors are
|
||||
// handled for these commands in this mode.
|
||||
type loadBalancedCursorDeployment struct {
|
||||
errorProcessor ErrorProcessor
|
||||
conn PinnedConnection
|
||||
}
|
||||
|
||||
var _ Deployment = (*loadBalancedCursorDeployment)(nil)
|
||||
var _ Server = (*loadBalancedCursorDeployment)(nil)
|
||||
var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil)
|
||||
|
||||
func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ description.ServerSelector) (Server, error) {
|
||||
return lbcd, nil
|
||||
}
|
||||
|
||||
func (lbcd *loadBalancedCursorDeployment) Kind() description.TopologyKind {
|
||||
return description.LoadBalanced
|
||||
}
|
||||
|
||||
func (lbcd *loadBalancedCursorDeployment) Connection(_ context.Context) (Connection, error) {
|
||||
return lbcd.conn, nil
|
||||
}
|
||||
|
||||
// RTTMonitor implements the driver.Server interface.
|
||||
func (lbcd *loadBalancedCursorDeployment) RTTMonitor() RTTMonitor {
|
||||
return &internal.ZeroRTTMonitor{}
|
||||
}
|
||||
|
||||
func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, conn Connection) ProcessErrorResult {
|
||||
return lbcd.errorProcessor.ProcessError(err, conn)
|
||||
}
|
||||
92
mongo/x/mongo/driver/batch_cursor_test.go
Normal file
92
mongo/x/mongo/driver/batch_cursor_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
)
|
||||
|
||||
func TestBatchCursor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("setBatchSize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var size int32
|
||||
bc := &BatchCursor{
|
||||
batchSize: size,
|
||||
}
|
||||
assert.Equal(t, size, bc.batchSize, "expected batchSize %v, got %v", size, bc.batchSize)
|
||||
|
||||
size = int32(4)
|
||||
bc.SetBatchSize(size)
|
||||
assert.Equal(t, size, bc.batchSize, "expected batchSize %v, got %v", size, bc.batchSize)
|
||||
})
|
||||
|
||||
t.Run("calcGetMoreBatchSize", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tcase := range []struct {
|
||||
name string
|
||||
size, limit, numReturned, expected int32
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
expected: 0,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "batchSize NEQ 0",
|
||||
size: 4,
|
||||
expected: 4,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "limit NEQ 0",
|
||||
limit: 4,
|
||||
expected: 0,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "limit NEQ and batchSize + numReturned EQ limit",
|
||||
size: 4,
|
||||
limit: 8,
|
||||
numReturned: 4,
|
||||
expected: 4,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "limit makes batchSize negative",
|
||||
numReturned: 4,
|
||||
limit: 2,
|
||||
expected: -2,
|
||||
ok: false,
|
||||
},
|
||||
} {
|
||||
tcase := tcase
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bc := &BatchCursor{
|
||||
limit: tcase.limit,
|
||||
batchSize: tcase.size,
|
||||
numReturned: tcase.numReturned,
|
||||
}
|
||||
|
||||
bc.SetBatchSize(tcase.size)
|
||||
|
||||
size, ok := calcGetMoreBatchSize(*bc)
|
||||
|
||||
assert.Equal(t, tcase.expected, size, "expected batchSize %v, got %v", tcase.expected, size)
|
||||
assert.Equal(t, tcase.ok, ok, "expected ok %v, got %v", tcase.ok, ok)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
76
mongo/x/mongo/driver/batches.go
Normal file
76
mongo/x/mongo/driver/batches.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a
|
||||
// server is passed to an insert command.
|
||||
var ErrDocumentTooLarge = errors.New("an inserted document is too large")
|
||||
|
||||
// Batches contains the necessary information to batch split an operation. This is only used for write
|
||||
// oeprations.
|
||||
type Batches struct {
|
||||
Identifier string
|
||||
Documents []bsoncore.Document
|
||||
Current []bsoncore.Document
|
||||
Ordered *bool
|
||||
}
|
||||
|
||||
// Valid returns true if Batches contains both an identifier and the length of Documents is greater
|
||||
// than zero.
|
||||
func (b *Batches) Valid() bool { return b != nil && b.Identifier != "" && len(b.Documents) > 0 }
|
||||
|
||||
// ClearBatch clears the Current batch. This must be called before AdvanceBatch will advance to the
|
||||
// next batch.
|
||||
func (b *Batches) ClearBatch() { b.Current = b.Current[:0] }
|
||||
|
||||
// AdvanceBatch splits the next batch using maxCount and targetBatchSize. This method will do nothing if
|
||||
// the current batch has not been cleared. We do this so that when this is called during execute we
|
||||
// can call it without first needing to check if we already have a batch, which makes the code
|
||||
// simpler and makes retrying easier.
|
||||
// The maxDocSize parameter is used to check that any one document is not too large. If the first document is bigger
|
||||
// than targetBatchSize but smaller than maxDocSize, a batch of size 1 containing that document will be created.
|
||||
func (b *Batches) AdvanceBatch(maxCount, targetBatchSize, maxDocSize int) error {
|
||||
if len(b.Current) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if maxCount <= 0 {
|
||||
maxCount = 1
|
||||
}
|
||||
|
||||
splitAfter := 0
|
||||
size := 0
|
||||
for i, doc := range b.Documents {
|
||||
if i == maxCount {
|
||||
break
|
||||
}
|
||||
if len(doc) > maxDocSize {
|
||||
return ErrDocumentTooLarge
|
||||
}
|
||||
if size+len(doc) > targetBatchSize {
|
||||
break
|
||||
}
|
||||
|
||||
size += len(doc)
|
||||
splitAfter++
|
||||
}
|
||||
|
||||
// if there are no documents, take the first one.
|
||||
// this can happen if there is a document that is smaller than maxDocSize but greater than targetBatchSize.
|
||||
if splitAfter == 0 {
|
||||
splitAfter = 1
|
||||
}
|
||||
|
||||
b.Current, b.Documents = b.Documents[:splitAfter], b.Documents[splitAfter:]
|
||||
return nil
|
||||
}
|
||||
137
mongo/x/mongo/driver/batches_test.go
Normal file
137
mongo/x/mongo/driver/batches_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestBatches(t *testing.T) {
|
||||
t.Run("Valid", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
batches *Batches
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"missing identifier", &Batches{}, false},
|
||||
{"no documents", &Batches{Identifier: "documents"}, false},
|
||||
{"valid", &Batches{Identifier: "documents", Documents: make([]bsoncore.Document, 5)}, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
want := tc.want
|
||||
got := tc.batches.Valid()
|
||||
if got != want {
|
||||
t.Errorf("Did not get expected result from Valid. got %t; want %t", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("ClearBatch", func(t *testing.T) {
|
||||
batches := &Batches{Identifier: "documents", Current: make([]bsoncore.Document, 2, 10)}
|
||||
if len(batches.Current) != 2 {
|
||||
t.Fatalf("Length of current batch should be 2, but is %d", len(batches.Current))
|
||||
}
|
||||
batches.ClearBatch()
|
||||
if len(batches.Current) != 0 {
|
||||
t.Fatalf("Length of current batch should be 0, but is %d", len(batches.Current))
|
||||
}
|
||||
})
|
||||
t.Run("AdvanceBatch", func(t *testing.T) {
|
||||
documents := make([]bsoncore.Document, 0)
|
||||
for i := 0; i < 5; i++ {
|
||||
doc := make(bsoncore.Document, 100)
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
batches *Batches
|
||||
maxCount int
|
||||
targetBatchSize int
|
||||
maxDocSize int
|
||||
err error
|
||||
want *Batches
|
||||
}{
|
||||
{
|
||||
"current batch non-zero",
|
||||
&Batches{Current: make([]bsoncore.Document, 2, 10)},
|
||||
0, 0, 0, nil,
|
||||
&Batches{Current: make([]bsoncore.Document, 2, 10)},
|
||||
},
|
||||
{
|
||||
// all of the documents in the batch fit in targetBatchSize so the batch is created successfully
|
||||
"documents fit in targetBatchSize",
|
||||
&Batches{Documents: documents},
|
||||
10, 600, 1000, nil,
|
||||
&Batches{Documents: documents[:0], Current: documents[0:]},
|
||||
},
|
||||
{
|
||||
// the first doc is bigger than targetBatchSize but smaller than maxDocSize so it is taken alone
|
||||
"first document larger than targetBatchSize, smaller than maxDocSize",
|
||||
&Batches{Documents: documents},
|
||||
10, 5, 100, nil,
|
||||
&Batches{Documents: documents[1:], Current: documents[:1]},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tc.batches.AdvanceBatch(tc.maxCount, tc.targetBatchSize, tc.maxDocSize)
|
||||
if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) {
|
||||
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
|
||||
}
|
||||
if !cmp.Equal(tc.batches, tc.want) {
|
||||
t.Errorf("Batches is not in correct state after AdvanceBatch. got %v; want %v", tc.batches, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("middle document larger than targetBatchSize, smaller than maxDocSize", func(t *testing.T) {
|
||||
// a batch is made but one document is too big, so everything before it is taken.
|
||||
// on the second call to AdvanceBatch, only the large document is taken
|
||||
|
||||
middleLargeDoc := make([]bsoncore.Document, 0)
|
||||
for i := 0; i < 5; i++ {
|
||||
doc := make(bsoncore.Document, 100)
|
||||
middleLargeDoc = append(middleLargeDoc, doc)
|
||||
}
|
||||
largeDoc := make(bsoncore.Document, 900)
|
||||
middleLargeDoc[2] = largeDoc
|
||||
batches := &Batches{Documents: middleLargeDoc}
|
||||
maxCount := 10
|
||||
targetSize := 600
|
||||
maxDocSize := 1000
|
||||
|
||||
// first batch should take first 2 docs (size 100 each)
|
||||
err := batches.AdvanceBatch(maxCount, targetSize, maxDocSize)
|
||||
assert.Nil(t, err, "AdvanceBatch error: %v", err)
|
||||
want := &Batches{Current: middleLargeDoc[:2], Documents: middleLargeDoc[2:]}
|
||||
assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches)
|
||||
|
||||
// second batch should take single large doc (size 900)
|
||||
batches.ClearBatch()
|
||||
err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize)
|
||||
assert.Nil(t, err, "AdvanceBatch error: %v", err)
|
||||
want = &Batches{Current: middleLargeDoc[2:3], Documents: middleLargeDoc[3:]}
|
||||
assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches)
|
||||
|
||||
// last batch should take last 2 docs (size 100 each)
|
||||
batches.ClearBatch()
|
||||
err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize)
|
||||
assert.Nil(t, err, "AdvanceBatch error: %v", err)
|
||||
want = &Batches{Current: middleLargeDoc[3:], Documents: middleLargeDoc[:0]}
|
||||
assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches)
|
||||
})
|
||||
})
|
||||
}
|
||||
53
mongo/x/mongo/driver/command_monitoring_test.go
Normal file
53
mongo/x/mongo/driver/command_monitoring_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/assert"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
func TestCommandMonitoring(t *testing.T) {
|
||||
t.Run("redactCommand", func(t *testing.T) {
|
||||
emptyDoc := bsoncore.BuildDocumentFromElements(nil)
|
||||
legacyHello := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, internal.LegacyHello, 1),
|
||||
)
|
||||
legacyHelloLowercase := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, internal.LegacyHelloLowercase, 1),
|
||||
)
|
||||
legacyHelloSpeculative := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, internal.LegacyHello, 1),
|
||||
bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", emptyDoc),
|
||||
)
|
||||
legacyHelloSpeculativeLowercase := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendInt32Element(nil, internal.LegacyHelloLowercase, 1),
|
||||
bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", emptyDoc),
|
||||
)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
commandName string
|
||||
command bsoncore.Document
|
||||
redacted bool
|
||||
}{
|
||||
{"legacy hello", internal.LegacyHello, legacyHello, false},
|
||||
{"legacy hello lowercase", internal.LegacyHelloLowercase, legacyHelloLowercase, false},
|
||||
{"legacy hello speculative auth", internal.LegacyHello, legacyHelloSpeculative, true},
|
||||
{"legacy hello speculative auth lowercase", internal.LegacyHello, legacyHelloSpeculativeLowercase, true},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
canMonitor := (&Operation{}).redactCommand(tc.commandName, tc.command)
|
||||
assert.Equal(t, tc.redacted, canMonitor, "expected redacted %v, got %v", tc.redacted, canMonitor)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
145
mongo/x/mongo/driver/compression.go
Normal file
145
mongo/x/mongo/driver/compression.go
Normal file
@@ -0,0 +1,145 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/snappy"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||||
)
|
||||
|
||||
// CompressionOpts holds settings for how to compress a payload
|
||||
type CompressionOpts struct {
|
||||
Compressor wiremessage.CompressorID
|
||||
ZlibLevel int
|
||||
ZstdLevel int
|
||||
UncompressedSize int32
|
||||
}
|
||||
|
||||
var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
|
||||
|
||||
func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
|
||||
if v, ok := zstdEncoders.Load(level); ok {
|
||||
return v.(*zstd.Encoder), nil
|
||||
}
|
||||
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
zstdEncoders.Store(level, encoder)
|
||||
return encoder, nil
|
||||
}
|
||||
|
||||
var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
|
||||
|
||||
func getZlibEncoder(level int) (*zlibEncoder, error) {
|
||||
if v, ok := zlibEncoders.Load(level); ok {
|
||||
return v.(*zlibEncoder), nil
|
||||
}
|
||||
writer, err := zlib.NewWriterLevel(nil, level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
|
||||
zlibEncoders.Store(level, encoder)
|
||||
|
||||
return encoder, nil
|
||||
}
|
||||
|
||||
type zlibEncoder struct {
|
||||
mu sync.Mutex
|
||||
writer *zlib.Writer
|
||||
buf *bytes.Buffer
|
||||
}
|
||||
|
||||
func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.buf.Reset()
|
||||
e.writer.Reset(e.buf)
|
||||
|
||||
_, err := e.writer.Write(src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = e.writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dst = append(dst[:0], e.buf.Bytes()...)
|
||||
return dst, nil
|
||||
}
|
||||
|
||||
// CompressPayload takes a byte slice and compresses it according to the options passed
|
||||
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
|
||||
switch opts.Compressor {
|
||||
case wiremessage.CompressorNoOp:
|
||||
return in, nil
|
||||
case wiremessage.CompressorSnappy:
|
||||
return snappy.Encode(nil, in), nil
|
||||
case wiremessage.CompressorZLib:
|
||||
encoder, err := getZlibEncoder(opts.ZlibLevel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encoder.Encode(nil, in)
|
||||
case wiremessage.CompressorZstd:
|
||||
encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encoder.EncodeAll(in, nil), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
|
||||
}
|
||||
}
|
||||
|
||||
// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
|
||||
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
|
||||
switch opts.Compressor {
|
||||
case wiremessage.CompressorNoOp:
|
||||
return in, nil
|
||||
case wiremessage.CompressorSnappy:
|
||||
uncompressed = make([]byte, opts.UncompressedSize)
|
||||
return snappy.Decode(uncompressed, in)
|
||||
case wiremessage.CompressorZLib:
|
||||
r, err := zlib.NewReader(bytes.NewReader(in))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = r.Close()
|
||||
}()
|
||||
uncompressed = make([]byte, opts.UncompressedSize)
|
||||
_, err = io.ReadFull(r, uncompressed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return uncompressed, nil
|
||||
case wiremessage.CompressorZstd:
|
||||
r, err := zstd.NewReader(bytes.NewBuffer(in))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
uncompressed = make([]byte, opts.UncompressedSize)
|
||||
_, err = io.ReadFull(r, uncompressed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return uncompressed, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
|
||||
}
|
||||
}
|
||||
80
mongo/x/mongo/driver/compression_test.go
Normal file
80
mongo/x/mongo/driver/compression_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||||
)
|
||||
|
||||
func TestCompression(t *testing.T) {
|
||||
compressors := []wiremessage.CompressorID{
|
||||
wiremessage.CompressorNoOp,
|
||||
wiremessage.CompressorSnappy,
|
||||
wiremessage.CompressorZLib,
|
||||
wiremessage.CompressorZstd,
|
||||
}
|
||||
|
||||
for _, compressor := range compressors {
|
||||
t.Run(compressor.String(), func(t *testing.T) {
|
||||
payload := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt")
|
||||
opts := CompressionOpts{
|
||||
Compressor: compressor,
|
||||
ZlibLevel: wiremessage.DefaultZlibLevel,
|
||||
ZstdLevel: wiremessage.DefaultZstdLevel,
|
||||
UncompressedSize: int32(len(payload)),
|
||||
}
|
||||
compressed, err := CompressPayload(payload, opts)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, 0, len(compressed))
|
||||
decompressed, err := DecompressPayload(compressed, opts)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, payload, decompressed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCompressPayload(b *testing.B) {
|
||||
payload := func() []byte {
|
||||
buf, err := os.ReadFile("compression.go")
|
||||
if err != nil {
|
||||
b.Log(err)
|
||||
b.FailNow()
|
||||
}
|
||||
for i := 1; i < 10; i++ {
|
||||
buf = append(buf, buf...)
|
||||
}
|
||||
return buf
|
||||
}()
|
||||
|
||||
compressors := []wiremessage.CompressorID{
|
||||
wiremessage.CompressorSnappy,
|
||||
wiremessage.CompressorZLib,
|
||||
wiremessage.CompressorZstd,
|
||||
}
|
||||
|
||||
for _, compressor := range compressors {
|
||||
b.Run(compressor.String(), func(b *testing.B) {
|
||||
opts := CompressionOpts{
|
||||
Compressor: compressor,
|
||||
ZlibLevel: wiremessage.DefaultZlibLevel,
|
||||
ZstdLevel: wiremessage.DefaultZstdLevel,
|
||||
}
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := CompressPayload(payload, opts)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
1036
mongo/x/mongo/driver/connstring/connstring.go
Normal file
1036
mongo/x/mongo/driver/connstring/connstring.go
Normal file
File diff suppressed because it is too large
Load Diff
356
mongo/x/mongo/driver/connstring/connstring_spec_test.go
Normal file
356
mongo/x/mongo/driver/connstring/connstring_spec_test.go
Normal file
@@ -0,0 +1,356 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package connstring_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
|
||||
)
|
||||
|
||||
type host struct {
|
||||
Type string
|
||||
Host string
|
||||
Port json.Number
|
||||
}
|
||||
|
||||
type auth struct {
|
||||
Username string
|
||||
Password *string
|
||||
DB string
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
Description string
|
||||
URI string
|
||||
Valid bool
|
||||
Warning bool
|
||||
Hosts []host
|
||||
Auth *auth
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
type testContainer struct {
|
||||
Tests []testCase
|
||||
}
|
||||
|
||||
const connstringTestsDir = "../../../../testdata/connection-string/"
|
||||
const urioptionsTestDir = "../../../../testdata/uri-options/"
|
||||
|
||||
func (h *host) toString() string {
|
||||
switch h.Type {
|
||||
case "unix":
|
||||
return h.Host
|
||||
case "ip_literal":
|
||||
if len(h.Port) == 0 {
|
||||
return "[" + h.Host + "]"
|
||||
}
|
||||
return "[" + h.Host + "]" + ":" + string(h.Port)
|
||||
case "ipv4":
|
||||
fallthrough
|
||||
case "hostname":
|
||||
if len(h.Port) == 0 {
|
||||
return h.Host
|
||||
}
|
||||
return h.Host + ":" + string(h.Port)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func hostsToStrings(hosts []host) []string {
|
||||
out := make([]string, len(hosts))
|
||||
|
||||
for i, host := range hosts {
|
||||
out[i] = host.toString()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func runTestsInFile(t *testing.T, dirname string, filename string, warningsError bool) {
|
||||
filepath := path.Join(dirname, filename)
|
||||
content, err := ioutil.ReadFile(filepath)
|
||||
require.NoError(t, err)
|
||||
|
||||
var container testContainer
|
||||
require.NoError(t, json.Unmarshal(content, &container))
|
||||
|
||||
// Remove ".json" from filename.
|
||||
filename = filename[:len(filename)-5]
|
||||
|
||||
for _, testCase := range container.Tests {
|
||||
runTest(t, filename, testCase, warningsError)
|
||||
}
|
||||
}
|
||||
|
||||
var skipDescriptions = map[string]struct{}{
|
||||
"Valid options specific to single-threaded drivers are parsed correctly": {},
|
||||
}
|
||||
|
||||
var skipKeywords = []string{
|
||||
"tlsAllowInvalidHostnames",
|
||||
"tlsAllowInvalidCertificates",
|
||||
"tlsDisableCertificateRevocationCheck",
|
||||
"serverSelectionTryOnce",
|
||||
}
|
||||
|
||||
func runTest(t *testing.T, filename string, test testCase, warningsError bool) {
|
||||
t.Run(filename+"/"+test.Description, func(t *testing.T) {
|
||||
if _, skip := skipDescriptions[test.Description]; skip {
|
||||
t.Skip()
|
||||
}
|
||||
for _, keyword := range skipKeywords {
|
||||
if strings.Contains(test.Description, keyword) {
|
||||
t.Skipf("skipping because keyword %s", keyword)
|
||||
}
|
||||
}
|
||||
|
||||
cs, err := connstring.ParseAndValidate(test.URI)
|
||||
// Since we don't have warnings in Go, we return warnings as errors.
|
||||
//
|
||||
// This is a bit unfortunate, but since we do raise warnings as errors with the newer
|
||||
// URI options, but don't with some of the older things, we do a switch on the filename
|
||||
// here. We are trying to not break existing user applications that have unrecognized
|
||||
// options.
|
||||
if test.Valid && !(test.Warning && warningsError) {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.Equal(t, test.URI, cs.Original)
|
||||
|
||||
if test.Hosts != nil {
|
||||
require.Equal(t, hostsToStrings(test.Hosts), cs.Hosts)
|
||||
}
|
||||
|
||||
if test.Auth != nil {
|
||||
require.Equal(t, test.Auth.Username, cs.Username)
|
||||
|
||||
if test.Auth.Password == nil {
|
||||
require.False(t, cs.PasswordSet)
|
||||
} else {
|
||||
require.True(t, cs.PasswordSet)
|
||||
require.Equal(t, *test.Auth.Password, cs.Password)
|
||||
}
|
||||
|
||||
if test.Auth.DB != cs.Database {
|
||||
require.Equal(t, test.Auth.DB, cs.AuthSource)
|
||||
} else {
|
||||
require.Equal(t, test.Auth.DB, cs.Database)
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all options are present.
|
||||
verifyConnStringOptions(t, cs, test.Options)
|
||||
|
||||
// Check that non-present options are unset. This will be redundant with the above checks
|
||||
// for options that are present.
|
||||
var ok bool
|
||||
|
||||
_, ok = test.Options["maxpoolsize"]
|
||||
require.Equal(t, ok, cs.MaxPoolSizeSet)
|
||||
})
|
||||
}
|
||||
|
||||
// Test case for all connection string spec tests.
|
||||
func TestConnStringSpec(t *testing.T) {
|
||||
for _, file := range helpers.FindJSONFilesInDir(t, connstringTestsDir) {
|
||||
runTestsInFile(t, connstringTestsDir, file, false)
|
||||
}
|
||||
}
|
||||
|
||||
func TestURIOptionsSpec(t *testing.T) {
|
||||
for _, file := range helpers.FindJSONFilesInDir(t, urioptionsTestDir) {
|
||||
runTestsInFile(t, urioptionsTestDir, file, true)
|
||||
}
|
||||
}
|
||||
|
||||
// verifyConnStringOptions verifies the options on the connection string.
|
||||
func verifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map[string]interface{}) {
|
||||
// Check that all options are present.
|
||||
for key, value := range options {
|
||||
|
||||
key = strings.ToLower(key)
|
||||
switch key {
|
||||
case "appname":
|
||||
require.Equal(t, value, cs.AppName)
|
||||
case "authsource":
|
||||
require.Equal(t, value, cs.AuthSource)
|
||||
case "authmechanism":
|
||||
require.Equal(t, value, cs.AuthMechanism)
|
||||
case "authmechanismproperties":
|
||||
convertedMap := value.(map[string]interface{})
|
||||
require.Equal(t,
|
||||
mapInterfaceToString(convertedMap),
|
||||
cs.AuthMechanismProperties)
|
||||
case "compressors":
|
||||
require.Equal(t, convertToStringSlice(value), cs.Compressors)
|
||||
case "connecttimeoutms":
|
||||
require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond))
|
||||
case "directconnection":
|
||||
require.True(t, cs.DirectConnectionSet)
|
||||
require.Equal(t, value, cs.DirectConnection)
|
||||
case "heartbeatfrequencyms":
|
||||
require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond))
|
||||
case "journal":
|
||||
require.True(t, cs.JSet)
|
||||
require.Equal(t, value, cs.J)
|
||||
case "loadbalanced":
|
||||
require.True(t, cs.LoadBalancedSet)
|
||||
require.Equal(t, value, cs.LoadBalanced)
|
||||
case "localthresholdms":
|
||||
require.True(t, cs.LocalThresholdSet)
|
||||
require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond))
|
||||
case "maxidletimems":
|
||||
require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond))
|
||||
case "maxpoolsize":
|
||||
require.True(t, cs.MaxPoolSizeSet)
|
||||
require.Equal(t, value, cs.MaxPoolSize)
|
||||
case "maxstalenessseconds":
|
||||
require.True(t, cs.MaxStalenessSet)
|
||||
require.Equal(t, value, float64(cs.MaxStaleness/time.Second))
|
||||
case "minpoolsize":
|
||||
require.True(t, cs.MinPoolSizeSet)
|
||||
require.Equal(t, value, int64(cs.MinPoolSize))
|
||||
case "readpreference":
|
||||
require.Equal(t, value, cs.ReadPreference)
|
||||
case "readpreferencetags":
|
||||
sm, ok := value.([]interface{})
|
||||
require.True(t, ok)
|
||||
tags := make([]map[string]string, 0, len(sm))
|
||||
for _, i := range sm {
|
||||
m, ok := i.(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
tags = append(tags, mapInterfaceToString(m))
|
||||
}
|
||||
require.Equal(t, tags, cs.ReadPreferenceTagSets)
|
||||
case "readconcernlevel":
|
||||
require.Equal(t, value, cs.ReadConcernLevel)
|
||||
case "replicaset":
|
||||
require.Equal(t, value, cs.ReplicaSet)
|
||||
case "retrywrites":
|
||||
require.True(t, cs.RetryWritesSet)
|
||||
require.Equal(t, value, cs.RetryWrites)
|
||||
case "serverselectiontimeoutms":
|
||||
require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
|
||||
case "srvmaxhosts":
|
||||
require.Equal(t, value, float64(cs.SRVMaxHosts))
|
||||
case "srvservicename":
|
||||
require.Equal(t, value, cs.SRVServiceName)
|
||||
case "ssl", "tls":
|
||||
require.Equal(t, value, cs.SSL)
|
||||
case "sockettimeoutms":
|
||||
require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond))
|
||||
case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure":
|
||||
require.True(t, cs.SSLInsecureSet)
|
||||
require.Equal(t, value, cs.SSLInsecure)
|
||||
case "tlscafile":
|
||||
require.True(t, cs.SSLCaFileSet)
|
||||
require.Equal(t, value, cs.SSLCaFile)
|
||||
case "tlscertificatekeyfile":
|
||||
require.True(t, cs.SSLClientCertificateKeyFileSet)
|
||||
require.Equal(t, value, cs.SSLClientCertificateKeyFile)
|
||||
case "tlscertificatekeyfilepassword":
|
||||
require.True(t, cs.SSLClientCertificateKeyPasswordSet)
|
||||
require.Equal(t, value, cs.SSLClientCertificateKeyPassword())
|
||||
case "w":
|
||||
if cs.WNumberSet {
|
||||
valueInt := getIntFromInterface(value)
|
||||
require.NotNil(t, valueInt)
|
||||
require.Equal(t, *valueInt, int64(cs.WNumber))
|
||||
} else {
|
||||
require.Equal(t, value, cs.WString)
|
||||
}
|
||||
case "wtimeoutms":
|
||||
require.Equal(t, value, float64(cs.WTimeout/time.Millisecond))
|
||||
case "waitqueuetimeoutms":
|
||||
case "zlibcompressionlevel":
|
||||
require.Equal(t, value, float64(cs.ZlibLevel))
|
||||
case "zstdcompressionlevel":
|
||||
require.Equal(t, value, float64(cs.ZstdLevel))
|
||||
case "tlsdisableocspendpointcheck":
|
||||
require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck)
|
||||
default:
|
||||
opt, ok := cs.UnknownOptions[key]
|
||||
require.True(t, ok)
|
||||
require.Contains(t, opt, fmt.Sprint(value))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert each interface{} value in the map to a string.
|
||||
func mapInterfaceToString(m map[string]interface{}) map[string]string {
|
||||
out := make(map[string]string)
|
||||
|
||||
for key, value := range m {
|
||||
out[key] = fmt.Sprint(value)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// getIntFromInterface attempts to convert an empty interface value to an integer.
|
||||
//
|
||||
// Returns nil if it is not possible.
|
||||
func getIntFromInterface(i interface{}) *int64 {
|
||||
var out int64
|
||||
|
||||
switch v := i.(type) {
|
||||
case int:
|
||||
out = int64(v)
|
||||
case int32:
|
||||
out = int64(v)
|
||||
case int64:
|
||||
out = v
|
||||
case float32:
|
||||
f := float64(v)
|
||||
if math.Floor(f) != f || f > float64(math.MaxInt64) {
|
||||
break
|
||||
}
|
||||
|
||||
out = int64(f)
|
||||
|
||||
case float64:
|
||||
if math.Floor(v) != v || v > float64(math.MaxInt64) {
|
||||
break
|
||||
}
|
||||
|
||||
out = int64(v)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return &out
|
||||
}
|
||||
|
||||
func convertToStringSlice(i interface{}) []string {
|
||||
s, ok := i.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
ret := make([]string, 0, len(s))
|
||||
for _, v := range s {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
ret = append(ret, str)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
631
mongo/x/mongo/driver/connstring/connstring_test.go
Normal file
631
mongo/x/mongo/driver/connstring/connstring_test.go
Normal file
@@ -0,0 +1,631 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package connstring_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
|
||||
)
|
||||
|
||||
func TestAppName(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected string
|
||||
err bool
|
||||
}{
|
||||
{s: "appName=Funny", expected: "Funny"},
|
||||
{s: "appName=awesome", expected: "awesome"},
|
||||
{s: "appName=", expected: ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.AppName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMechanism(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected string
|
||||
err bool
|
||||
}{
|
||||
{s: "authMechanism=scram-sha-1", expected: "scram-sha-1"},
|
||||
{s: "authMechanism=scram-sha-256", expected: "scram-sha-256"},
|
||||
{s: "authMechanism=mongodb-CR", expected: "mongodb-CR"},
|
||||
{s: "authMechanism=plain", expected: "plain"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://user:pass@localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.AuthMechanism)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected string
|
||||
err bool
|
||||
}{
|
||||
{s: "foobar?authSource=bazqux", expected: "bazqux"},
|
||||
{s: "foobar", expected: "foobar"},
|
||||
{s: "", expected: "admin"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://user:pass@localhost/%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.AuthSource)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected connstring.ConnectMode
|
||||
err bool
|
||||
}{
|
||||
{s: "connect=automatic", expected: connstring.AutoConnect},
|
||||
{s: "connect=AUTOMATIC", expected: connstring.AutoConnect},
|
||||
{s: "connect=direct", expected: connstring.SingleConnect},
|
||||
{s: "connect=blah", err: true},
|
||||
// Combinations of connect and directConnection where connect is set first - conflicting combinations must
|
||||
// error.
|
||||
{s: "connect=automatic&directConnection=true", err: true},
|
||||
{s: "connect=automatic&directConnection=false", expected: connstring.AutoConnect},
|
||||
{s: "connect=direct&directConnection=true", expected: connstring.SingleConnect},
|
||||
{s: "connect=direct&directConnection=false", err: true},
|
||||
// Combinations of connect and directConnection where directConnection is set first.
|
||||
{s: "directConnection=true&connect=automatic", err: true},
|
||||
{s: "directConnection=false&connect=automatic", expected: connstring.AutoConnect},
|
||||
{s: "directConnection=true&connect=direct", expected: connstring.SingleConnect},
|
||||
{s: "directConnection=false&connect=direct", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.Connect)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectConnection(t *testing.T) {
|
||||
testCases := []struct {
|
||||
s string
|
||||
expected bool
|
||||
err bool
|
||||
}{
|
||||
{"directConnection=true", true, false},
|
||||
{"directConnection=false", false, false},
|
||||
{"directConnection=TRUE", true, false},
|
||||
{"directConnection=FALSE", false, false},
|
||||
{"directConnection=blah", false, true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", tc.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if tc.err {
|
||||
assert.NotNil(t, err, "expected error, got nil")
|
||||
return
|
||||
}
|
||||
|
||||
assert.Nil(t, err, "expected no error, got %v", err)
|
||||
assert.Equal(t, tc.expected, cs.DirectConnection, "expected DirectConnection value %v, got %v", tc.expected,
|
||||
cs.DirectConnection)
|
||||
assert.True(t, cs.DirectConnectionSet, "expected DirectConnectionSet to be true, got false")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "connectTimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "connectTimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "connectTimeoutMS=-2", err: true},
|
||||
{s: "connectTimeoutMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.ConnectTimeout)
|
||||
require.True(t, cs.ConnectTimeoutSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatInterval(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "heartbeatIntervalMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "heartbeatIntervalMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "heartbeatIntervalMS=-2", err: true},
|
||||
{s: "heartbeatIntervalMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.HeartbeatInterval)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalThreshold(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "localThresholdMS=0", expected: time.Duration(0) * time.Millisecond},
|
||||
{s: "localThresholdMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "localThresholdMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "localThresholdMS=-2", err: true},
|
||||
{s: "localThresholdMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.LocalThreshold)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxConnIdleTime(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "maxIdleTimeMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "maxIdleTimeMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "maxIdleTimeMS=-2", err: true},
|
||||
{s: "maxIdleTimeMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.MaxConnIdleTime)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxPoolSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected uint64
|
||||
err bool
|
||||
}{
|
||||
{s: "maxPoolSize=10", expected: 10},
|
||||
{s: "maxPoolSize=100", expected: 100},
|
||||
{s: "maxPoolSize=-2", err: true},
|
||||
{s: "maxPoolSize=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.True(t, cs.MaxPoolSizeSet)
|
||||
require.Equal(t, test.expected, cs.MaxPoolSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinPoolSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected uint64
|
||||
err bool
|
||||
}{
|
||||
{s: "minPoolSize=10", expected: 10},
|
||||
{s: "minPoolSize=100", expected: 100},
|
||||
{s: "minPoolSize=-2", err: true},
|
||||
{s: "minPoolSize=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.True(t, cs.MinPoolSizeSet)
|
||||
require.Equal(t, test.expected, cs.MinPoolSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxConnecting(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected uint64
|
||||
err bool
|
||||
}{
|
||||
{s: "maxConnecting=10", expected: 10},
|
||||
{s: "maxConnecting=100", expected: 100},
|
||||
{s: "maxConnecting=-2", err: true},
|
||||
{s: "maxConnecting=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.True(t, cs.MaxConnectingSet)
|
||||
require.Equal(t, test.expected, cs.MaxConnecting)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPreference(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected string
|
||||
err bool
|
||||
}{
|
||||
{s: "readPreference=primary", expected: "primary"},
|
||||
{s: "readPreference=secondaryPreferred", expected: "secondaryPreferred"},
|
||||
{s: "readPreference=something", expected: "something"}, // we don't validate here
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.ReadPreference)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPreferenceTags(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected []map[string]string
|
||||
err bool
|
||||
}{
|
||||
{s: "", expected: nil},
|
||||
{s: "readPreferenceTags=one:1", expected: []map[string]string{{"one": "1"}}},
|
||||
{s: "readPreferenceTags=one:1,two:2", expected: []map[string]string{{"one": "1", "two": "2"}}},
|
||||
{s: "readPreferenceTags=one:1&readPreferenceTags=two:2", expected: []map[string]string{{"one": "1"}, {"two": "2"}}},
|
||||
{s: "readPreferenceTags=one:1:3,two:2", err: true},
|
||||
{s: "readPreferenceTags=one:1&readPreferenceTags=two:2&readPreferenceTags=", expected: []map[string]string{{"one": "1"}, {"two": "2"}, {}}},
|
||||
{s: "readPreferenceTags=", expected: []map[string]string{{}}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.ReadPreferenceTagSets)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxStaleness(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "maxStaleness=10", expected: time.Duration(10) * time.Second},
|
||||
{s: "maxStaleness=100", expected: time.Duration(100) * time.Second},
|
||||
{s: "maxStaleness=-2", err: true},
|
||||
{s: "maxStaleness=gsdge", err: true},
|
||||
}
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.MaxStaleness)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplicaSet(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected string
|
||||
err bool
|
||||
}{
|
||||
{s: "replicaSet=auto", expected: "auto"},
|
||||
{s: "replicaSet=rs0", expected: "rs0"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.ReplicaSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryWrites(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected bool
|
||||
err bool
|
||||
}{
|
||||
{s: "retryWrites=true", expected: true},
|
||||
{s: "retryWrites=false", expected: false},
|
||||
{s: "retryWrites=foobar", expected: false, err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.RetryWrites)
|
||||
require.True(t, cs.RetryWritesSet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryReads(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected bool
|
||||
err bool
|
||||
}{
|
||||
{s: "retryReads=true", expected: true},
|
||||
{s: "retryReads=false", expected: false},
|
||||
{s: "retryReads=foobar", expected: false, err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.RetryReads)
|
||||
require.True(t, cs.RetryReadsSet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScheme(t *testing.T) {
|
||||
// Can't unit test 'mongodb+srv' because that requires networking. Tested
|
||||
// in x/mongo/driver/topology/initial_dns_seedlist_discovery_test.go
|
||||
cs, err := connstring.ParseAndValidate("mongodb://localhost/")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cs.Scheme, "mongodb")
|
||||
require.Equal(t, cs.Scheme, connstring.SchemeMongoDB)
|
||||
}
|
||||
|
||||
func TestServerSelectionTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "serverSelectionTimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "serverSelectionTimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "serverSelectionTimeoutMS=-2", err: true},
|
||||
{s: "serverSelectionTimeoutMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.ServerSelectionTimeout)
|
||||
require.True(t, cs.ServerSelectionTimeoutSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocketTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "socketTimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "socketTimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "socketTimeoutMS=-2", err: true},
|
||||
{s: "socketTimeoutMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.SocketTimeout)
|
||||
require.True(t, cs.SocketTimeoutSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
expected time.Duration
|
||||
err bool
|
||||
}{
|
||||
{s: "wtimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
|
||||
{s: "wtimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
|
||||
{s: "wtimeoutMS=-2", err: true},
|
||||
{s: "wtimeoutMS=gsdge", err: true},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
|
||||
t.Run(s, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(s)
|
||||
if test.err {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.expected, cs.WTimeout)
|
||||
require.True(t, cs.WTimeoutSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompressionOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
uriOptions string
|
||||
compressors []string
|
||||
zlibLevel int
|
||||
zstdLevel int
|
||||
err bool
|
||||
}{
|
||||
{name: "SingleCompressor", uriOptions: "compressors=zlib", compressors: []string{"zlib"}},
|
||||
{name: "MultiCompressors", uriOptions: "compressors=snappy,zlib", compressors: []string{"snappy", "zlib"}},
|
||||
{name: "ZlibWithLevel", uriOptions: "compressors=zlib&zlibCompressionLevel=7", compressors: []string{"zlib"}, zlibLevel: 7},
|
||||
{name: "DefaultZlibLevel", uriOptions: "compressors=zlib&zlibCompressionLevel=-1", compressors: []string{"zlib"}, zlibLevel: 6},
|
||||
{name: "InvalidZlibLevel", uriOptions: "compressors=zlib&zlibCompressionLevel=-2", compressors: []string{"zlib"}, err: true},
|
||||
{name: "ZstdWithLevel", uriOptions: "compressors=zstd&zstdCompressionLevel=20", compressors: []string{"zstd"}, zstdLevel: 20},
|
||||
{name: "DefaultZstdLevel", uriOptions: "compressors=zstd&zstdCompressionLevel=-1", compressors: []string{"zstd"}, zstdLevel: 6},
|
||||
{name: "InvalidZstdLevel", uriOptions: "compressors=zstd&zstdCompressionLevel=30", compressors: []string{"zstd"}, err: true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
uri := fmt.Sprintf("mongodb://localhost/?%s", tc.uriOptions)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cs, err := connstring.ParseAndValidate(uri)
|
||||
if tc.err {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tc.compressors, cs.Compressors)
|
||||
if tc.zlibLevel != 0 {
|
||||
assert.Equal(t, tc.zlibLevel, cs.ZlibLevel)
|
||||
}
|
||||
if tc.zstdLevel != 0 {
|
||||
assert.Equal(t, tc.zstdLevel, cs.ZstdLevel)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
473
mongo/x/mongo/driver/crypt.go
Normal file
473
mongo/x/mongo/driver/crypt.go
Normal file
@@ -0,0 +1,473 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultKmsPort = 443
|
||||
defaultKmsTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
// CollectionInfoFn is a callback used to retrieve collection information.
|
||||
type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
|
||||
|
||||
// KeyRetrieverFn is a callback used to retrieve keys from the key vault.
|
||||
type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
|
||||
|
||||
// MarkCommandFn is a callback used to add encryption markings to a command.
|
||||
type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
|
||||
|
||||
// CryptOptions specifies options to configure a Crypt instance.
|
||||
type CryptOptions struct {
|
||||
MongoCrypt *mongocrypt.MongoCrypt
|
||||
CollInfoFn CollectionInfoFn
|
||||
KeyFn KeyRetrieverFn
|
||||
MarkFn MarkCommandFn
|
||||
TLSConfig map[string]*tls.Config
|
||||
HTTPClient *http.Client
|
||||
BypassAutoEncryption bool
|
||||
BypassQueryAnalysis bool
|
||||
}
|
||||
|
||||
// Crypt is an interface implemented by types that can encrypt and decrypt instances of
|
||||
// bsoncore.Document.
|
||||
//
|
||||
// Users should rely on the driver's crypt type (used by default) for encryption and decryption
|
||||
// unless they are perfectly confident in another implementation of Crypt.
|
||||
type Crypt interface {
|
||||
// Encrypt encrypts the given command.
|
||||
Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
|
||||
// Decrypt decrypts the given command response.
|
||||
Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
|
||||
// CreateDataKey creates a data key using the given KMS provider and options.
|
||||
CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
|
||||
// EncryptExplicit encrypts the given value with the given options.
|
||||
EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
|
||||
// DecryptExplicit decrypts the given encrypted value.
|
||||
DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
|
||||
// Close cleans up any resources associated with the Crypt instance.
|
||||
Close()
|
||||
// BypassAutoEncryption returns true if auto-encryption should be bypassed.
|
||||
BypassAutoEncryption() bool
|
||||
// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents
|
||||
// to be returned as a slice of bsoncore.Document.
|
||||
RewrapDataKey(ctx context.Context, filter []byte, opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error)
|
||||
}
|
||||
|
||||
// crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
|
||||
// and decryption.
|
||||
type crypt struct {
|
||||
mongoCrypt *mongocrypt.MongoCrypt
|
||||
collInfoFn CollectionInfoFn
|
||||
keyFn KeyRetrieverFn
|
||||
markFn MarkCommandFn
|
||||
tlsConfig map[string]*tls.Config
|
||||
httpClient *http.Client
|
||||
|
||||
bypassAutoEncryption bool
|
||||
}
|
||||
|
||||
// NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
|
||||
func NewCrypt(opts *CryptOptions) Crypt {
|
||||
c := &crypt{
|
||||
mongoCrypt: opts.MongoCrypt,
|
||||
collInfoFn: opts.CollInfoFn,
|
||||
keyFn: opts.KeyFn,
|
||||
markFn: opts.MarkFn,
|
||||
tlsConfig: opts.TLSConfig,
|
||||
httpClient: opts.HTTPClient,
|
||||
bypassAutoEncryption: opts.BypassAutoEncryption,
|
||||
}
|
||||
if c.httpClient == nil {
|
||||
c.httpClient = internal.DefaultHTTPClient
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// Encrypt encrypts the given command.
|
||||
func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
|
||||
if c.bypassAutoEncryption {
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cryptCtx.Close()
|
||||
|
||||
return c.executeStateMachine(ctx, cryptCtx, db)
|
||||
}
|
||||
|
||||
// Decrypt decrypts the given command response.
|
||||
func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
|
||||
cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cryptCtx.Close()
|
||||
|
||||
return c.executeStateMachine(ctx, cryptCtx, "")
|
||||
}
|
||||
|
||||
// CreateDataKey creates a data key using the given KMS provider and options.
|
||||
func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
|
||||
cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cryptCtx.Close()
|
||||
|
||||
return c.executeStateMachine(ctx, cryptCtx, "")
|
||||
}
|
||||
|
||||
// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents to
|
||||
// be returned as a slice of bsoncore.Document.
|
||||
func (c *crypt) RewrapDataKey(ctx context.Context, filter []byte,
|
||||
opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error) {
|
||||
|
||||
cryptCtx, err := c.mongoCrypt.RewrapDataKeyContext(filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cryptCtx.Close()
|
||||
|
||||
rewrappedBSON, err := c.executeStateMachine(ctx, cryptCtx, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rewrappedBSON == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mongocrypt_ctx_rewrap_many_datakey_init wraps the documents in a BSON of the form { "v": [(BSON document), ...] }
|
||||
// where each BSON document in the slice is a document containing a rewrapped datakey.
|
||||
rewrappedDocumentBytes, err := rewrappedBSON.LookupErr("v")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse the resulting BSON as individual documents.
|
||||
rewrappedDocsArray, ok := rewrappedDocumentBytes.ArrayOK()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected results from mongocrypt_ctx_rewrap_many_datakey_init to be an array")
|
||||
}
|
||||
|
||||
rewrappedDocumentValues, err := rewrappedDocsArray.Values()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rewrappedDocuments := []bsoncore.Document{}
|
||||
for _, rewrappedDocumentValue := range rewrappedDocumentValues {
|
||||
if rewrappedDocumentValue.Type != bsontype.EmbeddedDocument {
|
||||
// If a value in the document's array returned by mongocrypt is anything other than an embedded document,
|
||||
// then something is wrong and we should terminate the routine.
|
||||
return nil, fmt.Errorf("expected value of type %q, got: %q",
|
||||
bsontype.EmbeddedDocument.String(),
|
||||
rewrappedDocumentValue.Type.String())
|
||||
}
|
||||
rewrappedDocuments = append(rewrappedDocuments, rewrappedDocumentValue.Document())
|
||||
}
|
||||
return rewrappedDocuments, nil
|
||||
}
|
||||
|
||||
// EncryptExplicit encrypts the given value with the given options.
|
||||
func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
|
||||
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||
doc = bsoncore.AppendValueElement(doc, "v", val)
|
||||
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||||
|
||||
cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
defer cryptCtx.Close()
|
||||
|
||||
res, err := c.executeStateMachine(ctx, cryptCtx, "")
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
sub, data := res.Lookup("v").Binary()
|
||||
return sub, data, nil
|
||||
}
|
||||
|
||||
// DecryptExplicit decrypts the given encrypted value.
|
||||
func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
|
||||
idx, doc := bsoncore.AppendDocumentStart(nil)
|
||||
doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
|
||||
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
|
||||
|
||||
cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
|
||||
if err != nil {
|
||||
return bsoncore.Value{}, err
|
||||
}
|
||||
defer cryptCtx.Close()
|
||||
|
||||
res, err := c.executeStateMachine(ctx, cryptCtx, "")
|
||||
if err != nil {
|
||||
return bsoncore.Value{}, err
|
||||
}
|
||||
|
||||
return res.Lookup("v"), nil
|
||||
}
|
||||
|
||||
// Close cleans up any resources associated with the Crypt instance.
|
||||
func (c *crypt) Close() {
|
||||
c.mongoCrypt.Close()
|
||||
if c.httpClient == internal.DefaultHTTPClient {
|
||||
internal.CloseIdleHTTPConnections(c.httpClient)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *crypt) BypassAutoEncryption() bool {
|
||||
return c.bypassAutoEncryption
|
||||
}
|
||||
|
||||
func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
|
||||
var err error
|
||||
for {
|
||||
state := cryptCtx.State()
|
||||
switch state {
|
||||
case mongocrypt.NeedMongoCollInfo:
|
||||
err = c.collectionInfo(ctx, cryptCtx, db)
|
||||
case mongocrypt.NeedMongoMarkings:
|
||||
err = c.markCommand(ctx, cryptCtx, db)
|
||||
case mongocrypt.NeedMongoKeys:
|
||||
err = c.retrieveKeys(ctx, cryptCtx)
|
||||
case mongocrypt.NeedKms:
|
||||
err = c.decryptKeys(cryptCtx)
|
||||
case mongocrypt.Ready:
|
||||
return cryptCtx.Finish()
|
||||
case mongocrypt.Done:
|
||||
return nil, nil
|
||||
case mongocrypt.NeedKmsCredentials:
|
||||
err = c.provideKmsProviders(ctx, cryptCtx)
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid Crypt state: %v", state)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
|
||||
op, err := cryptCtx.NextOperation()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collInfo, err := c.collInfoFn(ctx, db, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if collInfo != nil {
|
||||
if err = cryptCtx.AddOperationResult(collInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return cryptCtx.CompleteOperation()
|
||||
}
|
||||
|
||||
func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
|
||||
op, err := cryptCtx.NextOperation()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
markedCmd, err := c.markFn(ctx, db, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cryptCtx.CompleteOperation()
|
||||
}
|
||||
|
||||
func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
|
||||
op, err := cryptCtx.NextOperation()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
keys, err := c.keyFn(ctx, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if err = cryptCtx.AddOperationResult(key); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return cryptCtx.CompleteOperation()
|
||||
}
|
||||
|
||||
func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
|
||||
for {
|
||||
kmsCtx := cryptCtx.NextKmsContext()
|
||||
if kmsCtx == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if err := c.decryptKey(kmsCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return cryptCtx.FinishKmsContexts()
|
||||
}
|
||||
|
||||
func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
|
||||
host, err := kmsCtx.HostName()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg, err := kmsCtx.Message()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add a port to the address if it's not already present
|
||||
addr := host
|
||||
if idx := strings.IndexByte(host, ':'); idx == -1 {
|
||||
addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
|
||||
}
|
||||
|
||||
kmsProvider := kmsCtx.KMSProvider()
|
||||
tlsCfg := c.tlsConfig[kmsProvider]
|
||||
if tlsCfg == nil {
|
||||
tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
conn, err := tls.Dial("tcp", addr, tlsCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = conn.Write(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
bytesNeeded := kmsCtx.BytesNeeded()
|
||||
if bytesNeeded == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
res := make([]byte, bytesNeeded)
|
||||
bytesRead, err := conn.Read(res)
|
||||
if err != nil && err != io.EOF {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// needsKmsProvider returns true if provider was initially set to an empty document.
|
||||
// An empty document signals the driver to fetch credentials.
|
||||
func needsKmsProvider(kmsProviders bsoncore.Document, provider string) bool {
|
||||
val, err := kmsProviders.LookupErr(provider)
|
||||
if err != nil {
|
||||
// KMS provider is not configured.
|
||||
return false
|
||||
}
|
||||
doc, ok := val.DocumentOK()
|
||||
// KMS provider is an empty document.
|
||||
return ok && len(doc) == 5
|
||||
}
|
||||
|
||||
func getGCPAccessToken(ctx context.Context, httpClient *http.Client) (string, error) {
|
||||
metadataHost := "metadata.google.internal"
|
||||
if envhost := os.Getenv("GCE_METADATA_HOST"); envhost != "" {
|
||||
metadataHost = envhost
|
||||
}
|
||||
url := fmt.Sprintf("http://%s/computeMetadata/v1/instance/service-accounts/default/token", metadataHost)
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials")
|
||||
}
|
||||
req.Header.Set("Metadata-Flavor", "Google")
|
||||
resp, err := httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials: error reading response body")
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
|
||||
}
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
}
|
||||
// Attempt to read body as JSON
|
||||
err = json.Unmarshal(body, &tokenResponse)
|
||||
if err != nil {
|
||||
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials: error reading body JSON. Response body: %s", body)
|
||||
}
|
||||
if tokenResponse.AccessToken == "" {
|
||||
return "", fmt.Errorf("unable to retrieve GCP credentials: got unexpected empty accessToken from GCP Metadata Server. Response body: %s", body)
|
||||
}
|
||||
return tokenResponse.AccessToken, nil
|
||||
}
|
||||
|
||||
func (c *crypt) provideKmsProviders(ctx context.Context, cryptCtx *mongocrypt.Context) error {
|
||||
kmsProviders := c.mongoCrypt.GetKmsProviders()
|
||||
builder := bsoncore.NewDocumentBuilder()
|
||||
|
||||
if needsKmsProvider(kmsProviders, "gcp") {
|
||||
// "gcp" KMS provider is an empty document.
|
||||
// Attempt to fetch from GCP Instance Metadata server.
|
||||
{
|
||||
token, err := getGCPAccessToken(ctx, c.httpClient)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
builder.StartDocument("gcp").
|
||||
AppendString("accessToken", token).
|
||||
FinishDocument()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return cryptCtx.ProvideKmsProviders(builder.Build())
|
||||
}
|
||||
144
mongo/x/mongo/driver/dns/dns.go
Normal file
144
mongo/x/mongo/driver/dns/dns.go
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package dns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Resolver resolves DNS records.
|
||||
type Resolver struct {
|
||||
// Holds the functions to use for DNS lookups
|
||||
LookupSRV func(string, string, string) (string, []*net.SRV, error)
|
||||
LookupTXT func(string) ([]string, error)
|
||||
}
|
||||
|
||||
// DefaultResolver is a Resolver that uses the default Resolver from the net package.
|
||||
var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
|
||||
|
||||
// ParseHosts uses the srv string and service name to get the hosts.
|
||||
func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
|
||||
parsedHosts := strings.Split(host, ",")
|
||||
|
||||
if len(parsedHosts) != 1 {
|
||||
return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
|
||||
}
|
||||
return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
|
||||
}
|
||||
|
||||
// GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
|
||||
func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
|
||||
var connectionArgsFromTXT []string
|
||||
|
||||
// error ignored because not finding a TXT record should not be
|
||||
// considered an error.
|
||||
recordsFromTXT, _ := r.LookupTXT(host)
|
||||
|
||||
// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
|
||||
// It will currently incorrectly concatenate multiple TXT records to one
|
||||
// on windows.
|
||||
if runtime.GOOS == "windows" {
|
||||
recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
|
||||
}
|
||||
|
||||
if len(recordsFromTXT) > 1 {
|
||||
return nil, errors.New("multiple records from TXT not supported")
|
||||
}
|
||||
if len(recordsFromTXT) > 0 {
|
||||
connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
|
||||
|
||||
err := validateTXTResult(connectionArgsFromTXT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return connectionArgsFromTXT, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
|
||||
var err error
|
||||
|
||||
_, _, err = net.SplitHostPort(host)
|
||||
|
||||
if err == nil {
|
||||
// we were able to successfully extract a port from the host,
|
||||
// but should not be able to when using SRV
|
||||
return nil, fmt.Errorf("URI with srv must not include a port number")
|
||||
}
|
||||
|
||||
// default to "mongodb" as service name if not supplied
|
||||
if srvName == "" {
|
||||
srvName = "mongodb"
|
||||
}
|
||||
_, addresses, err := r.LookupSRV(srvName, "tcp", host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trimmedHost := strings.TrimSuffix(host, ".")
|
||||
|
||||
parsedHosts := make([]string, 0, len(addresses))
|
||||
for _, address := range addresses {
|
||||
trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
|
||||
err := validateSRVResult(trimmedAddressTarget, trimmedHost)
|
||||
if err != nil {
|
||||
if stopOnErr {
|
||||
return nil, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
|
||||
}
|
||||
return parsedHosts, nil
|
||||
}
|
||||
|
||||
func validateSRVResult(recordFromSRV, inputHostName string) error {
|
||||
separatedInputDomain := strings.Split(inputHostName, ".")
|
||||
separatedRecord := strings.Split(recordFromSRV, ".")
|
||||
if len(separatedRecord) < 2 {
|
||||
return errors.New("DNS name must contain at least 2 labels")
|
||||
}
|
||||
if len(separatedRecord) < len(separatedInputDomain) {
|
||||
return errors.New("Domain suffix from SRV record not matched input domain")
|
||||
}
|
||||
|
||||
inputDomainSuffix := separatedInputDomain[1:]
|
||||
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
|
||||
|
||||
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
|
||||
for ix, label := range inputDomainSuffix {
|
||||
if label != recordDomainSuffix[ix] {
|
||||
return errors.New("Domain suffix from SRV record not matched input domain")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var allowedTXTOptions = map[string]struct{}{
|
||||
"authsource": {},
|
||||
"replicaset": {},
|
||||
"loadbalanced": {},
|
||||
}
|
||||
|
||||
func validateTXTResult(paramsFromTXT []string) error {
|
||||
for _, param := range paramsFromTXT {
|
||||
kv := strings.SplitN(param, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
return errors.New("Invalid TXT record")
|
||||
}
|
||||
key := strings.ToLower(kv[0])
|
||||
if _, ok := allowedTXTOptions[key]; !ok {
|
||||
return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
269
mongo/x/mongo/driver/driver.go
Normal file
269
mongo/x/mongo/driver/driver.go
Normal file
@@ -0,0 +1,269 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
|
||||
)
|
||||
|
||||
// Deployment is implemented by types that can select a server from a deployment.
|
||||
type Deployment interface {
|
||||
SelectServer(context.Context, description.ServerSelector) (Server, error)
|
||||
Kind() description.TopologyKind
|
||||
}
|
||||
|
||||
// Connector represents a type that can connect to a server.
|
||||
type Connector interface {
|
||||
Connect() error
|
||||
}
|
||||
|
||||
// Disconnector represents a type that can disconnect from a server.
|
||||
type Disconnector interface {
|
||||
Disconnect(context.Context) error
|
||||
}
|
||||
|
||||
// Subscription represents a subscription to topology updates. A subscriber can receive updates through the
|
||||
// Updates field.
|
||||
type Subscription struct {
|
||||
Updates <-chan description.Topology
|
||||
ID uint64
|
||||
}
|
||||
|
||||
// Subscriber represents a type to which another type can subscribe. A subscription contains a channel that
|
||||
// is updated with topology descriptions.
|
||||
type Subscriber interface {
|
||||
Subscribe() (*Subscription, error)
|
||||
Unsubscribe(*Subscription) error
|
||||
}
|
||||
|
||||
// Server represents a MongoDB server. Implementations should pool connections and handle the
|
||||
// retrieving and returning of connections.
|
||||
type Server interface {
|
||||
Connection(context.Context) (Connection, error)
|
||||
|
||||
// RTTMonitor returns the round-trip time monitor associated with this server.
|
||||
RTTMonitor() RTTMonitor
|
||||
}
|
||||
|
||||
// Connection represents a connection to a MongoDB server.
|
||||
type Connection interface {
|
||||
WriteWireMessage(context.Context, []byte) error
|
||||
ReadWireMessage(ctx context.Context) ([]byte, error)
|
||||
Description() description.Server
|
||||
|
||||
// Close closes any underlying connection and returns or frees any resources held by the
|
||||
// connection. Close is idempotent and can be called multiple times, although subsequent calls
|
||||
// to Close may return an error. A connection cannot be used after it is closed.
|
||||
Close() error
|
||||
|
||||
ID() string
|
||||
ServerConnectionID() *int64
|
||||
DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64.
|
||||
Address() address.Address
|
||||
Stale() bool
|
||||
}
|
||||
|
||||
// RTTMonitor represents a round-trip-time monitor.
|
||||
type RTTMonitor interface {
|
||||
// EWMA returns the exponentially weighted moving average observed round-trip time.
|
||||
EWMA() time.Duration
|
||||
|
||||
// Min returns the minimum observed round-trip time over the window period.
|
||||
Min() time.Duration
|
||||
|
||||
// P90 returns the 90th percentile observed round-trip time over the window period.
|
||||
P90() time.Duration
|
||||
|
||||
// Stats returns stringified stats of the current state of the monitor.
|
||||
Stats() string
|
||||
}
|
||||
|
||||
var _ RTTMonitor = &internal.ZeroRTTMonitor{}
|
||||
|
||||
// PinnedConnection represents a Connection that can be pinned by one or more cursors or transactions. Implementations
|
||||
// of this interface should maintain the following invariants:
|
||||
//
|
||||
// 1. Each Pin* call should increment the number of references for the connection.
|
||||
// 2. Each Unpin* call should decrement the number of references for the connection.
|
||||
// 3. Calls to Close() should be ignored until all resources have unpinned the connection.
|
||||
type PinnedConnection interface {
|
||||
Connection
|
||||
PinToCursor() error
|
||||
PinToTransaction() error
|
||||
UnpinFromCursor() error
|
||||
UnpinFromTransaction() error
|
||||
}
|
||||
|
||||
// The session.LoadBalancedTransactionConnection type is a copy of PinnedConnection that was introduced to avoid
|
||||
// import cycles. This compile-time assertion ensures that these types remain in sync if the PinnedConnection interface
|
||||
// is changed in the future.
|
||||
var _ PinnedConnection = (session.LoadBalancedTransactionConnection)(nil)
|
||||
|
||||
// LocalAddresser is a type that is able to supply its local address
|
||||
type LocalAddresser interface {
|
||||
LocalAddress() address.Address
|
||||
}
|
||||
|
||||
// Expirable represents an expirable object.
|
||||
type Expirable interface {
|
||||
Expire() error
|
||||
Alive() bool
|
||||
}
|
||||
|
||||
// StreamerConnection represents a Connection that supports streaming wire protocol messages using the moreToCome and
|
||||
// exhaustAllowed flags.
|
||||
//
|
||||
// The SetStreaming and CurrentlyStreaming functions correspond to the moreToCome flag on server responses. If a
|
||||
// response has moreToCome set, SetStreaming(true) will be called and CurrentlyStreaming() should return true.
|
||||
//
|
||||
// CanStream corresponds to the exhaustAllowed flag. The operations layer will set exhaustAllowed on outgoing wire
|
||||
// messages to inform the server that the driver supports streaming.
|
||||
type StreamerConnection interface {
|
||||
Connection
|
||||
SetStreaming(bool)
|
||||
CurrentlyStreaming() bool
|
||||
SupportsStreaming() bool
|
||||
}
|
||||
|
||||
// Compressor is an interface used to compress wire messages. If a Connection supports compression
|
||||
// it should implement this interface as well. The CompressWireMessage method will be called during
|
||||
// the execution of an operation if the wire message is allowed to be compressed.
|
||||
type Compressor interface {
|
||||
CompressWireMessage(src, dst []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// ProcessErrorResult represents the result of a ErrorProcessor.ProcessError() call. Exact values for this type can be
|
||||
// checked directly (e.g. res == ServerMarkedUnknown), but it is recommended that applications use the ServerChanged()
|
||||
// function instead.
|
||||
type ProcessErrorResult int
|
||||
|
||||
const (
|
||||
// NoChange indicates that the error did not affect the state of the server.
|
||||
NoChange ProcessErrorResult = iota
|
||||
// ServerMarkedUnknown indicates that the error only resulted in the server being marked as Unknown.
|
||||
ServerMarkedUnknown
|
||||
// ConnectionPoolCleared indicates that the error resulted in the server being marked as Unknown and its connection
|
||||
// pool being cleared.
|
||||
ConnectionPoolCleared
|
||||
)
|
||||
|
||||
// ErrorProcessor implementations can handle processing errors, which may modify their internal state.
|
||||
// If this type is implemented by a Server, then Operation.Execute will call it's ProcessError
|
||||
// method after it decodes a wire message.
|
||||
type ErrorProcessor interface {
|
||||
ProcessError(err error, conn Connection) ProcessErrorResult
|
||||
}
|
||||
|
||||
// HandshakeInformation contains information extracted from a MongoDB connection handshake. This is a helper type that
|
||||
// augments description.Server by also tracking server connection ID and authentication-related fields. We use this type
|
||||
// rather than adding authentication-related fields to description.Server to avoid retaining sensitive information in a
|
||||
// user-facing type. The server connection ID is stored in this type because unlike description.Server, all handshakes are
|
||||
// correlated with a single network connection.
|
||||
type HandshakeInformation struct {
|
||||
Description description.Server
|
||||
SpeculativeAuthenticate bsoncore.Document
|
||||
ServerConnectionID *int64
|
||||
SaslSupportedMechs []string
|
||||
}
|
||||
|
||||
// Handshaker is the interface implemented by types that can perform a MongoDB
|
||||
// handshake over a provided driver.Connection. This is used during connection
|
||||
// initialization. Implementations must be goroutine safe.
|
||||
type Handshaker interface {
|
||||
GetHandshakeInformation(context.Context, address.Address, Connection) (HandshakeInformation, error)
|
||||
FinishHandshake(context.Context, Connection) error
|
||||
}
|
||||
|
||||
// SingleServerDeployment is an implementation of Deployment that always returns a single server.
|
||||
type SingleServerDeployment struct{ Server }
|
||||
|
||||
var _ Deployment = SingleServerDeployment{}
|
||||
|
||||
// SelectServer implements the Deployment interface. This method does not use the
|
||||
// description.SelectedServer provided and instead returns the embedded Server.
|
||||
func (ssd SingleServerDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
|
||||
return ssd.Server, nil
|
||||
}
|
||||
|
||||
// Kind implements the Deployment interface. It always returns description.Single.
|
||||
func (SingleServerDeployment) Kind() description.TopologyKind { return description.Single }
|
||||
|
||||
// SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This
|
||||
// implementation should only be used for connection handshakes and server heartbeats as it does not implement
|
||||
// ErrorProcessor, which is necessary for application operations.
|
||||
type SingleConnectionDeployment struct{ C Connection }
|
||||
|
||||
var _ Deployment = SingleConnectionDeployment{}
|
||||
var _ Server = SingleConnectionDeployment{}
|
||||
|
||||
// SelectServer implements the Deployment interface. This method does not use the
|
||||
// description.SelectedServer provided and instead returns itself. The Connections returned from the
|
||||
// Connection method have a no-op Close method.
|
||||
func (ssd SingleConnectionDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
|
||||
return ssd, nil
|
||||
}
|
||||
|
||||
// Kind implements the Deployment interface. It always returns description.Single.
|
||||
func (ssd SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single }
|
||||
|
||||
// Connection implements the Server interface. It always returns the embedded connection.
|
||||
func (ssd SingleConnectionDeployment) Connection(context.Context) (Connection, error) {
|
||||
return ssd.C, nil
|
||||
}
|
||||
|
||||
// RTTMonitor implements the driver.Server interface.
|
||||
func (ssd SingleConnectionDeployment) RTTMonitor() RTTMonitor {
|
||||
return &internal.ZeroRTTMonitor{}
|
||||
}
|
||||
|
||||
// TODO(GODRIVER-617): We can likely use 1 type for both the Type and the RetryMode by using 2 bits for the mode and 1
|
||||
// TODO bit for the type. Although in the practical sense, we might not want to do that since the type of retryability
|
||||
// TODO is tied to the operation itself and isn't going change, e.g. and insert operation will always be a write,
|
||||
// TODO however some operations are both reads and writes, for instance aggregate is a read but with a $out parameter
|
||||
// TODO it's a write.
|
||||
|
||||
// Type specifies whether an operation is a read, write, or unknown.
|
||||
type Type uint
|
||||
|
||||
// THese are the availables types of Type.
|
||||
const (
|
||||
_ Type = iota
|
||||
Write
|
||||
Read
|
||||
)
|
||||
|
||||
// RetryMode specifies the way that retries are handled for retryable operations.
|
||||
type RetryMode uint
|
||||
|
||||
// These are the modes available for retrying. Note that if Timeout is specified on the Client, the
|
||||
// operation will automatically retry as many times as possible within the context's deadline
|
||||
// unless RetryNone is used.
|
||||
const (
|
||||
// RetryNone disables retrying.
|
||||
RetryNone RetryMode = iota
|
||||
// RetryOnce will enable retrying the entire operation once if Timeout is not specified.
|
||||
RetryOnce
|
||||
// RetryOncePerCommand will enable retrying each command associated with an operation if Timeout
|
||||
// is not specified. For example, if an insert is batch split into 4 commands then each of
|
||||
// those commands is eligible for one retry.
|
||||
RetryOncePerCommand
|
||||
// RetryContext will enable retrying until the context.Context's deadline is exceeded or it is
|
||||
// cancelled.
|
||||
RetryContext
|
||||
)
|
||||
|
||||
// Enabled returns if this RetryMode enables retrying.
|
||||
func (rm RetryMode) Enabled() bool {
|
||||
return rm == RetryOnce || rm == RetryOncePerCommand || rm == RetryContext
|
||||
}
|
||||
153
mongo/x/mongo/driver/drivertest/channel_conn.go
Normal file
153
mongo/x/mongo/driver/drivertest/channel_conn.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package drivertest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo/address"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
|
||||
)
|
||||
|
||||
// ChannelConn implements the driver.Connection interface by reading and writing wire messages
|
||||
// to a channel
|
||||
type ChannelConn struct {
|
||||
WriteErr error
|
||||
Written chan []byte
|
||||
ReadResp chan []byte
|
||||
ReadErr chan error
|
||||
Desc description.Server
|
||||
}
|
||||
|
||||
// WriteWireMessage implements the driver.Connection interface.
|
||||
func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error {
|
||||
// Copy wm in case it came from a buffer pool.
|
||||
b := make([]byte, len(wm))
|
||||
copy(b, wm)
|
||||
select {
|
||||
case c.Written <- b:
|
||||
default:
|
||||
c.WriteErr = errors.New("could not write wiremessage to written channel")
|
||||
}
|
||||
return c.WriteErr
|
||||
}
|
||||
|
||||
// ReadWireMessage implements the driver.Connection interface.
|
||||
func (c *ChannelConn) ReadWireMessage(ctx context.Context) ([]byte, error) {
|
||||
var wm []byte
|
||||
var err error
|
||||
select {
|
||||
case wm = <-c.ReadResp:
|
||||
case err = <-c.ReadErr:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return wm, err
|
||||
}
|
||||
|
||||
// Description implements the driver.Connection interface.
|
||||
func (c *ChannelConn) Description() description.Server { return c.Desc }
|
||||
|
||||
// Close implements the driver.Connection interface.
|
||||
func (c *ChannelConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ID implements the driver.Connection interface.
|
||||
func (c *ChannelConn) ID() string {
|
||||
return "faked"
|
||||
}
|
||||
|
||||
// DriverConnectionID implements the driver.Connection interface.
|
||||
// TODO(GODRIVER-2824): replace return type with int64.
|
||||
func (c *ChannelConn) DriverConnectionID() uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// ServerConnectionID implements the driver.Connection interface.
|
||||
func (c *ChannelConn) ServerConnectionID() *int64 {
|
||||
serverConnectionID := int64(42)
|
||||
return &serverConnectionID
|
||||
}
|
||||
|
||||
// Address implements the driver.Connection interface.
|
||||
func (c *ChannelConn) Address() address.Address { return address.Address("0.0.0.0") }
|
||||
|
||||
// Stale implements the driver.Connection interface.
|
||||
func (c *ChannelConn) Stale() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MakeReply creates an OP_REPLY wiremessage from a BSON document
|
||||
func MakeReply(doc bsoncore.Document) []byte {
|
||||
var dst []byte
|
||||
idx, dst := wiremessage.AppendHeaderStart(dst, 10, 9, wiremessage.OpReply)
|
||||
dst = wiremessage.AppendReplyFlags(dst, 0)
|
||||
dst = wiremessage.AppendReplyCursorID(dst, 0)
|
||||
dst = wiremessage.AppendReplyStartingFrom(dst, 0)
|
||||
dst = wiremessage.AppendReplyNumberReturned(dst, 1)
|
||||
dst = append(dst, doc...)
|
||||
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
|
||||
}
|
||||
|
||||
// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message.
|
||||
func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) {
|
||||
var ok bool
|
||||
_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read header")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryFlags(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read flags")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read fullCollectionName")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read numberToSkip")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read numberToReturn")
|
||||
}
|
||||
|
||||
var query bsoncore.Document
|
||||
query, wm, ok = wiremessage.ReadQueryQuery(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read query")
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message.
|
||||
func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) {
|
||||
var ok bool
|
||||
_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read header")
|
||||
}
|
||||
|
||||
_, wm, ok = wiremessage.ReadMsgFlags(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read flags")
|
||||
}
|
||||
_, wm, ok = wiremessage.ReadMsgSectionType(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read section type")
|
||||
}
|
||||
|
||||
cmdDoc, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm)
|
||||
if !ok {
|
||||
return nil, errors.New("could not read command document")
|
||||
}
|
||||
return cmdDoc, nil
|
||||
}
|
||||
102
mongo/x/mongo/driver/drivertest/channel_netconn.go
Normal file
102
mongo/x/mongo/driver/drivertest/channel_netconn.go
Normal file
@@ -0,0 +1,102 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package drivertest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChannelNetConn implements the net.Conn interface by reading and writing wire messages to a channel.
|
||||
type ChannelNetConn struct {
|
||||
WriteErr error
|
||||
Written chan []byte
|
||||
ReadResp chan []byte
|
||||
ReadErr chan error
|
||||
}
|
||||
|
||||
// Read reads data from the connection
|
||||
func (c *ChannelNetConn) Read(b []byte) (int, error) {
|
||||
var wm []byte
|
||||
var err error
|
||||
select {
|
||||
case wm = <-c.ReadResp:
|
||||
case err = <-c.ReadErr:
|
||||
}
|
||||
return copy(b, wm), err
|
||||
}
|
||||
|
||||
// Write writes data to the connection.
|
||||
func (c *ChannelNetConn) Write(b []byte) (int, error) {
|
||||
copyBuf := make([]byte, len(b))
|
||||
copy(copyBuf, b)
|
||||
|
||||
select {
|
||||
case c.Written <- copyBuf:
|
||||
default:
|
||||
c.WriteErr = errors.New("could not write wm to Written channel")
|
||||
}
|
||||
return len(b), c.WriteErr
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *ChannelNetConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
func (c *ChannelNetConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr returns the remote network address.
|
||||
func (c *ChannelNetConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline sets the read and write deadlines associated with the connection.
|
||||
func (c *ChannelNetConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read and write deadlines associated with the connection.
|
||||
func (c *ChannelNetConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the read and write deadlines associated with the connection.
|
||||
func (c *ChannelNetConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWrittenMessage gets the last wire message written to the connection
|
||||
func (c *ChannelNetConn) GetWrittenMessage() []byte {
|
||||
select {
|
||||
case wm := <-c.Written:
|
||||
return wm
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// AddResponse adds a response to the connection.
|
||||
func (c *ChannelNetConn) AddResponse(resp []byte) error {
|
||||
select {
|
||||
case c.ReadResp <- resp[:4]:
|
||||
default:
|
||||
return errors.New("could not write length bytes")
|
||||
}
|
||||
|
||||
select {
|
||||
case c.ReadResp <- resp[4:]:
|
||||
default:
|
||||
return errors.New("could not write response bytes")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
520
mongo/x/mongo/driver/errors.go
Normal file
520
mongo/x/mongo/driver/errors.go
Normal file
@@ -0,0 +1,520 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/internal"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
var (
|
||||
retryableCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001, 262}
|
||||
nodeIsRecoveringCodes = []int32{11600, 11602, 13436, 189, 91}
|
||||
notPrimaryCodes = []int32{10107, 13435, 10058}
|
||||
nodeIsShuttingDownCodes = []int32{11600, 91}
|
||||
|
||||
unknownReplWriteConcernCode = int32(79)
|
||||
unsatisfiableWriteConcernCode = int32(100)
|
||||
)
|
||||
|
||||
var (
|
||||
// UnknownTransactionCommitResult is an error label for unknown transaction commit results.
|
||||
UnknownTransactionCommitResult = "UnknownTransactionCommitResult"
|
||||
// TransientTransactionError is an error label for transient errors with transactions.
|
||||
TransientTransactionError = "TransientTransactionError"
|
||||
// NetworkError is an error label for network errors.
|
||||
NetworkError = "NetworkError"
|
||||
// RetryableWriteError is an error lable for retryable write errors.
|
||||
RetryableWriteError = "RetryableWriteError"
|
||||
// NoWritesPerformed is an error label indicated that no writes were performed for an operation.
|
||||
NoWritesPerformed = "NoWritesPerformed"
|
||||
// ErrCursorNotFound is the cursor not found error for legacy find operations.
|
||||
ErrCursorNotFound = errors.New("cursor not found")
|
||||
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
|
||||
// write concern.
|
||||
ErrUnacknowledgedWrite = errors.New("unacknowledged write")
|
||||
// ErrUnsupportedStorageEngine is returned when a retryable write is attempted against a server
|
||||
// that uses a storage engine that does not support retryable writes
|
||||
ErrUnsupportedStorageEngine = errors.New("this MongoDB deployment does not support retryable writes. Please add retryWrites=false to your connection string")
|
||||
// ErrDeadlineWouldBeExceeded is returned when a Timeout set on an operation would be exceeded
|
||||
// if the operation were sent to the server.
|
||||
ErrDeadlineWouldBeExceeded = errors.New("operation not sent to server, as Timeout would be exceeded")
|
||||
// ErrNegativeMaxTime is returned when MaxTime on an operation is a negative value.
|
||||
ErrNegativeMaxTime = errors.New("a negative value was provided for MaxTime on an operation")
|
||||
)
|
||||
|
||||
// QueryFailureError is an error representing a command failure as a document.
|
||||
type QueryFailureError struct {
|
||||
Message string
|
||||
Response bsoncore.Document
|
||||
Wrapped error
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e QueryFailureError) Error() string {
|
||||
return fmt.Sprintf("%s: %v", e.Message, e.Response)
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (e QueryFailureError) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// ResponseError is an error parsing the response to a command.
|
||||
type ResponseError struct {
|
||||
Message string
|
||||
Wrapped error
|
||||
}
|
||||
|
||||
// NewCommandResponseError creates a CommandResponseError.
|
||||
func NewCommandResponseError(msg string, err error) ResponseError {
|
||||
return ResponseError{Message: msg, Wrapped: err}
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e ResponseError) Error() string {
|
||||
if e.Wrapped != nil {
|
||||
return fmt.Sprintf("%s: %s", e.Message, e.Wrapped)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// WriteCommandError is an error for a write command.
|
||||
type WriteCommandError struct {
|
||||
WriteConcernError *WriteConcernError
|
||||
WriteErrors WriteErrors
|
||||
Labels []string
|
||||
Raw bsoncore.Document
|
||||
}
|
||||
|
||||
// UnsupportedStorageEngine returns whether or not the WriteCommandError comes from a retryable write being attempted
|
||||
// against a server that has a storage engine where they are not supported
|
||||
func (wce WriteCommandError) UnsupportedStorageEngine() bool {
|
||||
for _, writeError := range wce.WriteErrors {
|
||||
if writeError.Code == 20 && strings.HasPrefix(strings.ToLower(writeError.Message), "transaction numbers") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (wce WriteCommandError) Error() string {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprint(&buf, "write command error: [")
|
||||
fmt.Fprintf(&buf, "{%s}, ", wce.WriteErrors)
|
||||
fmt.Fprintf(&buf, "{%s}]", wce.WriteConcernError)
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Retryable returns true if the error is retryable
|
||||
func (wce WriteCommandError) Retryable(wireVersion *description.VersionRange) bool {
|
||||
for _, label := range wce.Labels {
|
||||
if label == RetryableWriteError {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if wireVersion != nil && wireVersion.Max >= 9 {
|
||||
return false
|
||||
}
|
||||
|
||||
if wce.WriteConcernError == nil {
|
||||
return false
|
||||
}
|
||||
return (*wce.WriteConcernError).Retryable()
|
||||
}
|
||||
|
||||
// HasErrorLabel returns true if the error contains the specified label.
|
||||
func (wce WriteCommandError) HasErrorLabel(label string) bool {
|
||||
if wce.Labels != nil {
|
||||
for _, l := range wce.Labels {
|
||||
if l == label {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// WriteConcernError is a write concern failure that occurred as a result of a
|
||||
// write operation.
|
||||
type WriteConcernError struct {
|
||||
Name string
|
||||
Code int64
|
||||
Message string
|
||||
Details bsoncore.Document
|
||||
Labels []string
|
||||
TopologyVersion *description.TopologyVersion
|
||||
Raw bsoncore.Document
|
||||
}
|
||||
|
||||
func (wce WriteConcernError) Error() string {
|
||||
if wce.Name != "" {
|
||||
return fmt.Sprintf("(%v) %v", wce.Name, wce.Message)
|
||||
}
|
||||
return wce.Message
|
||||
}
|
||||
|
||||
// Retryable returns true if the error is retryable
|
||||
func (wce WriteConcernError) Retryable() bool {
|
||||
for _, code := range retryableCodes {
|
||||
if wce.Code == int64(code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// NodeIsRecovering returns true if this error is a node is recovering error.
|
||||
func (wce WriteConcernError) NodeIsRecovering() bool {
|
||||
for _, code := range nodeIsRecoveringCodes {
|
||||
if wce.Code == int64(code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
hasNoCode := wce.Code == 0
|
||||
return hasNoCode && strings.Contains(wce.Message, "node is recovering")
|
||||
}
|
||||
|
||||
// NodeIsShuttingDown returns true if this error is a node is shutting down error.
|
||||
func (wce WriteConcernError) NodeIsShuttingDown() bool {
|
||||
for _, code := range nodeIsShuttingDownCodes {
|
||||
if wce.Code == int64(code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
hasNoCode := wce.Code == 0
|
||||
return hasNoCode && strings.Contains(wce.Message, "node is shutting down")
|
||||
}
|
||||
|
||||
// NotPrimary returns true if this error is a not primary error.
|
||||
func (wce WriteConcernError) NotPrimary() bool {
|
||||
for _, code := range notPrimaryCodes {
|
||||
if wce.Code == int64(code) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
hasNoCode := wce.Code == 0
|
||||
return hasNoCode && strings.Contains(wce.Message, internal.LegacyNotPrimary)
|
||||
}
|
||||
|
||||
// WriteError is a non-write concern failure that occurred as a result of a write
|
||||
// operation.
|
||||
type WriteError struct {
|
||||
Index int64
|
||||
Code int64
|
||||
Message string
|
||||
Details bsoncore.Document
|
||||
Raw bsoncore.Document
|
||||
}
|
||||
|
||||
func (we WriteError) Error() string { return we.Message }
|
||||
|
||||
// WriteErrors is a group of non-write concern failures that occurred as a result
|
||||
// of a write operation.
|
||||
type WriteErrors []WriteError
|
||||
|
||||
func (we WriteErrors) Error() string {
|
||||
var buf bytes.Buffer
|
||||
fmt.Fprint(&buf, "write errors: [")
|
||||
for idx, err := range we {
|
||||
if idx != 0 {
|
||||
fmt.Fprintf(&buf, ", ")
|
||||
}
|
||||
fmt.Fprintf(&buf, "{%s}", err)
|
||||
}
|
||||
fmt.Fprint(&buf, "]")
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Error is a command execution error from the database.
|
||||
type Error struct {
|
||||
Code int32
|
||||
Message string
|
||||
Labels []string
|
||||
Name string
|
||||
Wrapped error
|
||||
TopologyVersion *description.TopologyVersion
|
||||
Raw bsoncore.Document
|
||||
}
|
||||
|
||||
// UnsupportedStorageEngine returns whether e came as a result of an unsupported storage engine
|
||||
func (e Error) UnsupportedStorageEngine() bool {
|
||||
return e.Code == 20 && strings.HasPrefix(strings.ToLower(e.Message), "transaction numbers")
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e Error) Error() string {
|
||||
if e.Name != "" {
|
||||
return fmt.Sprintf("(%v) %v", e.Name, e.Message)
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying error.
|
||||
func (e Error) Unwrap() error {
|
||||
return e.Wrapped
|
||||
}
|
||||
|
||||
// HasErrorLabel returns true if the error contains the specified label.
|
||||
func (e Error) HasErrorLabel(label string) bool {
|
||||
if e.Labels != nil {
|
||||
for _, l := range e.Labels {
|
||||
if l == label {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RetryableRead returns true if the error is retryable for a read operation
|
||||
func (e Error) RetryableRead() bool {
|
||||
for _, label := range e.Labels {
|
||||
if label == NetworkError {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, code := range retryableCodes {
|
||||
if e.Code == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RetryableWrite returns true if the error is retryable for a write operation
|
||||
func (e Error) RetryableWrite(wireVersion *description.VersionRange) bool {
|
||||
for _, label := range e.Labels {
|
||||
if label == NetworkError || label == RetryableWriteError {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if wireVersion != nil && wireVersion.Max >= 9 {
|
||||
return false
|
||||
}
|
||||
for _, code := range retryableCodes {
|
||||
if e.Code == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// NetworkError returns true if the error is a network error.
|
||||
func (e Error) NetworkError() bool {
|
||||
for _, label := range e.Labels {
|
||||
if label == NetworkError {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NodeIsRecovering returns true if this error is a node is recovering error.
|
||||
func (e Error) NodeIsRecovering() bool {
|
||||
for _, code := range nodeIsRecoveringCodes {
|
||||
if e.Code == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
hasNoCode := e.Code == 0
|
||||
return hasNoCode && strings.Contains(e.Message, "node is recovering")
|
||||
}
|
||||
|
||||
// NodeIsShuttingDown returns true if this error is a node is shutting down error.
|
||||
func (e Error) NodeIsShuttingDown() bool {
|
||||
for _, code := range nodeIsShuttingDownCodes {
|
||||
if e.Code == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
hasNoCode := e.Code == 0
|
||||
return hasNoCode && strings.Contains(e.Message, "node is shutting down")
|
||||
}
|
||||
|
||||
// NotPrimary returns true if this error is a not primary error.
|
||||
func (e Error) NotPrimary() bool {
|
||||
for _, code := range notPrimaryCodes {
|
||||
if e.Code == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
hasNoCode := e.Code == 0
|
||||
return hasNoCode && strings.Contains(e.Message, internal.LegacyNotPrimary)
|
||||
}
|
||||
|
||||
// NamespaceNotFound returns true if this errors is a NamespaceNotFound error.
|
||||
func (e Error) NamespaceNotFound() bool {
|
||||
return e.Code == 26 || e.Message == "ns not found"
|
||||
}
|
||||
|
||||
// ExtractErrorFromServerResponse extracts an error from a server response bsoncore.Document
|
||||
// if there is one. Also used in testing for SDAM.
|
||||
func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
|
||||
var errmsg, codeName string
|
||||
var code int32
|
||||
var labels []string
|
||||
var ok bool
|
||||
var tv *description.TopologyVersion
|
||||
var wcError WriteCommandError
|
||||
elems, err := doc.Elements()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, elem := range elems {
|
||||
switch elem.Key() {
|
||||
case "ok":
|
||||
switch elem.Value().Type {
|
||||
case bson.TypeInt32:
|
||||
if elem.Value().Int32() == 1 {
|
||||
ok = true
|
||||
}
|
||||
case bson.TypeInt64:
|
||||
if elem.Value().Int64() == 1 {
|
||||
ok = true
|
||||
}
|
||||
case bson.TypeDouble:
|
||||
if elem.Value().Double() == 1 {
|
||||
ok = true
|
||||
}
|
||||
}
|
||||
case "errmsg":
|
||||
if str, okay := elem.Value().StringValueOK(); okay {
|
||||
errmsg = str
|
||||
}
|
||||
case "codeName":
|
||||
if str, okay := elem.Value().StringValueOK(); okay {
|
||||
codeName = str
|
||||
}
|
||||
case "code":
|
||||
if c, okay := elem.Value().Int32OK(); okay {
|
||||
code = c
|
||||
}
|
||||
case "errorLabels":
|
||||
if arr, okay := elem.Value().ArrayOK(); okay {
|
||||
vals, err := arr.Values()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, val := range vals {
|
||||
if str, ok := val.StringValueOK(); ok {
|
||||
labels = append(labels, str)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
case "writeErrors":
|
||||
arr, exists := elem.Value().ArrayOK()
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
vals, err := arr.Values()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, val := range vals {
|
||||
var we WriteError
|
||||
doc, exists := val.DocumentOK()
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if index, exists := doc.Lookup("index").AsInt64OK(); exists {
|
||||
we.Index = index
|
||||
}
|
||||
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
|
||||
we.Code = code
|
||||
}
|
||||
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
|
||||
we.Message = msg
|
||||
}
|
||||
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
|
||||
we.Details = make([]byte, len(info))
|
||||
copy(we.Details, info)
|
||||
}
|
||||
we.Raw = doc
|
||||
wcError.WriteErrors = append(wcError.WriteErrors, we)
|
||||
}
|
||||
case "writeConcernError":
|
||||
doc, exists := elem.Value().DocumentOK()
|
||||
if !exists {
|
||||
break
|
||||
}
|
||||
wcError.WriteConcernError = new(WriteConcernError)
|
||||
wcError.WriteConcernError.Raw = doc
|
||||
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
|
||||
wcError.WriteConcernError.Code = code
|
||||
}
|
||||
if name, exists := doc.Lookup("codeName").StringValueOK(); exists {
|
||||
wcError.WriteConcernError.Name = name
|
||||
}
|
||||
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
|
||||
wcError.WriteConcernError.Message = msg
|
||||
}
|
||||
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
|
||||
wcError.WriteConcernError.Details = make([]byte, len(info))
|
||||
copy(wcError.WriteConcernError.Details, info)
|
||||
}
|
||||
if errLabels, exists := doc.Lookup("errorLabels").ArrayOK(); exists {
|
||||
vals, err := errLabels.Values()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, val := range vals {
|
||||
if str, ok := val.StringValueOK(); ok {
|
||||
labels = append(labels, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "topologyVersion":
|
||||
doc, ok := elem.Value().DocumentOK()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
version, err := description.NewTopologyVersion(bson.Raw(doc))
|
||||
if err == nil {
|
||||
tv = version
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !ok {
|
||||
if errmsg == "" {
|
||||
errmsg = "command failed"
|
||||
}
|
||||
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: errmsg,
|
||||
Name: codeName,
|
||||
Labels: labels,
|
||||
TopologyVersion: tv,
|
||||
Raw: doc,
|
||||
}
|
||||
}
|
||||
|
||||
if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {
|
||||
wcError.Labels = labels
|
||||
if wcError.WriteConcernError != nil {
|
||||
wcError.WriteConcernError.TopologyVersion = tv
|
||||
}
|
||||
wcError.Raw = doc
|
||||
return wcError
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
195
mongo/x/mongo/driver/integration/aggregate_test.go
Normal file
195
mongo/x/mongo/driver/integration/aggregate_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/event"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/mongo/writeconcern"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
|
||||
)
|
||||
|
||||
func setUpMonitor() (*event.CommandMonitor, chan *event.CommandStartedEvent, chan *event.CommandSucceededEvent, chan *event.CommandFailedEvent) {
|
||||
started := make(chan *event.CommandStartedEvent, 1)
|
||||
succeeded := make(chan *event.CommandSucceededEvent, 1)
|
||||
failed := make(chan *event.CommandFailedEvent, 1)
|
||||
|
||||
return &event.CommandMonitor{
|
||||
Started: func(ctx context.Context, e *event.CommandStartedEvent) {
|
||||
started <- e
|
||||
},
|
||||
Succeeded: func(ctx context.Context, e *event.CommandSucceededEvent) {
|
||||
succeeded <- e
|
||||
},
|
||||
Failed: func(ctx context.Context, e *event.CommandFailedEvent) {
|
||||
failed <- e
|
||||
},
|
||||
}, started, succeeded, failed
|
||||
}
|
||||
|
||||
func skipIfBelow32(ctx context.Context, t *testing.T, topo *topology.Topology) {
|
||||
server, err := topo.SelectServer(ctx, description.WriteSelector())
|
||||
noerr(t, err)
|
||||
|
||||
versionCmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "serverStatus", 1))
|
||||
serverStatus, err := testutil.RunCommand(t, server, dbName, versionCmd)
|
||||
noerr(t, err)
|
||||
version, err := serverStatus.LookupErr("version")
|
||||
noerr(t, err)
|
||||
|
||||
if testutil.CompareVersions(t, version.StringValue(), "3.2") < 0 {
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregate(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
t.Run("TestMaxTimeMSInGetMore", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
monitor, started, succeeded, failed := setUpMonitor()
|
||||
dbName := "TestAggMaxTimeDB"
|
||||
collName := "TestAggMaxTimeColl"
|
||||
top := testutil.MonitoredTopology(t, dbName, monitor)
|
||||
clearChannels(started, succeeded, failed)
|
||||
skipIfBelow32(ctx, t, top)
|
||||
|
||||
clearChannels(started, succeeded, failed)
|
||||
err := operation.NewInsert(
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)),
|
||||
).Collection(collName).Database(dbName).
|
||||
Deployment(top).ServerSelector(description.WriteSelector()).Execute(context.Background())
|
||||
noerr(t, err)
|
||||
|
||||
clearChannels(started, succeeded, failed)
|
||||
op := operation.NewAggregate(bsoncore.BuildDocumentFromElements(nil)).
|
||||
Collection(collName).Database(dbName).Deployment(top).ServerSelector(description.WriteSelector()).
|
||||
CommandMonitor(monitor).BatchSize(2)
|
||||
err = op.Execute(context.Background())
|
||||
noerr(t, err)
|
||||
batchCursor, err := op.Result(driver.CursorOptions{MaxTimeMS: 10, BatchSize: 2, CommandMonitor: monitor})
|
||||
noerr(t, err)
|
||||
|
||||
var e *event.CommandStartedEvent
|
||||
select {
|
||||
case e = <-started:
|
||||
case <-time.After(2000 * time.Millisecond):
|
||||
t.Fatal("timed out waiting for aggregate")
|
||||
}
|
||||
|
||||
require.Equal(t, "aggregate", e.CommandName)
|
||||
|
||||
clearChannels(started, succeeded, failed)
|
||||
// first Next() should automatically return true
|
||||
require.True(t, batchCursor.Next(ctx), "expected true from first Next, got false")
|
||||
clearChannels(started, succeeded, failed)
|
||||
batchCursor.Next(ctx) // should do getMore
|
||||
|
||||
select {
|
||||
case e = <-started:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("timed out waiting for getMore")
|
||||
}
|
||||
require.Equal(t, "getMore", e.CommandName)
|
||||
_, err = e.Command.LookupErr("maxTimeMS")
|
||||
noerr(t, err)
|
||||
})
|
||||
t.Run("Multiple Batches", func(t *testing.T) {
|
||||
ds := []bsoncore.Document{
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 1)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 2)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 3)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 4)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 5)),
|
||||
}
|
||||
wc := writeconcern.New(writeconcern.WMajority())
|
||||
testutil.AutoInsertDocs(t, wc, ds...)
|
||||
|
||||
op := operation.NewAggregate(bsoncore.BuildArray(nil,
|
||||
bsoncore.BuildDocumentValue(
|
||||
bsoncore.BuildDocumentElement(nil,
|
||||
"$match", bsoncore.BuildDocumentElement(nil,
|
||||
"_id", bsoncore.AppendInt32Element(nil, "$gt", 2),
|
||||
),
|
||||
),
|
||||
),
|
||||
bsoncore.BuildDocumentValue(
|
||||
bsoncore.BuildDocumentElement(nil,
|
||||
"$sort", bsoncore.AppendInt32Element(nil, "_id", 1),
|
||||
),
|
||||
),
|
||||
)).Collection(testutil.ColName(t)).Database(dbName).Deployment(testutil.Topology(t)).
|
||||
ServerSelector(description.WriteSelector()).BatchSize(2)
|
||||
err := op.Execute(context.Background())
|
||||
noerr(t, err)
|
||||
cursor, err := op.Result(driver.CursorOptions{BatchSize: 2})
|
||||
noerr(t, err)
|
||||
|
||||
var got []bsoncore.Document
|
||||
for i := 0; i < 2; i++ {
|
||||
if !cursor.Next(context.Background()) {
|
||||
t.Error("Cursor should have results, but does not have a next result")
|
||||
}
|
||||
docs, err := cursor.Batch().Documents()
|
||||
noerr(t, err)
|
||||
got = append(got, docs...)
|
||||
}
|
||||
readers := ds[2:]
|
||||
for i, g := range got {
|
||||
if !bytes.Equal(g[:len(readers[i])], readers[i]) {
|
||||
t.Errorf("Did not get expected document. got %v; want %v", bson.Raw(g[:len(readers[i])]), readers[i])
|
||||
}
|
||||
}
|
||||
|
||||
if cursor.Next(context.Background()) {
|
||||
t.Error("Cursor should be exhausted but has more results")
|
||||
}
|
||||
})
|
||||
t.Run("AllowDiskUse", func(t *testing.T) {
|
||||
ds := []bsoncore.Document{
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 1)),
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 2)),
|
||||
}
|
||||
wc := writeconcern.New(writeconcern.WMajority())
|
||||
testutil.AutoInsertDocs(t, wc, ds...)
|
||||
|
||||
op := operation.NewAggregate(bsoncore.BuildArray(nil)).Collection(testutil.ColName(t)).Database(dbName).
|
||||
Deployment(testutil.Topology(t)).ServerSelector(description.WriteSelector()).AllowDiskUse(true)
|
||||
err := op.Execute(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error from allowing disk use, but got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func clearChannels(s chan *event.CommandStartedEvent, succ chan *event.CommandSucceededEvent, f chan *event.CommandFailedEvent) {
|
||||
for len(s) > 0 {
|
||||
<-s
|
||||
}
|
||||
for len(succ) > 0 {
|
||||
<-succ
|
||||
}
|
||||
for len(f) > 0 {
|
||||
<-f
|
||||
}
|
||||
}
|
||||
76
mongo/x/mongo/driver/integration/compressor_test.go
Normal file
76
mongo/x/mongo/driver/integration/compressor_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil"
|
||||
"go.mongodb.org/mongo-driver/mongo/writeconcern"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
)
|
||||
|
||||
func TestCompression(t *testing.T) {
|
||||
comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
|
||||
if len(comp) == 0 {
|
||||
t.Skip("Skipping because no compressor specified")
|
||||
}
|
||||
|
||||
wc := writeconcern.New(writeconcern.WMajority())
|
||||
collOne := testutil.ColName(t)
|
||||
|
||||
testutil.DropCollection(t, testutil.DBName(t), collOne)
|
||||
testutil.InsertDocs(t, testutil.DBName(t), collOne, wc,
|
||||
bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "name", "compression_test")),
|
||||
)
|
||||
|
||||
cmd := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "serverStatus", 1))).
|
||||
Deployment(testutil.Topology(t)).
|
||||
Database(testutil.DBName(t))
|
||||
|
||||
ctx := context.Background()
|
||||
err := cmd.Execute(ctx)
|
||||
noerr(t, err)
|
||||
result := cmd.Result()
|
||||
|
||||
serverVersion, err := result.LookupErr("version")
|
||||
noerr(t, err)
|
||||
|
||||
if testutil.CompareVersions(t, serverVersion.StringValue(), "3.4") < 0 {
|
||||
t.Skip("skipping compression test for version < 3.4")
|
||||
}
|
||||
|
||||
networkVal, err := result.LookupErr("network")
|
||||
noerr(t, err)
|
||||
|
||||
require.Equal(t, networkVal.Type, bson.TypeEmbeddedDocument)
|
||||
|
||||
compressionVal, err := networkVal.Document().LookupErr("compression")
|
||||
noerr(t, err)
|
||||
|
||||
compressorDoc, err := compressionVal.Document().LookupErr(comp)
|
||||
noerr(t, err)
|
||||
|
||||
compressorKey := "compressor"
|
||||
compareTo36 := testutil.CompareVersions(t, serverVersion.StringValue(), "3.6")
|
||||
if compareTo36 < 0 {
|
||||
compressorKey = "compressed"
|
||||
}
|
||||
compressor, err := compressorDoc.Document().LookupErr(compressorKey)
|
||||
noerr(t, err)
|
||||
|
||||
bytesIn, err := compressor.Document().LookupErr("bytesIn")
|
||||
noerr(t, err)
|
||||
|
||||
require.True(t, bytesIn.IsNumber())
|
||||
require.True(t, bytesIn.Int64() > 0)
|
||||
}
|
||||
56
mongo/x/mongo/driver/integration/insert_test.go
Normal file
56
mongo/x/mongo/driver/integration/insert_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
|
||||
)
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
t.Skip()
|
||||
topo, err := topology.New(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't connect topology: %v", err)
|
||||
}
|
||||
_ = topo.Connect()
|
||||
|
||||
doc := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
|
||||
|
||||
iop := operation.NewInsert(doc).Database("foo").Collection("bar").Deployment(topo)
|
||||
err = iop.Execute(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't execute insert operation: %v", err)
|
||||
}
|
||||
t.Log(iop.Result())
|
||||
|
||||
fop := operation.NewFind(bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))).
|
||||
Database("foo").Collection("bar").Deployment(topo).BatchSize(1)
|
||||
err = fop.Execute(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't execute find operation: %v", err)
|
||||
}
|
||||
cur, err := fop.Result(driver.CursorOptions{BatchSize: 2})
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't get cursor result from find operation: %v", err)
|
||||
}
|
||||
for cur.Next(context.Background()) {
|
||||
batch := cur.Batch()
|
||||
docs, err := batch.Documents()
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't iterate batch: %v", err)
|
||||
}
|
||||
for i, doc := range docs {
|
||||
t.Log(i, doc)
|
||||
}
|
||||
}
|
||||
}
|
||||
7
mongo/x/mongo/driver/integration/integration.go
Normal file
7
mongo/x/mongo/driver/integration/integration.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package integration
|
||||
113
mongo/x/mongo/driver/integration/main_test.go
Normal file
113
mongo/x/mongo/driver/integration/main_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
|
||||
)
|
||||
|
||||
var host *string
|
||||
var connectionString connstring.ConnString
|
||||
var dbName string
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
flag.Parse()
|
||||
|
||||
mongodbURI := os.Getenv("MONGODB_URI")
|
||||
if mongodbURI == "" {
|
||||
mongodbURI = "mongodb://localhost:27017"
|
||||
}
|
||||
|
||||
mongodbURI = addTLSConfigToURI(mongodbURI)
|
||||
mongodbURI = addCompressorToURI(mongodbURI)
|
||||
|
||||
var err error
|
||||
connectionString, err = connstring.ParseAndValidate(mongodbURI)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not parse connection string: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
host = &connectionString.Hosts[0]
|
||||
|
||||
dbName = fmt.Sprintf("mongo-go-driver-%d", os.Getpid())
|
||||
if connectionString.Database != "" {
|
||||
dbName = connectionString.Database
|
||||
}
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func noerr(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Helper()
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func autherr(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
switch e := err.(type) {
|
||||
case topology.ConnectionError:
|
||||
_, ok := e.Wrapped.(*auth.Error)
|
||||
if !ok {
|
||||
t.Fatal("Expected auth error and didn't get one")
|
||||
}
|
||||
case *auth.Error:
|
||||
return
|
||||
default:
|
||||
t.Fatal("Expected auth error and didn't get one")
|
||||
}
|
||||
}
|
||||
|
||||
// addTLSConfigToURI checks for the environmental variable indicating that the tests are being run
|
||||
// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration.
|
||||
func addTLSConfigToURI(uri string) string {
|
||||
caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE")
|
||||
if len(caFile) == 0 {
|
||||
return uri
|
||||
}
|
||||
|
||||
if !strings.ContainsRune(uri, '?') {
|
||||
if uri[len(uri)-1] != '/' {
|
||||
uri += "/"
|
||||
}
|
||||
|
||||
uri += "?"
|
||||
} else {
|
||||
uri += "&"
|
||||
}
|
||||
|
||||
return uri + "ssl=true&sslCertificateAuthorityFile=" + caFile
|
||||
}
|
||||
|
||||
func addCompressorToURI(uri string) string {
|
||||
comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
|
||||
if len(comp) == 0 {
|
||||
return uri
|
||||
}
|
||||
|
||||
if !strings.ContainsRune(uri, '?') {
|
||||
if uri[len(uri)-1] != '/' {
|
||||
uri += "/"
|
||||
}
|
||||
|
||||
uri += "?"
|
||||
} else {
|
||||
uri += "&"
|
||||
}
|
||||
|
||||
return uri + "compressors=" + comp
|
||||
}
|
||||
178
mongo/x/mongo/driver/integration/scram_test.go
Normal file
178
mongo/x/mongo/driver/integration/scram_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"go.mongodb.org/mongo-driver/internal/testutil"
|
||||
"go.mongodb.org/mongo-driver/mongo/description"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/writeconcern"
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver"
|
||||
)
|
||||
|
||||
type scramTestCase struct {
|
||||
username string
|
||||
password string
|
||||
mechanisms []string
|
||||
altPassword string
|
||||
}
|
||||
|
||||
func TestSCRAM(t *testing.T) {
|
||||
if os.Getenv("AUTH") != "auth" {
|
||||
t.Skip("Skipping because authentication is required")
|
||||
}
|
||||
|
||||
server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector())
|
||||
noerr(t, err)
|
||||
serverConnection, err := server.Connection(context.Background())
|
||||
noerr(t, err)
|
||||
defer serverConnection.Close()
|
||||
|
||||
if !serverConnection.Description().WireVersion.Includes(7) {
|
||||
t.Skip("Skipping because MongoDB 4.0 is needed for SCRAM-SHA-256")
|
||||
}
|
||||
|
||||
// Unicode constants for testing
|
||||
var romanFour = "\u2163" // ROMAN NUMERAL FOUR -> SASL prepped is "IV"
|
||||
var romanNine = "\u2168" // ROMAN NUMERAL NINE -> SASL prepped is "IX"
|
||||
|
||||
testUsers := []scramTestCase{
|
||||
// SCRAM spec test steps 1-3
|
||||
{username: "sha1", password: "sha1", mechanisms: []string{"SCRAM-SHA-1"}},
|
||||
{username: "sha256", password: "sha256", mechanisms: []string{"SCRAM-SHA-256"}},
|
||||
{username: "both", password: "both", mechanisms: []string{"SCRAM-SHA-1", "SCRAM-SHA-256"}},
|
||||
// SCRAM spec test step 4
|
||||
{username: "IX", password: "IX", mechanisms: []string{"SCRAM-SHA-256"}, altPassword: "I\u00ADX"},
|
||||
{username: romanNine, password: romanFour, mechanisms: []string{"SCRAM-SHA-256"}, altPassword: "I\u00ADV"},
|
||||
}
|
||||
|
||||
// Verify that test (root) user is authenticated. If this fails, the
|
||||
// rest of the test can't succeed.
|
||||
wc := writeconcern.New(writeconcern.WMajority())
|
||||
collOne := testutil.ColName(t)
|
||||
testutil.DropCollection(t, testutil.DBName(t), collOne)
|
||||
testutil.InsertDocs(t, testutil.DBName(t),
|
||||
collOne, wc, bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "name", "scram_test")),
|
||||
)
|
||||
|
||||
// Test step 1: Create users for test cases
|
||||
err = createScramUsers(t, server, testUsers)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Step 2 and 3a: For each auth mechanism, "SCRAM-SHA-1", "SCRAM-SHA-256"
|
||||
// and "negotiate" (a fake, placeholder mechanism), iterate over each user
|
||||
// and ensure that each mechanism that should succeed does so and each
|
||||
// that should fail does so.
|
||||
for _, m := range []string{"SCRAM-SHA-1", "SCRAM-SHA-256", "negotiate"} {
|
||||
for _, c := range testUsers {
|
||||
t.Run(
|
||||
fmt.Sprintf("%s %s", c.username, m),
|
||||
func(t *testing.T) {
|
||||
err := testScramUserAuthWithMech(t, c, m)
|
||||
if m == "negotiate" || hasAuthMech(c.mechanisms, m) {
|
||||
noerr(t, err)
|
||||
} else {
|
||||
autherr(t, err)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3b: test non-existing user with negotiation fails with
|
||||
// an auth.Error type.
|
||||
bogus := scramTestCase{username: "eliot", password: "trustno1"}
|
||||
err = testScramUserAuthWithMech(t, bogus, "negotiate")
|
||||
autherr(t, err)
|
||||
|
||||
// XXX Step 4: test alternate password forms
|
||||
for _, c := range testUsers {
|
||||
if c.altPassword == "" {
|
||||
continue
|
||||
}
|
||||
c.password = c.altPassword
|
||||
t.Run(
|
||||
fmt.Sprintf("%s alternate password", c.username),
|
||||
func(t *testing.T) {
|
||||
err := testScramUserAuthWithMech(t, c, "SCRAM-SHA-256")
|
||||
noerr(t, err)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func hasAuthMech(mechs []string, m string) bool {
|
||||
for _, v := range mechs {
|
||||
if v == m {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func testScramUserAuthWithMech(t *testing.T, c scramTestCase, mech string) error {
|
||||
t.Helper()
|
||||
credential := options.Credential{
|
||||
Username: c.username,
|
||||
Password: c.password,
|
||||
AuthSource: testutil.DBName(t),
|
||||
}
|
||||
switch mech {
|
||||
case "negotiate":
|
||||
credential.AuthMechanism = ""
|
||||
default:
|
||||
credential.AuthMechanism = mech
|
||||
}
|
||||
return runScramAuthTest(t, credential)
|
||||
}
|
||||
|
||||
func runScramAuthTest(t *testing.T, credential options.Credential) error {
|
||||
t.Helper()
|
||||
topology := testutil.TopologyWithCredential(t, credential)
|
||||
server, err := topology.SelectServer(context.Background(), description.WriteSelector())
|
||||
noerr(t, err)
|
||||
|
||||
cmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dbstats", 1))
|
||||
_, err = testutil.RunCommand(t, server, testutil.DBName(t), cmd)
|
||||
return err
|
||||
}
|
||||
|
||||
func createScramUsers(t *testing.T, s driver.Server, cases []scramTestCase) error {
|
||||
db := testutil.DBName(t)
|
||||
for _, c := range cases {
|
||||
var values []bsoncore.Value
|
||||
for _, v := range c.mechanisms {
|
||||
values = append(values, bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, v)})
|
||||
}
|
||||
newUserCmd := bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendStringElement(nil, "createUser", c.username),
|
||||
bsoncore.AppendStringElement(nil, "pwd", c.password),
|
||||
bsoncore.AppendArrayElement(nil, "roles", bsoncore.BuildArray(nil,
|
||||
bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: bsoncore.BuildDocumentFromElements(nil,
|
||||
bsoncore.AppendStringElement(nil, "role", "readWrite"),
|
||||
bsoncore.AppendStringElement(nil, "db", db),
|
||||
)},
|
||||
)),
|
||||
bsoncore.AppendArrayElement(nil, "mechanisms", bsoncore.BuildArray(nil, values...)),
|
||||
)
|
||||
_, err := testutil.RunCommand(t, s, db, newUserCmd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't create user '%s' on db '%s': %v", c.username, testutil.DBName(t), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
22
mongo/x/mongo/driver/legacy.go
Normal file
22
mongo/x/mongo/driver/legacy.go
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (C) MongoDB, Inc. 2022-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
// LegacyOperationKind indicates if an operation is a legacy find, getMore, or killCursors. This is used
|
||||
// in Operation.Execute, which will create legacy OP_QUERY, OP_GET_MORE, or OP_KILL_CURSORS instead
|
||||
// of sending them as a command.
|
||||
type LegacyOperationKind uint
|
||||
|
||||
// These constants represent the three different kinds of legacy operations.
|
||||
const (
|
||||
LegacyNone LegacyOperationKind = iota
|
||||
LegacyFind
|
||||
LegacyGetMore
|
||||
LegacyKillCursors
|
||||
LegacyListCollections
|
||||
LegacyListIndexes
|
||||
)
|
||||
134
mongo/x/mongo/driver/list_collections_batch_cursor.go
Normal file
134
mongo/x/mongo/driver/list_collections_batch_cursor.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
|
||||
)
|
||||
|
||||
// ListCollectionsBatchCursor is a special batch cursor returned from ListCollections that properly
|
||||
// handles current and legacy ListCollections operations.
|
||||
type ListCollectionsBatchCursor struct {
|
||||
legacy bool // server version < 3.0
|
||||
bc *BatchCursor
|
||||
currentBatch *bsoncore.DocumentSequence
|
||||
err error
|
||||
}
|
||||
|
||||
// NewListCollectionsBatchCursor creates a new non-legacy ListCollectionsCursor.
|
||||
func NewListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) {
|
||||
if bc == nil {
|
||||
return nil, errors.New("batch cursor must not be nil")
|
||||
}
|
||||
return &ListCollectionsBatchCursor{bc: bc, currentBatch: new(bsoncore.DocumentSequence)}, nil
|
||||
}
|
||||
|
||||
// NewLegacyListCollectionsBatchCursor creates a new legacy ListCollectionsCursor.
|
||||
func NewLegacyListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) {
|
||||
if bc == nil {
|
||||
return nil, errors.New("batch cursor must not be nil")
|
||||
}
|
||||
return &ListCollectionsBatchCursor{legacy: true, bc: bc, currentBatch: new(bsoncore.DocumentSequence)}, nil
|
||||
}
|
||||
|
||||
// ID returns the cursor ID for this batch cursor.
|
||||
func (lcbc *ListCollectionsBatchCursor) ID() int64 {
|
||||
return lcbc.bc.ID()
|
||||
}
|
||||
|
||||
// Next indicates if there is another batch available. Returning false does not necessarily indicate
|
||||
// that the cursor is closed. This method will return false when an empty batch is returned.
|
||||
//
|
||||
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
|
||||
// is not a valid batch of documents available.
|
||||
func (lcbc *ListCollectionsBatchCursor) Next(ctx context.Context) bool {
|
||||
if !lcbc.bc.Next(ctx) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !lcbc.legacy {
|
||||
lcbc.currentBatch.Style = lcbc.bc.currentBatch.Style
|
||||
lcbc.currentBatch.Data = lcbc.bc.currentBatch.Data
|
||||
lcbc.currentBatch.ResetIterator()
|
||||
return true
|
||||
}
|
||||
|
||||
lcbc.currentBatch.Style = bsoncore.SequenceStyle
|
||||
lcbc.currentBatch.Data = lcbc.currentBatch.Data[:0]
|
||||
|
||||
var doc bsoncore.Document
|
||||
for {
|
||||
doc, lcbc.err = lcbc.bc.currentBatch.Next()
|
||||
if lcbc.err != nil {
|
||||
if lcbc.err == io.EOF {
|
||||
lcbc.err = nil
|
||||
break
|
||||
}
|
||||
return false
|
||||
}
|
||||
doc, lcbc.err = lcbc.projectNameElement(doc)
|
||||
if lcbc.err != nil {
|
||||
return false
|
||||
}
|
||||
lcbc.currentBatch.Data = append(lcbc.currentBatch.Data, doc...)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Batch will return a DocumentSequence for the current batch of documents. The returned
|
||||
// DocumentSequence is only valid until the next call to Next or Close.
|
||||
func (lcbc *ListCollectionsBatchCursor) Batch() *bsoncore.DocumentSequence { return lcbc.currentBatch }
|
||||
|
||||
// Server returns a pointer to the cursor's server.
|
||||
func (lcbc *ListCollectionsBatchCursor) Server() Server { return lcbc.bc.server }
|
||||
|
||||
// Err returns the latest error encountered.
|
||||
func (lcbc *ListCollectionsBatchCursor) Err() error {
|
||||
if lcbc.err != nil {
|
||||
return lcbc.err
|
||||
}
|
||||
return lcbc.bc.Err()
|
||||
}
|
||||
|
||||
// Close closes this batch cursor.
|
||||
func (lcbc *ListCollectionsBatchCursor) Close(ctx context.Context) error { return lcbc.bc.Close(ctx) }
|
||||
|
||||
// project out the database name for a legacy server
|
||||
func (*ListCollectionsBatchCursor) projectNameElement(rawDoc bsoncore.Document) (bsoncore.Document, error) {
|
||||
elems, err := rawDoc.Elements()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var filteredElems []byte
|
||||
for _, elem := range elems {
|
||||
key := elem.Key()
|
||||
if key != "name" {
|
||||
filteredElems = append(filteredElems, elem...)
|
||||
continue
|
||||
}
|
||||
|
||||
name := elem.Value().StringValue()
|
||||
collName := name[strings.Index(name, ".")+1:]
|
||||
filteredElems = bsoncore.AppendStringElement(filteredElems, "name", collName)
|
||||
}
|
||||
|
||||
var filteredDoc []byte
|
||||
filteredDoc = bsoncore.BuildDocument(filteredDoc, filteredElems)
|
||||
return filteredDoc, nil
|
||||
}
|
||||
|
||||
// SetBatchSize sets the batchSize for future getMores.
|
||||
func (lcbc *ListCollectionsBatchCursor) SetBatchSize(size int32) {
|
||||
lcbc.bc.SetBatchSize(size)
|
||||
}
|
||||
56
mongo/x/mongo/driver/mongocrypt/binary.go
Normal file
56
mongo/x/mongo/driver/mongocrypt/binary.go
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build cse
|
||||
// +build cse
|
||||
|
||||
package mongocrypt
|
||||
|
||||
// #include <mongocrypt.h>
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// binary is a wrapper type around a mongocrypt_binary_t*
|
||||
type binary struct {
|
||||
wrapped *C.mongocrypt_binary_t
|
||||
}
|
||||
|
||||
// newBinary creates an empty binary instance.
|
||||
func newBinary() *binary {
|
||||
return &binary{
|
||||
wrapped: C.mongocrypt_binary_new(),
|
||||
}
|
||||
}
|
||||
|
||||
// newBinaryFromBytes creates a binary instance from a byte buffer.
|
||||
func newBinaryFromBytes(data []byte) *binary {
|
||||
if len(data) == 0 {
|
||||
return newBinary()
|
||||
}
|
||||
|
||||
// We don't need C.CBytes here because data cannot go out of scope. Any mongocrypt function that takes a
|
||||
// mongocrypt_binary_t will make a copy of the data so the data can be garbage collected after calling.
|
||||
addr := (*C.uint8_t)(unsafe.Pointer(&data[0])) // uint8_t*
|
||||
dataLen := C.uint32_t(len(data)) // uint32_t
|
||||
return &binary{
|
||||
wrapped: C.mongocrypt_binary_new_from_data(addr, dataLen),
|
||||
}
|
||||
}
|
||||
|
||||
// toBytes converts the given binary instance to []byte.
|
||||
func (b *binary) toBytes() []byte {
|
||||
dataPtr := C.mongocrypt_binary_data(b.wrapped) // C.uint8_t*
|
||||
dataLen := C.mongocrypt_binary_len(b.wrapped) // C.uint32_t
|
||||
|
||||
return C.GoBytes(unsafe.Pointer(dataPtr), C.int(dataLen))
|
||||
}
|
||||
|
||||
// close cleans up any resources associated with the given binary instance.
|
||||
func (b *binary) close() {
|
||||
C.mongocrypt_binary_destroy(b.wrapped)
|
||||
}
|
||||
44
mongo/x/mongo/driver/mongocrypt/errors.go
Normal file
44
mongo/x/mongo/driver/mongocrypt/errors.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build cse
|
||||
// +build cse
|
||||
|
||||
package mongocrypt
|
||||
|
||||
// #include <mongocrypt.h>
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error represents an error from an operation on a MongoCrypt instance.
|
||||
type Error struct {
|
||||
Code int32
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e Error) Error() string {
|
||||
return fmt.Sprintf("mongocrypt error %d: %v", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// errorFromStatus builds a Error from a mongocrypt_status_t object.
|
||||
func errorFromStatus(status *C.mongocrypt_status_t) error {
|
||||
cCode := C.mongocrypt_status_code(status) // uint32_t
|
||||
// mongocrypt_status_message takes uint32_t* as its second param to store the length of the returned string.
|
||||
// pass nil because the length is handled by C.GoString
|
||||
cMsg := C.mongocrypt_status_message(status, nil) // const char*
|
||||
var msg string
|
||||
if cMsg != nil {
|
||||
msg = C.GoString(cMsg)
|
||||
}
|
||||
|
||||
return Error{
|
||||
Code: int32(cCode),
|
||||
Message: msg,
|
||||
}
|
||||
}
|
||||
21
mongo/x/mongo/driver/mongocrypt/errors_not_enabled.go
Normal file
21
mongo/x/mongo/driver/mongocrypt/errors_not_enabled.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Copyright (C) MongoDB, Inc. 2017-present.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
//go:build !cse
|
||||
// +build !cse
|
||||
|
||||
package mongocrypt
|
||||
|
||||
// Error represents an error from an operation on a MongoCrypt instance.
|
||||
type Error struct {
|
||||
Code int32
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (Error) Error() string {
|
||||
panic(cseNotSupportedMsg)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user