Copied mongo repo (to patch it)

This commit is contained in:
2023-06-18 15:50:55 +02:00
parent 21d241f9b1
commit d471d7c396
544 changed files with 142039 additions and 1 deletions

6
mongo/x/README.md Normal file
View File

@@ -0,0 +1,6 @@
MongoDB Go Driver Unstable Libraries
====================================
This directory contains unstable MongoDB Go driver libraries and packages. The APIs of these
packages are not stable and there is no backward compatibility guarantee.
**THESE PACKAGES ARE EXPERIMENTAL AND SUBJECT TO CHANGE.**

97
mongo/x/bsonx/array.go Normal file
View File

@@ -0,0 +1,97 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx // import "go.mongodb.org/mongo-driver/x/bsonx"
import (
"bytes"
"errors"
"fmt"
"strconv"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrNilArray indicates that an operation was attempted on a nil *Array.
var ErrNilArray = errors.New("array is nil")
// Arr represents an array in BSON.
type Arr []Val
// String implements the fmt.Stringer interface.
func (a Arr) String() string {
var buf bytes.Buffer
buf.Write([]byte("bson.Array["))
for idx, val := range a {
if idx > 0 {
buf.Write([]byte(", "))
}
fmt.Fprintf(&buf, "%s", val)
}
buf.WriteByte(']')
return buf.String()
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
func (a Arr) MarshalBSONValue() (bsontype.Type, []byte, error) {
if a == nil {
// TODO: Should we do this?
return bsontype.Null, nil, nil
}
idx, dst := bsoncore.ReserveLength(nil)
for idx, value := range a {
t, data, _ := value.MarshalBSONValue() // marshalBSONValue never returns an error.
dst = append(dst, byte(t))
dst = append(dst, strconv.Itoa(idx)...)
dst = append(dst, 0x00)
dst = append(dst, data...)
}
dst = append(dst, 0x00)
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return bsontype.Array, dst, nil
}
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.
func (a *Arr) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
if a == nil {
return ErrNilArray
}
*a = (*a)[:0]
elements, err := bsoncore.Document(data).Elements()
if err != nil {
return err
}
for _, elem := range elements {
var val Val
rawval := elem.Value()
err = val.UnmarshalBSONValue(rawval.Type, rawval.Data)
if err != nil {
return err
}
*a = append(*a, val)
}
return nil
}
// Equal compares this document to another, returning true if they are equal.
func (a Arr) Equal(a2 Arr) bool {
if len(a) != len(a2) {
return false
}
for idx := range a {
if !a[idx].Equal(a2[idx]) {
return false
}
}
return true
}
func (Arr) idoc() {}

View File

@@ -0,0 +1,36 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"fmt"
)
func ExampleArray() {
internalVersion := "1234567"
f := func(appName string) Arr {
arr := make(Arr, 0)
arr = append(arr,
Document(Doc{{"name", String("mongo-go-driver")}, {"version", String(internalVersion)}}),
Document(Doc{{"type", String("darwin")}, {"architecture", String("amd64")}}),
String("go1.9.2"),
)
if appName != "" {
arr = append(arr, Document(MDoc{"name": String(appName)}))
}
return arr
}
_, buf, err := f("hello-world").MarshalBSONValue()
if err != nil {
fmt.Println(err)
}
fmt.Println(buf)
// Output: [154 0 0 0 3 48 0 52 0 0 0 2 110 97 109 101 0 16 0 0 0 109 111 110 103 111 45 103 111 45 100 114 105 118 101 114 0 2 118 101 114 115 105 111 110 0 8 0 0 0 49 50 51 52 53 54 55 0 0 3 49 0 46 0 0 0 2 116 121 112 101 0 7 0 0 0 100 97 114 119 105 110 0 2 97 114 99 104 105 116 101 99 116 117 114 101 0 6 0 0 0 97 109 100 54 52 0 0 2 50 0 8 0 0 0 103 111 49 46 57 46 50 0 3 51 0 27 0 0 0 2 110 97 109 101 0 12 0 0 0 104 101 108 108 111 45 119 111 114 108 100 0 0 0]
}

View File

@@ -0,0 +1,52 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"testing"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
d1H, d1L := d1.GetBytes()
d2H, d2L := d2.GetBytes()
if d1H != d2H {
return false
}
if d1L != d2L {
return false
}
return true
}
func compareErrors(err1, err2 error) bool {
if err1 == nil && err2 == nil {
return true
}
if err1 == nil || err2 == nil {
return false
}
if err1.Error() != err2.Error() {
return false
}
return true
}

View File

@@ -0,0 +1,164 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"fmt"
"io"
"strconv"
)
// NewArrayLengthError creates and returns an error for when the length of an array exceeds the
// bytes available.
func NewArrayLengthError(length, rem int) error {
return lengthError("array", length, rem)
}
// Array is a raw bytes representation of a BSON array.
type Array []byte
// NewArrayFromReader reads an array from r. This function will only validate the length is
// correct and that the array ends with a null byte.
func NewArrayFromReader(r io.Reader) (Array, error) {
return newBufferFromReader(r)
}
// Index searches for and retrieves the value at the given index. This method will panic if
// the array is invalid or if the index is out of bounds.
func (a Array) Index(index uint) Value {
value, err := a.IndexErr(index)
if err != nil {
panic(err)
}
return value
}
// IndexErr searches for and retrieves the value at the given index.
func (a Array) IndexErr(index uint) (Value, error) {
elem, err := indexErr(a, index)
if err != nil {
return Value{}, err
}
return elem.Value(), err
}
// DebugString outputs a human readable version of Array. It will attempt to stringify the
// valid components of the array even if the entire array is not valid.
func (a Array) DebugString() string {
if len(a) < 5 {
return "<malformed>"
}
var buf bytes.Buffer
buf.WriteString("Array")
length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString(")[")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
fmt.Fprintf(&buf, "%s", elem.Value().DebugString())
if length != 1 {
buf.WriteByte(',')
}
}
buf.WriteByte(']')
return buf.String()
}
// String outputs an ExtendedJSON version of Array. If the Array is not valid, this method
// returns an empty string.
func (a Array) String() string {
if len(a) < 5 {
return ""
}
var buf bytes.Buffer
buf.WriteByte('[')
length, rem, _ := ReadLength(a) // We know we have enough bytes to read the length
length -= 4
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return ""
}
fmt.Fprintf(&buf, "%s", elem.Value().String())
if length > 1 {
buf.WriteByte(',')
}
}
if length != 1 { // Missing final null byte or inaccurate length
return ""
}
buf.WriteByte(']')
return buf.String()
}
// Values returns this array as a slice of values. The returned slice will contain valid values.
// If the array is not valid, the values up to the invalid point will be returned along with an
// error.
func (a Array) Values() ([]Value, error) {
return values(a)
}
// Validate validates the array and ensures the elements contained within are valid.
func (a Array) Validate() error {
length, rem, ok := ReadLength(a)
if !ok {
return NewInsufficientBytesError(a, rem)
}
if int(length) > len(a) {
return NewArrayLengthError(int(length), len(a))
}
if a[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
var keyNum int64
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(a, rem)
}
// validate element
err := elem.Validate()
if err != nil {
return err
}
// validate keys increase numerically
if fmt.Sprint(keyNum) != elem.Key() {
return fmt.Errorf("array key %q is out of order or invalid", elem.Key())
}
keyNum++
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}

View File

@@ -0,0 +1,349 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"errors"
"io"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
func TestArray(t *testing.T) {
t.Run("Validate", func(t *testing.T) {
t.Run("TooShort", func(t *testing.T) {
want := NewInsufficientBytesError(nil, nil)
got := Array{'\x00', '\x00'}.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("InvalidLength", func(t *testing.T) {
want := NewArrayLengthError(200, 5)
r := make(Array, 5)
binary.LittleEndian.PutUint32(r[0:4], 200)
got := r.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("Invalid Element", func(t *testing.T) {
want := NewInsufficientBytesError(nil, nil)
r := make(Array, 7)
binary.LittleEndian.PutUint32(r[0:4], 7)
r[4], r[5], r[6] = 0x02, 'f', 0x00
got := r.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("Missing Null Terminator", func(t *testing.T) {
want := ErrMissingNull
r := make(Array, 6)
binary.LittleEndian.PutUint32(r[0:4], 6)
r[4], r[5] = 0x0A, '0'
got := r.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
testCases := []struct {
name string
r Array
want error
}{
{"array null", Array{'\x08', '\x00', '\x00', '\x00', '\x0A', '0', '\x00', '\x00'}, nil},
{"array",
Array{
'\x1B', '\x00', '\x00', '\x00',
'\x02',
'0', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x72', '\x00',
'\x02',
'1', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x7a', '\x00',
'\x00',
},
nil,
},
{"subarray",
Array{
'\x13', '\x00', '\x00', '\x00',
'\x04',
'0', '\x00',
'\x0B', '\x00', '\x00', '\x00', '\x0A', '0', '\x00',
'\x0A', '1', '\x00', '\x00', '\x00',
},
nil,
},
{"invalid key order",
Array{
'\x0B', '\x00', '\x00', '\x00', '\x0A', '2', '\x00',
'\x0A', '0', '\x00', '\x00', '\x00',
},
errors.New(`array key "2" is out of order or invalid`),
},
{"invalid key type",
Array{
'\x0B', '\x00', '\x00', '\x00', '\x0A', 'p', '\x00',
'\x0A', 'q', '\x00', '\x00', '\x00',
},
errors.New(`array key "p" is out of order or invalid`),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := tc.r.Validate()
if !compareErrors(got, tc.want) {
t.Errorf("Returned error does not match. got %v; want %v", got, tc.want)
}
})
}
})
t.Run("Index", func(t *testing.T) {
t.Run("Out of bounds", func(t *testing.T) {
rdr := Array{0xe, 0x0, 0x0, 0x0, 0xa, '0', 0x0, 0xa, '1', 0x0, 0xa, 0x7a, 0x0, 0x0}
_, err := rdr.IndexErr(3)
if err != ErrOutOfBounds {
t.Errorf("Out of bounds should be returned when accessing element beyond end of Array. got %v; want %v", err, ErrOutOfBounds)
}
})
t.Run("Validation Error", func(t *testing.T) {
rdr := Array{0x07, 0x00, 0x00, 0x00, 0x00}
_, got := rdr.IndexErr(1)
want := NewInsufficientBytesError(nil, nil)
if !compareErrors(got, want) {
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
}
})
testArray := Array{
'\x26', '\x00', '\x00', '\x00',
'\x02',
'0', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x72', '\x00',
'\x02',
'1', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x7a', '\x00',
'\x02',
'2', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x71', '\x75', '\x78', '\x00',
'\x00',
}
testCases := []struct {
name string
index uint
want Value
}{
{"first",
0,
Value{
Type: bsontype.String,
Data: []byte{0x04, 0x00, 0x00, 0x00, 0x62, 0x61, 0x72, 0x00},
},
},
{"second",
1,
Value{
Type: bsontype.String,
Data: []byte{0x04, 0x00, 0x00, 0x00, 0x62, 0x61, 0x7a, 0x00},
},
},
{"third",
2,
Value{
Type: bsontype.String,
Data: []byte{0x04, 0x00, 0x00, 0x00, 0x71, 0x75, 0x78, 0x00},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("IndexErr", func(t *testing.T) {
got, err := testArray.IndexErr(tc.index)
if err != nil {
t.Errorf("Unexpected error from IndexErr: %s", err)
}
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("Arrays differ: (-got +want)\n%s", diff)
}
})
t.Run("Index", func(t *testing.T) {
defer func() {
if err := recover(); err != nil {
t.Errorf("Unexpected error: %v", err)
}
}()
got := testArray.Index(tc.index)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("Arrays differ: (-got +want)\n%s", diff)
}
})
})
}
})
t.Run("NewArrayFromReader", func(t *testing.T) {
testCases := []struct {
name string
ioReader io.Reader
arr Array
err error
}{
{
"nil reader",
nil,
nil,
ErrNilReader,
},
{
"premature end of reader",
bytes.NewBuffer([]byte{}),
nil,
io.EOF,
},
{
"empty Array",
bytes.NewBuffer([]byte{5, 0, 0, 0, 0}),
[]byte{5, 0, 0, 0, 0},
nil,
},
{
"non-empty Array",
bytes.NewBuffer([]byte{
'\x1B', '\x00', '\x00', '\x00',
'\x02',
'0', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x72', '\x00',
'\x02',
'1', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x7a', '\x00',
'\x00',
}),
[]byte{
'\x1B', '\x00', '\x00', '\x00',
'\x02',
'0', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x72', '\x00',
'\x02',
'1', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x7a', '\x00',
'\x00',
},
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
arr, err := NewArrayFromReader(tc.ioReader)
if !compareErrors(err, tc.err) {
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
}
if !bytes.Equal(tc.arr, arr) {
t.Errorf("Arrays differ. got %v; want %v", tc.arr, arr)
}
})
}
})
t.Run("DebugString", func(t *testing.T) {
testCases := []struct {
name string
arr Array
arrayString string
arrayDebugString string
}{
{
"array",
Array{
'\x1B', '\x00', '\x00', '\x00',
'\x02',
'0', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x72', '\x00',
'\x02',
'1', '\x00',
'\x04', '\x00', '\x00', '\x00',
'\x62', '\x61', '\x7a', '\x00',
'\x00',
},
`["bar","baz"]`,
`Array(27)["bar","baz"]`,
},
{
"subarray",
Array{
'\x13', '\x00', '\x00', '\x00',
'\x04',
'0', '\x00',
'\x0B', '\x00', '\x00', '\x00',
'\x0A', '0', '\x00',
'\x0A', '1', '\x00',
'\x00', '\x00',
},
`[[null,null]]`,
`Array(19)[Array(11)[null,null]]`,
},
{
"malformed--length too small",
Array{
'\x04', '\x00', '\x00', '\x00',
'\x00',
},
``,
`Array(4)[]`,
},
{
"malformed--length too large",
Array{
'\x13', '\x00', '\x00', '\x00',
'\x00',
},
``,
`Array(19)[<malformed (15)>]`,
},
{
"malformed--missing null byte",
Array{
'\x06', '\x00', '\x00', '\x00',
'\x02', '0',
},
``,
`Array(6)[<malformed (2)>]`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
arrayString := tc.arr.String()
if arrayString != tc.arrayString {
t.Errorf("array strings do not match. got %q; want %q",
arrayString, tc.arrayString)
}
arrayDebugString := tc.arr.DebugString()
if arrayDebugString != tc.arrayDebugString {
t.Errorf("array debug strings do not match. got %q; want %q",
arrayDebugString, tc.arrayDebugString)
}
})
}
})
}

View File

@@ -0,0 +1,201 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"strconv"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayBuilder builds a bson array
type ArrayBuilder struct {
arr []byte
indexes []int32
keys []int
}
// NewArrayBuilder creates a new ArrayBuilder
func NewArrayBuilder() *ArrayBuilder {
return (&ArrayBuilder{}).startArray()
}
// startArray reserves the array's length and sets the index to where the length begins
func (a *ArrayBuilder) startArray() *ArrayBuilder {
var index int32
index, a.arr = AppendArrayStart(a.arr)
a.indexes = append(a.indexes, index)
a.keys = append(a.keys, 0)
return a
}
// Build updates the length of the array and index to the beginning of the documents length
// bytes, then returns the array (bson bytes)
func (a *ArrayBuilder) Build() Array {
lastIndex := len(a.indexes) - 1
lastKey := len(a.keys) - 1
a.arr, _ = AppendArrayEnd(a.arr, a.indexes[lastIndex])
a.indexes = a.indexes[:lastIndex]
a.keys = a.keys[:lastKey]
return a.arr
}
// incrementKey() increments the value keys and returns the key to be used to a.appendArray* functions
func (a *ArrayBuilder) incrementKey() string {
idx := len(a.keys) - 1
key := strconv.Itoa(a.keys[idx])
a.keys[idx]++
return key
}
// AppendInt32 will append i32 to ArrayBuilder.arr
func (a *ArrayBuilder) AppendInt32(i32 int32) *ArrayBuilder {
a.arr = AppendInt32Element(a.arr, a.incrementKey(), i32)
return a
}
// AppendDocument will append doc to ArrayBuilder.arr
func (a *ArrayBuilder) AppendDocument(doc []byte) *ArrayBuilder {
a.arr = AppendDocumentElement(a.arr, a.incrementKey(), doc)
return a
}
// AppendArray will append arr to ArrayBuilder.arr
func (a *ArrayBuilder) AppendArray(arr []byte) *ArrayBuilder {
a.arr = AppendArrayElement(a.arr, a.incrementKey(), arr)
return a
}
// AppendDouble will append f to ArrayBuilder.doc
func (a *ArrayBuilder) AppendDouble(f float64) *ArrayBuilder {
a.arr = AppendDoubleElement(a.arr, a.incrementKey(), f)
return a
}
// AppendString will append str to ArrayBuilder.doc
func (a *ArrayBuilder) AppendString(str string) *ArrayBuilder {
a.arr = AppendStringElement(a.arr, a.incrementKey(), str)
return a
}
// AppendObjectID will append oid to ArrayBuilder.doc
func (a *ArrayBuilder) AppendObjectID(oid primitive.ObjectID) *ArrayBuilder {
a.arr = AppendObjectIDElement(a.arr, a.incrementKey(), oid)
return a
}
// AppendBinary will append a BSON binary element using subtype, and
// b to a.arr
func (a *ArrayBuilder) AppendBinary(subtype byte, b []byte) *ArrayBuilder {
a.arr = AppendBinaryElement(a.arr, a.incrementKey(), subtype, b)
return a
}
// AppendUndefined will append a BSON undefined element using key to a.arr
func (a *ArrayBuilder) AppendUndefined() *ArrayBuilder {
a.arr = AppendUndefinedElement(a.arr, a.incrementKey())
return a
}
// AppendBoolean will append a boolean element using b to a.arr
func (a *ArrayBuilder) AppendBoolean(b bool) *ArrayBuilder {
a.arr = AppendBooleanElement(a.arr, a.incrementKey(), b)
return a
}
// AppendDateTime will append datetime element dt to a.arr
func (a *ArrayBuilder) AppendDateTime(dt int64) *ArrayBuilder {
a.arr = AppendDateTimeElement(a.arr, a.incrementKey(), dt)
return a
}
// AppendNull will append a null element to a.arr
func (a *ArrayBuilder) AppendNull() *ArrayBuilder {
a.arr = AppendNullElement(a.arr, a.incrementKey())
return a
}
// AppendRegex will append pattern and options to a.arr
func (a *ArrayBuilder) AppendRegex(pattern, options string) *ArrayBuilder {
a.arr = AppendRegexElement(a.arr, a.incrementKey(), pattern, options)
return a
}
// AppendDBPointer will append ns and oid to a.arr
func (a *ArrayBuilder) AppendDBPointer(ns string, oid primitive.ObjectID) *ArrayBuilder {
a.arr = AppendDBPointerElement(a.arr, a.incrementKey(), ns, oid)
return a
}
// AppendJavaScript will append js to a.arr
func (a *ArrayBuilder) AppendJavaScript(js string) *ArrayBuilder {
a.arr = AppendJavaScriptElement(a.arr, a.incrementKey(), js)
return a
}
// AppendSymbol will append symbol to a.arr
func (a *ArrayBuilder) AppendSymbol(symbol string) *ArrayBuilder {
a.arr = AppendSymbolElement(a.arr, a.incrementKey(), symbol)
return a
}
// AppendCodeWithScope will append code and scope to a.arr
func (a *ArrayBuilder) AppendCodeWithScope(code string, scope Document) *ArrayBuilder {
a.arr = AppendCodeWithScopeElement(a.arr, a.incrementKey(), code, scope)
return a
}
// AppendTimestamp will append t and i to a.arr
func (a *ArrayBuilder) AppendTimestamp(t, i uint32) *ArrayBuilder {
a.arr = AppendTimestampElement(a.arr, a.incrementKey(), t, i)
return a
}
// AppendInt64 will append i64 to a.arr
func (a *ArrayBuilder) AppendInt64(i64 int64) *ArrayBuilder {
a.arr = AppendInt64Element(a.arr, a.incrementKey(), i64)
return a
}
// AppendDecimal128 will append d128 to a.arr
func (a *ArrayBuilder) AppendDecimal128(d128 primitive.Decimal128) *ArrayBuilder {
a.arr = AppendDecimal128Element(a.arr, a.incrementKey(), d128)
return a
}
// AppendMaxKey will append a max key element to a.arr
func (a *ArrayBuilder) AppendMaxKey() *ArrayBuilder {
a.arr = AppendMaxKeyElement(a.arr, a.incrementKey())
return a
}
// AppendMinKey will append a min key element to a.arr
func (a *ArrayBuilder) AppendMinKey() *ArrayBuilder {
a.arr = AppendMinKeyElement(a.arr, a.incrementKey())
return a
}
// AppendValue appends a BSON value to the array.
func (a *ArrayBuilder) AppendValue(val Value) *ArrayBuilder {
a.arr = AppendValueElement(a.arr, a.incrementKey(), val)
return a
}
// StartArray starts building an inline Array. After this document is completed,
// the user must call a.FinishArray
func (a *ArrayBuilder) StartArray() *ArrayBuilder {
a.arr = AppendHeader(a.arr, bsontype.Array, a.incrementKey())
a.startArray()
return a
}
// FinishArray builds the most recent array created
func (a *ArrayBuilder) FinishArray() *ArrayBuilder {
a.arr = a.Build()
return a
}

View File

@@ -0,0 +1,213 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"math"
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func TestArrayBuilder(t *testing.T) {
bits := math.Float64bits(3.14159)
pi := make([]byte, 8)
binary.LittleEndian.PutUint64(pi, bits)
testCases := []struct {
name string
fn interface{}
params []interface{}
expected []byte
}{
{
"AppendInt32",
NewArrayBuilder().AppendInt32,
[]interface{}{int32(256)},
BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(256))),
},
{
"AppendDouble",
NewArrayBuilder().AppendDouble,
[]interface{}{float64(3.14159)},
BuildDocumentFromElements(nil, AppendDoubleElement(nil, "0", float64(3.14159))),
},
{
"AppendString",
NewArrayBuilder().AppendString,
[]interface{}{"x"},
BuildDocumentFromElements(nil, AppendStringElement(nil, "0", "x")),
},
{
"AppendDocument",
NewArrayBuilder().AppendDocument,
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}},
BuildDocumentFromElements(nil, AppendDocumentElement(nil, "0", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
},
{
"AppendArray",
NewArrayBuilder().AppendArray,
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}},
BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
},
{
"AppendBinary",
NewArrayBuilder().AppendBinary,
[]interface{}{byte(0x02), []byte{0x01, 0x02, 0x03}},
BuildDocumentFromElements(nil, AppendBinaryElement(nil, "0", byte(0x02), []byte{0x01, 0x02, 0x03})),
},
{
"AppendObjectID",
NewArrayBuilder().AppendObjectID,
[]interface{}{
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
BuildDocumentFromElements(nil, AppendObjectIDElement(nil, "0",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
},
{
"AppendBoolean",
NewArrayBuilder().AppendBoolean,
[]interface{}{true},
BuildDocumentFromElements(nil, AppendBooleanElement(nil, "0", true)),
},
{
"AppendDateTime",
NewArrayBuilder().AppendDateTime,
[]interface{}{int64(256)},
BuildDocumentFromElements(nil, AppendDateTimeElement(nil, "0", int64(256))),
},
{
"AppendNull",
NewArrayBuilder().AppendNull,
[]interface{}{},
BuildDocumentFromElements(nil, AppendNullElement(nil, "0")),
},
{
"AppendRegex",
NewArrayBuilder().AppendRegex,
[]interface{}{"bar", "baz"},
BuildDocumentFromElements(nil, AppendRegexElement(nil, "0", "bar", "baz")),
},
{
"AppendJavaScript",
NewArrayBuilder().AppendJavaScript,
[]interface{}{"barbaz"},
BuildDocumentFromElements(nil, AppendJavaScriptElement(nil, "0", "barbaz")),
},
{
"AppendCodeWithScope",
NewArrayBuilder().AppendCodeWithScope,
[]interface{}{"barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00})},
BuildDocumentFromElements(nil, AppendCodeWithScopeElement(nil, "0", "barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00}))),
},
{
"AppendTimestamp",
NewArrayBuilder().AppendTimestamp,
[]interface{}{uint32(65536), uint32(256)},
BuildDocumentFromElements(nil, AppendTimestampElement(nil, "0", uint32(65536), uint32(256))),
},
{
"AppendInt64",
NewArrayBuilder().AppendInt64,
[]interface{}{int64(4294967296)},
BuildDocumentFromElements(nil, AppendInt64Element(nil, "0", int64(4294967296))),
},
{
"AppendDecimal128",
NewArrayBuilder().AppendDecimal128,
[]interface{}{primitive.NewDecimal128(4294967296, 65536)},
BuildDocumentFromElements(nil, AppendDecimal128Element(nil, "0", primitive.NewDecimal128(4294967296, 65536))),
},
{
"AppendMaxKey",
NewArrayBuilder().AppendMaxKey,
[]interface{}{},
BuildDocumentFromElements(nil, AppendMaxKeyElement(nil, "0")),
},
{
"AppendMinKey",
NewArrayBuilder().AppendMinKey,
[]interface{}{},
BuildDocumentFromElements(nil, AppendMinKeyElement(nil, "0")),
},
{
"AppendSymbol",
NewArrayBuilder().AppendSymbol,
[]interface{}{"barbaz"},
BuildDocumentFromElements(nil, AppendSymbolElement(nil, "0", "barbaz")),
},
{
"AppendDBPointer",
NewArrayBuilder().AppendDBPointer,
[]interface{}{"barbaz",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}},
BuildDocumentFromElements(nil, AppendDBPointerElement(nil, "0", "barbaz",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
},
{
"AppendUndefined",
NewArrayBuilder().AppendUndefined,
[]interface{}{},
BuildDocumentFromElements(nil, AppendUndefinedElement(nil, "0")),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params) {
t.Fatalf("tc.params must match the number of params in tc.fn. params %d; fn %d", fn.Type().NumIn(), len(tc.params))
}
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf(&ArrayBuilder{}) {
t.Fatalf("fn must have one return parameter and it must be an ArrayBuilder.")
}
params := make([]reflect.Value, 0, len(tc.params))
for _, param := range tc.params {
params = append(params, reflect.ValueOf(param))
}
results := fn.Call(params)
got := results[0].Interface().(*ArrayBuilder).Build()
want := tc.expected
if !bytes.Equal(got, want) {
t.Errorf("Did not receive expected bytes. got %v; want %v", got, want)
}
})
}
t.Run("TestBuildTwoElementsArray", func(t *testing.T) {
intArr := BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(1)))
expected := BuildDocumentFromElements(nil, AppendArrayElement(AppendInt32Element(nil, "0", int32(3)), "1", intArr))
elem := NewArrayBuilder().AppendInt32(int32(1)).Build()
result := NewArrayBuilder().AppendInt32(int32(3)).AppendArray(elem).Build()
if !bytes.Equal(result, expected) {
t.Errorf("Arrays do not match. got %v; want %v", result, expected)
}
})
t.Run("TestBuildInlineArray", func(t *testing.T) {
docElement := BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(256)))
expected := Document(BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", docElement)))
result := NewArrayBuilder().StartArray().AppendInt32(int32(256)).FinishArray().Build()
if !bytes.Equal(result, expected) {
t.Errorf("Documents do not match. got %v; want %v", result, expected)
}
})
t.Run("TestBuildNestedInlineArray", func(t *testing.T) {
docElement := BuildDocumentFromElements(nil, AppendDoubleElement(nil, "0", 3.14))
docInline := BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", docElement))
expected := Document(BuildDocumentFromElements(nil, AppendArrayElement(nil, "0", docInline)))
result := NewArrayBuilder().StartArray().StartArray().AppendDouble(3.14).FinishArray().FinishArray().Build()
if !bytes.Equal(result, expected) {
t.Errorf("Documents do not match. got %v; want %v", result, expected)
}
})
}

View File

@@ -0,0 +1,189 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// DocumentBuilder builds a bson document
type DocumentBuilder struct {
doc []byte
indexes []int32
}
// startDocument reserves the document's length and set the index to where the length begins
func (db *DocumentBuilder) startDocument() *DocumentBuilder {
var index int32
index, db.doc = AppendDocumentStart(db.doc)
db.indexes = append(db.indexes, index)
return db
}
// NewDocumentBuilder creates a new DocumentBuilder
func NewDocumentBuilder() *DocumentBuilder {
return (&DocumentBuilder{}).startDocument()
}
// Build updates the length of the document and index to the beginning of the documents length
// bytes, then returns the document (bson bytes)
func (db *DocumentBuilder) Build() Document {
last := len(db.indexes) - 1
db.doc, _ = AppendDocumentEnd(db.doc, db.indexes[last])
db.indexes = db.indexes[:last]
return db.doc
}
// AppendInt32 will append an int32 element using key and i32 to DocumentBuilder.doc
func (db *DocumentBuilder) AppendInt32(key string, i32 int32) *DocumentBuilder {
db.doc = AppendInt32Element(db.doc, key, i32)
return db
}
// AppendDocument will append a bson embedded document element using key
// and doc to DocumentBuilder.doc
func (db *DocumentBuilder) AppendDocument(key string, doc []byte) *DocumentBuilder {
db.doc = AppendDocumentElement(db.doc, key, doc)
return db
}
// AppendArray will append a bson array using key and arr to DocumentBuilder.doc
func (db *DocumentBuilder) AppendArray(key string, arr []byte) *DocumentBuilder {
db.doc = AppendHeader(db.doc, bsontype.Array, key)
db.doc = AppendArray(db.doc, arr)
return db
}
// AppendDouble will append a double element using key and f to DocumentBuilder.doc
func (db *DocumentBuilder) AppendDouble(key string, f float64) *DocumentBuilder {
db.doc = AppendDoubleElement(db.doc, key, f)
return db
}
// AppendString will append str to DocumentBuilder.doc with the given key
func (db *DocumentBuilder) AppendString(key string, str string) *DocumentBuilder {
db.doc = AppendStringElement(db.doc, key, str)
return db
}
// AppendObjectID will append oid to DocumentBuilder.doc with the given key
func (db *DocumentBuilder) AppendObjectID(key string, oid primitive.ObjectID) *DocumentBuilder {
db.doc = AppendObjectIDElement(db.doc, key, oid)
return db
}
// AppendBinary will append a BSON binary element using key, subtype, and
// b to db.doc
func (db *DocumentBuilder) AppendBinary(key string, subtype byte, b []byte) *DocumentBuilder {
db.doc = AppendBinaryElement(db.doc, key, subtype, b)
return db
}
// AppendUndefined will append a BSON undefined element using key to db.doc
func (db *DocumentBuilder) AppendUndefined(key string) *DocumentBuilder {
db.doc = AppendUndefinedElement(db.doc, key)
return db
}
// AppendBoolean will append a boolean element using key and b to db.doc
func (db *DocumentBuilder) AppendBoolean(key string, b bool) *DocumentBuilder {
db.doc = AppendBooleanElement(db.doc, key, b)
return db
}
// AppendDateTime will append a datetime element using key and dt to db.doc
func (db *DocumentBuilder) AppendDateTime(key string, dt int64) *DocumentBuilder {
db.doc = AppendDateTimeElement(db.doc, key, dt)
return db
}
// AppendNull will append a null element using key to db.doc
func (db *DocumentBuilder) AppendNull(key string) *DocumentBuilder {
db.doc = AppendNullElement(db.doc, key)
return db
}
// AppendRegex will append pattern and options using key to db.doc
func (db *DocumentBuilder) AppendRegex(key, pattern, options string) *DocumentBuilder {
db.doc = AppendRegexElement(db.doc, key, pattern, options)
return db
}
// AppendDBPointer will append ns and oid to using key to db.doc
func (db *DocumentBuilder) AppendDBPointer(key string, ns string, oid primitive.ObjectID) *DocumentBuilder {
db.doc = AppendDBPointerElement(db.doc, key, ns, oid)
return db
}
// AppendJavaScript will append js using the provided key to db.doc
func (db *DocumentBuilder) AppendJavaScript(key, js string) *DocumentBuilder {
db.doc = AppendJavaScriptElement(db.doc, key, js)
return db
}
// AppendSymbol will append a BSON symbol element using key and symbol db.doc
func (db *DocumentBuilder) AppendSymbol(key, symbol string) *DocumentBuilder {
db.doc = AppendSymbolElement(db.doc, key, symbol)
return db
}
// AppendCodeWithScope will append code and scope using key to db.doc
func (db *DocumentBuilder) AppendCodeWithScope(key string, code string, scope Document) *DocumentBuilder {
db.doc = AppendCodeWithScopeElement(db.doc, key, code, scope)
return db
}
// AppendTimestamp will append t and i to db.doc using provided key
func (db *DocumentBuilder) AppendTimestamp(key string, t, i uint32) *DocumentBuilder {
db.doc = AppendTimestampElement(db.doc, key, t, i)
return db
}
// AppendInt64 will append i64 to dst using key to db.doc
func (db *DocumentBuilder) AppendInt64(key string, i64 int64) *DocumentBuilder {
db.doc = AppendInt64Element(db.doc, key, i64)
return db
}
// AppendDecimal128 will append d128 to db.doc using provided key
func (db *DocumentBuilder) AppendDecimal128(key string, d128 primitive.Decimal128) *DocumentBuilder {
db.doc = AppendDecimal128Element(db.doc, key, d128)
return db
}
// AppendMaxKey will append a max key element using key to db.doc
func (db *DocumentBuilder) AppendMaxKey(key string) *DocumentBuilder {
db.doc = AppendMaxKeyElement(db.doc, key)
return db
}
// AppendMinKey will append a min key element using key to db.doc
func (db *DocumentBuilder) AppendMinKey(key string) *DocumentBuilder {
db.doc = AppendMinKeyElement(db.doc, key)
return db
}
// AppendValue will append a BSON element with the provided key and value to the document.
func (db *DocumentBuilder) AppendValue(key string, val Value) *DocumentBuilder {
db.doc = AppendValueElement(db.doc, key, val)
return db
}
// StartDocument starts building an inline document element with the provided key
// After this document is completed, the user must call finishDocument
func (db *DocumentBuilder) StartDocument(key string) *DocumentBuilder {
db.doc = AppendHeader(db.doc, bsontype.EmbeddedDocument, key)
db = db.startDocument()
return db
}
// FinishDocument builds the most recent document created
func (db *DocumentBuilder) FinishDocument() *DocumentBuilder {
db.doc = db.Build()
return db
}

View File

@@ -0,0 +1,215 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"math"
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func TestDocumentBuilder(t *testing.T) {
bits := math.Float64bits(3.14159)
pi := make([]byte, 8)
binary.LittleEndian.PutUint64(pi, bits)
testCases := []struct {
name string
fn interface{}
params []interface{}
expected []byte
}{
{
"AppendInt32",
NewDocumentBuilder().AppendInt32,
[]interface{}{"foobar", int32(256)},
BuildDocumentFromElements(nil, AppendInt32Element(nil, "foobar", 256)),
},
{
"AppendDouble",
NewDocumentBuilder().AppendDouble,
[]interface{}{"foobar", float64(3.14159)},
BuildDocumentFromElements(nil, AppendDoubleElement(nil, "foobar", float64(3.14159))),
},
{
"AppendString",
NewDocumentBuilder().AppendString,
[]interface{}{"foobar", "x"},
BuildDocumentFromElements(nil, AppendStringElement(nil, "foobar", "x")),
},
{
"AppendDocument",
NewDocumentBuilder().AppendDocument,
[]interface{}{"foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
BuildDocumentFromElements(nil, AppendDocumentElement(nil, "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
},
{
"AppendArray",
NewDocumentBuilder().AppendArray,
[]interface{}{"foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
BuildDocumentFromElements(nil, AppendArrayElement(nil, "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00})),
},
{
"AppendBinary",
NewDocumentBuilder().AppendBinary,
[]interface{}{"foobar", byte(0x02), []byte{0x01, 0x02, 0x03}},
BuildDocumentFromElements(nil, AppendBinaryElement(nil, "foobar", byte(0x02), []byte{0x01, 0x02, 0x03})),
},
{
"AppendObjectID",
NewDocumentBuilder().AppendObjectID,
[]interface{}{
"foobar",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
BuildDocumentFromElements(nil, AppendObjectIDElement(nil, "foobar",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
},
{
"AppendBoolean",
NewDocumentBuilder().AppendBoolean,
[]interface{}{"foobar", true},
BuildDocumentFromElements(nil, AppendBooleanElement(nil, "foobar", true)),
},
{
"AppendDateTime",
NewDocumentBuilder().AppendDateTime,
[]interface{}{"foobar", int64(256)},
BuildDocumentFromElements(nil, AppendDateTimeElement(nil, "foobar", int64(256))),
},
{
"AppendNull",
NewDocumentBuilder().AppendNull,
[]interface{}{"foobar"},
BuildDocumentFromElements(nil, AppendNullElement(nil, "foobar")),
},
{
"AppendRegex",
NewDocumentBuilder().AppendRegex,
[]interface{}{"foobar", "bar", "baz"},
BuildDocumentFromElements(nil, AppendRegexElement(nil, "foobar", "bar", "baz")),
},
{
"AppendJavaScript",
NewDocumentBuilder().AppendJavaScript,
[]interface{}{"foobar", "barbaz"},
BuildDocumentFromElements(nil, AppendJavaScriptElement(nil, "foobar", "barbaz")),
},
{
"AppendCodeWithScope",
NewDocumentBuilder().AppendCodeWithScope,
[]interface{}{"foobar", "barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00})},
BuildDocumentFromElements(nil, AppendCodeWithScopeElement(nil, "foobar", "barbaz", Document([]byte{0x05, 0x00, 0x00, 0x00, 0x00}))),
},
{
"AppendTimestamp",
NewDocumentBuilder().AppendTimestamp,
[]interface{}{"foobar", uint32(65536), uint32(256)},
BuildDocumentFromElements(nil, AppendTimestampElement(nil, "foobar", uint32(65536), uint32(256))),
},
{
"AppendInt64",
NewDocumentBuilder().AppendInt64,
[]interface{}{"foobar", int64(4294967296)},
BuildDocumentFromElements(nil, AppendInt64Element(nil, "foobar", int64(4294967296))),
},
{
"AppendDecimal128",
NewDocumentBuilder().AppendDecimal128,
[]interface{}{"foobar", primitive.NewDecimal128(4294967296, 65536)},
BuildDocumentFromElements(nil, AppendDecimal128Element(nil, "foobar", primitive.NewDecimal128(4294967296, 65536))),
},
{
"AppendMaxKey",
NewDocumentBuilder().AppendMaxKey,
[]interface{}{"foobar"},
BuildDocumentFromElements(nil, AppendMaxKeyElement(nil, "foobar")),
},
{
"AppendMinKey",
NewDocumentBuilder().AppendMinKey,
[]interface{}{"foobar"},
BuildDocumentFromElements(nil, AppendMinKeyElement(nil, "foobar")),
},
{
"AppendSymbol",
NewDocumentBuilder().AppendSymbol,
[]interface{}{"foobar", "barbaz"},
BuildDocumentFromElements(nil, AppendSymbolElement(nil, "foobar", "barbaz")),
},
{
"AppendDBPointer",
NewDocumentBuilder().AppendDBPointer,
[]interface{}{"foobar", "barbaz",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}},
BuildDocumentFromElements(nil, AppendDBPointerElement(nil, "foobar", "barbaz",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C})),
},
{
"AppendUndefined",
NewDocumentBuilder().AppendUndefined,
[]interface{}{"foobar"},
BuildDocumentFromElements(nil, AppendUndefinedElement(nil, "foobar")),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params) {
t.Fatalf("tc.params must match the number of params in tc.fn. params %d; fn %d", fn.Type().NumIn(), len(tc.params))
}
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf(&DocumentBuilder{}) {
t.Fatalf("fn must have one return parameter and it must be a DocumentBuilder.")
}
params := make([]reflect.Value, 0, len(tc.params))
for _, param := range tc.params {
params = append(params, reflect.ValueOf(param))
}
results := fn.Call(params)
got := results[0].Interface().(*DocumentBuilder)
doc := got.Build()
want := tc.expected
if !bytes.Equal(doc, want) {
t.Errorf("Did not receive expected bytes. got %v; want %v", got, want)
}
})
}
t.Run("TestBuildTwoElements", func(t *testing.T) {
intArr := BuildDocumentFromElements(nil, AppendInt32Element(nil, "0", int32(1)))
expected := BuildDocumentFromElements(nil, AppendArrayElement(AppendInt32Element(nil, "x", int32(3)), "y", intArr))
elem := NewArrayBuilder().AppendInt32(int32(1)).Build()
result := NewDocumentBuilder().AppendInt32("x", int32(3)).AppendArray("y", elem).Build()
if !bytes.Equal(result, expected) {
t.Errorf("Documents do not match. got %v; want %v", result, expected)
}
})
t.Run("TestBuildInlineDocument", func(t *testing.T) {
docElement := BuildDocumentFromElements(nil, AppendInt32Element(nil, "x", int32(256)))
expected := Document(BuildDocumentFromElements(nil, AppendDocumentElement(nil, "y", docElement)))
result := NewDocumentBuilder().StartDocument("y").AppendInt32("x", int32(256)).FinishDocument().Build()
if !bytes.Equal(result, expected) {
t.Errorf("Documents do not match. got %v; want %v", result, expected)
}
})
t.Run("TestBuildNestedInlineDocument", func(t *testing.T) {
docElement := BuildDocumentFromElements(nil, AppendDoubleElement(nil, "x", 3.14))
docInline := BuildDocumentFromElements(nil, AppendDocumentElement(nil, "y", docElement))
expected := Document(BuildDocumentFromElements(nil, AppendDocumentElement(nil, "z", docInline)))
result := NewDocumentBuilder().StartDocument("z").StartDocument("y").AppendDouble("x", 3.14).FinishDocument().FinishDocument().Build()
if !bytes.Equal(result, expected) {
t.Errorf("Documents do not match. got %v; want %v", result, expected)
}
})
}

View File

@@ -0,0 +1,862 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsoncore contains functions that can be used to encode and decode BSON
// elements and values to or from a slice of bytes. These functions are aimed at
// allowing low level manipulation of BSON and can be used to build a higher
// level BSON library.
//
// The Read* functions within this package return the values of the element and
// a boolean indicating if the values are valid. A boolean was used instead of
// an error because any error that would be returned would be the same: not
// enough bytes. This library attempts to do no validation, it will only return
// false if there are not enough bytes for an item to be read. For example, the
// ReadDocument function checks the length, if that length is larger than the
// number of bytes available, it will return false, if there are enough bytes, it
// will return those bytes and true. It is the consumers responsibility to
// validate those bytes.
//
// The Append* functions within this package will append the type value to the
// given dst slice. If the slice has enough capacity, it will not grow the
// slice. The Append*Element functions within this package operate in the same
// way, but additionally append the BSON type and the key before the value.
package bsoncore // import "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
import (
"bytes"
"fmt"
"math"
"strconv"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
const (
// EmptyDocumentLength is the length of a document that has been started/ended but has no elements.
EmptyDocumentLength = 5
// nullTerminator is a string version of the 0 byte that is appended at the end of cstrings.
nullTerminator = string(byte(0))
invalidKeyPanicMsg = "BSON element keys cannot contain null bytes"
invalidRegexPanicMsg = "BSON regex values cannot contain null bytes"
)
// AppendType will append t to dst and return the extended buffer.
func AppendType(dst []byte, t bsontype.Type) []byte { return append(dst, byte(t)) }
// AppendKey will append key to dst and return the extended buffer.
func AppendKey(dst []byte, key string) []byte { return append(dst, key+nullTerminator...) }
// AppendHeader will append Type t and key to dst and return the extended
// buffer.
func AppendHeader(dst []byte, t bsontype.Type, key string) []byte {
if !isValidCString(key) {
panic(invalidKeyPanicMsg)
}
dst = AppendType(dst, t)
dst = append(dst, key...)
return append(dst, 0x00)
// return append(AppendType(dst, t), key+string(0x00)...)
}
// TODO(skriptble): All of the Read* functions should return src resliced to start just after what was read.
// ReadType will return the first byte of the provided []byte as a type. If
// there is no available byte, false is returned.
func ReadType(src []byte) (bsontype.Type, []byte, bool) {
if len(src) < 1 {
return 0, src, false
}
return bsontype.Type(src[0]), src[1:], true
}
// ReadKey will read a key from src. The 0x00 byte will not be present
// in the returned string. If there are not enough bytes available, false is
// returned.
func ReadKey(src []byte) (string, []byte, bool) { return readcstring(src) }
// ReadKeyBytes will read a key from src as bytes. The 0x00 byte will
// not be present in the returned string. If there are not enough bytes
// available, false is returned.
func ReadKeyBytes(src []byte) ([]byte, []byte, bool) { return readcstringbytes(src) }
// ReadHeader will read a type byte and a key from src. If both of these
// values cannot be read, false is returned.
func ReadHeader(src []byte) (t bsontype.Type, key string, rem []byte, ok bool) {
t, rem, ok = ReadType(src)
if !ok {
return 0, "", src, false
}
key, rem, ok = ReadKey(rem)
if !ok {
return 0, "", src, false
}
return t, key, rem, true
}
// ReadHeaderBytes will read a type and a key from src and the remainder of the bytes
// are returned as rem. If either the type or key cannot be red, ok will be false.
func ReadHeaderBytes(src []byte) (header []byte, rem []byte, ok bool) {
if len(src) < 1 {
return nil, src, false
}
idx := bytes.IndexByte(src[1:], 0x00)
if idx == -1 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
// ReadElement reads the next full element from src. It returns the element, the remaining bytes in
// the slice, and a boolean indicating if the read was successful.
func ReadElement(src []byte) (Element, []byte, bool) {
if len(src) < 1 {
return nil, src, false
}
t := bsontype.Type(src[0])
idx := bytes.IndexByte(src[1:], 0x00)
if idx == -1 {
return nil, src, false
}
length, ok := valueLength(src[idx+2:], t) // We add 2 here because we called IndexByte with src[1:]
if !ok {
return nil, src, false
}
elemLength := 1 + idx + 1 + int(length)
if elemLength > len(src) {
return nil, src, false
}
if elemLength < 0 {
return nil, src, false
}
return src[:elemLength], src[elemLength:], true
}
// AppendValueElement appends value to dst as an element using key as the element's key.
func AppendValueElement(dst []byte, key string, value Value) []byte {
dst = AppendHeader(dst, value.Type, key)
dst = append(dst, value.Data...)
return dst
}
// ReadValue reads the next value as the provided types and returns a Value, the remaining bytes,
// and a boolean indicating if the read was successful.
func ReadValue(src []byte, t bsontype.Type) (Value, []byte, bool) {
data, rem, ok := readValue(src, t)
if !ok {
return Value{}, src, false
}
return Value{Type: t, Data: data}, rem, true
}
// AppendDouble will append f to dst and return the extended buffer.
func AppendDouble(dst []byte, f float64) []byte {
return appendu64(dst, math.Float64bits(f))
}
// AppendDoubleElement will append a BSON double element using key and f to dst
// and return the extended buffer.
func AppendDoubleElement(dst []byte, key string, f float64) []byte {
return AppendDouble(AppendHeader(dst, bsontype.Double, key), f)
}
// ReadDouble will read a float64 from src. If there are not enough bytes it
// will return false.
func ReadDouble(src []byte) (float64, []byte, bool) {
bits, src, ok := readu64(src)
if !ok {
return 0, src, false
}
return math.Float64frombits(bits), src, true
}
// AppendString will append s to dst and return the extended buffer.
func AppendString(dst []byte, s string) []byte {
return appendstring(dst, s)
}
// AppendStringElement will append a BSON string element using key and val to dst
// and return the extended buffer.
func AppendStringElement(dst []byte, key, val string) []byte {
return AppendString(AppendHeader(dst, bsontype.String, key), val)
}
// ReadString will read a string from src. If there are not enough bytes it
// will return false.
func ReadString(src []byte) (string, []byte, bool) {
return readstring(src)
}
// AppendDocumentStart reserves a document's length and returns the index where the length begins.
// This index can later be used to write the length of the document.
func AppendDocumentStart(dst []byte) (index int32, b []byte) {
// TODO(skriptble): We really need AppendDocumentStart and AppendDocumentEnd. AppendDocumentStart would handle calling
// TODO ReserveLength and providing the index of the start of the document. AppendDocumentEnd would handle taking that
// TODO start index, adding the null byte, calculating the length, and filling in the length at the start of the
// TODO document.
return ReserveLength(dst)
}
// AppendDocumentStartInline functions the same as AppendDocumentStart but takes a pointer to the
// index int32 which allows this function to be used inline.
func AppendDocumentStartInline(dst []byte, index *int32) []byte {
idx, doc := AppendDocumentStart(dst)
*index = idx
return doc
}
// AppendDocumentElementStart writes a document element header and then reserves the length bytes.
func AppendDocumentElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendDocumentStart(AppendHeader(dst, bsontype.EmbeddedDocument, key))
}
// AppendDocumentEnd writes the null byte for a document and updates the length of the document.
// The index should be the beginning of the document's length bytes.
func AppendDocumentEnd(dst []byte, index int32) ([]byte, error) {
if int(index) > len(dst)-4 {
return dst, fmt.Errorf("not enough bytes available after index to write length")
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, index, int32(len(dst[index:])))
return dst, nil
}
// AppendDocument will append doc to dst and return the extended buffer.
func AppendDocument(dst []byte, doc []byte) []byte { return append(dst, doc...) }
// AppendDocumentElement will append a BSON embedded document element using key
// and doc to dst and return the extended buffer.
func AppendDocumentElement(dst []byte, key string, doc []byte) []byte {
return AppendDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), doc)
}
// BuildDocument will create a document with the given slice of elements and will append
// it to dst and return the extended buffer.
func BuildDocument(dst []byte, elems ...[]byte) []byte {
idx, dst := ReserveLength(dst)
for _, elem := range elems {
dst = append(dst, elem...)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildDocumentValue creates an Embedded Document value from the given elements.
func BuildDocumentValue(elems ...[]byte) Value {
return Value{Type: bsontype.EmbeddedDocument, Data: BuildDocument(nil, elems...)}
}
// BuildDocumentElement will append a BSON embedded document elemnt using key and the provided
// elements and return the extended buffer.
func BuildDocumentElement(dst []byte, key string, elems ...[]byte) []byte {
return BuildDocument(AppendHeader(dst, bsontype.EmbeddedDocument, key), elems...)
}
// BuildDocumentFromElements is an alaias for the BuildDocument function.
var BuildDocumentFromElements = BuildDocument
// ReadDocument will read a document from src. If there are not enough bytes it
// will return false.
func ReadDocument(src []byte) (doc Document, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendArrayStart appends the length bytes to an array and then returns the index of the start
// of those length bytes.
func AppendArrayStart(dst []byte) (index int32, b []byte) { return ReserveLength(dst) }
// AppendArrayElementStart appends an array element header and then the length bytes for an array,
// returning the index where the length starts.
func AppendArrayElementStart(dst []byte, key string) (index int32, b []byte) {
return AppendArrayStart(AppendHeader(dst, bsontype.Array, key))
}
// AppendArrayEnd appends the null byte to an array and calculates the length, inserting that
// calculated length starting at index.
func AppendArrayEnd(dst []byte, index int32) ([]byte, error) { return AppendDocumentEnd(dst, index) }
// AppendArray will append arr to dst and return the extended buffer.
func AppendArray(dst []byte, arr []byte) []byte { return append(dst, arr...) }
// AppendArrayElement will append a BSON array element using key and arr to dst
// and return the extended buffer.
func AppendArrayElement(dst []byte, key string, arr []byte) []byte {
return AppendArray(AppendHeader(dst, bsontype.Array, key), arr)
}
// BuildArray will append a BSON array to dst built from values.
func BuildArray(dst []byte, values ...Value) []byte {
idx, dst := ReserveLength(dst)
for pos, val := range values {
dst = AppendValueElement(dst, strconv.Itoa(pos), val)
}
dst = append(dst, 0x00)
dst = UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst
}
// BuildArrayElement will create an array element using the provided values.
func BuildArrayElement(dst []byte, key string, values ...Value) []byte {
return BuildArray(AppendHeader(dst, bsontype.Array, key), values...)
}
// ReadArray will read an array from src. If there are not enough bytes it
// will return false.
func ReadArray(src []byte) (arr Array, rem []byte, ok bool) { return readLengthBytes(src) }
// AppendBinary will append subtype and b to dst and return the extended buffer.
func AppendBinary(dst []byte, subtype byte, b []byte) []byte {
if subtype == 0x02 {
return appendBinarySubtype2(dst, subtype, b)
}
dst = append(appendLength(dst, int32(len(b))), subtype)
return append(dst, b...)
}
// AppendBinaryElement will append a BSON binary element using key, subtype, and
// b to dst and return the extended buffer.
func AppendBinaryElement(dst []byte, key string, subtype byte, b []byte) []byte {
return AppendBinary(AppendHeader(dst, bsontype.Binary, key), subtype, b)
}
// ReadBinary will read a subtype and bin from src. If there are not enough bytes it
// will return false.
func ReadBinary(src []byte) (subtype byte, bin []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok {
return 0x00, nil, src, false
}
if len(rem) < 1 { // subtype
return 0x00, nil, src, false
}
subtype, rem = rem[0], rem[1:]
if len(rem) < int(length) {
return 0x00, nil, src, false
}
if subtype == 0x02 {
length, rem, ok = ReadLength(rem)
if !ok || len(rem) < int(length) {
return 0x00, nil, src, false
}
}
return subtype, rem[:length], rem[length:], true
}
// AppendUndefinedElement will append a BSON undefined element using key to dst
// and return the extended buffer.
func AppendUndefinedElement(dst []byte, key string) []byte {
return AppendHeader(dst, bsontype.Undefined, key)
}
// AppendObjectID will append oid to dst and return the extended buffer.
func AppendObjectID(dst []byte, oid primitive.ObjectID) []byte { return append(dst, oid[:]...) }
// AppendObjectIDElement will append a BSON ObjectID element using key and oid to dst
// and return the extended buffer.
func AppendObjectIDElement(dst []byte, key string, oid primitive.ObjectID) []byte {
return AppendObjectID(AppendHeader(dst, bsontype.ObjectID, key), oid)
}
// ReadObjectID will read an ObjectID from src. If there are not enough bytes it
// will return false.
func ReadObjectID(src []byte) (primitive.ObjectID, []byte, bool) {
if len(src) < 12 {
return primitive.ObjectID{}, src, false
}
var oid primitive.ObjectID
copy(oid[:], src[0:12])
return oid, src[12:], true
}
// AppendBoolean will append b to dst and return the extended buffer.
func AppendBoolean(dst []byte, b bool) []byte {
if b {
return append(dst, 0x01)
}
return append(dst, 0x00)
}
// AppendBooleanElement will append a BSON boolean element using key and b to dst
// and return the extended buffer.
func AppendBooleanElement(dst []byte, key string, b bool) []byte {
return AppendBoolean(AppendHeader(dst, bsontype.Boolean, key), b)
}
// ReadBoolean will read a bool from src. If there are not enough bytes it
// will return false.
func ReadBoolean(src []byte) (bool, []byte, bool) {
if len(src) < 1 {
return false, src, false
}
return src[0] == 0x01, src[1:], true
}
// AppendDateTime will append dt to dst and return the extended buffer.
func AppendDateTime(dst []byte, dt int64) []byte { return appendi64(dst, dt) }
// AppendDateTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendDateTimeElement(dst []byte, key string, dt int64) []byte {
return AppendDateTime(AppendHeader(dst, bsontype.DateTime, key), dt)
}
// ReadDateTime will read an int64 datetime from src. If there are not enough bytes it
// will return false.
func ReadDateTime(src []byte) (int64, []byte, bool) { return readi64(src) }
// AppendTime will append time as a BSON DateTime to dst and return the extended buffer.
func AppendTime(dst []byte, t time.Time) []byte {
return AppendDateTime(dst, t.Unix()*1000+int64(t.Nanosecond()/1e6))
}
// AppendTimeElement will append a BSON datetime element using key and dt to dst
// and return the extended buffer.
func AppendTimeElement(dst []byte, key string, t time.Time) []byte {
return AppendTime(AppendHeader(dst, bsontype.DateTime, key), t)
}
// ReadTime will read an time.Time datetime from src. If there are not enough bytes it
// will return false.
func ReadTime(src []byte) (time.Time, []byte, bool) {
dt, rem, ok := readi64(src)
return time.Unix(dt/1e3, dt%1e3*1e6), rem, ok
}
// AppendNullElement will append a BSON null element using key to dst
// and return the extended buffer.
func AppendNullElement(dst []byte, key string) []byte { return AppendHeader(dst, bsontype.Null, key) }
// AppendRegex will append pattern and options to dst and return the extended buffer.
func AppendRegex(dst []byte, pattern, options string) []byte {
if !isValidCString(pattern) || !isValidCString(options) {
panic(invalidRegexPanicMsg)
}
return append(dst, pattern+nullTerminator+options+nullTerminator...)
}
// AppendRegexElement will append a BSON regex element using key, pattern, and
// options to dst and return the extended buffer.
func AppendRegexElement(dst []byte, key, pattern, options string) []byte {
return AppendRegex(AppendHeader(dst, bsontype.Regex, key), pattern, options)
}
// ReadRegex will read a pattern and options from src. If there are not enough bytes it
// will return false.
func ReadRegex(src []byte) (pattern, options string, rem []byte, ok bool) {
pattern, rem, ok = readcstring(src)
if !ok {
return "", "", src, false
}
options, rem, ok = readcstring(rem)
if !ok {
return "", "", src, false
}
return pattern, options, rem, true
}
// AppendDBPointer will append ns and oid to dst and return the extended buffer.
func AppendDBPointer(dst []byte, ns string, oid primitive.ObjectID) []byte {
return append(appendstring(dst, ns), oid[:]...)
}
// AppendDBPointerElement will append a BSON DBPointer element using key, ns,
// and oid to dst and return the extended buffer.
func AppendDBPointerElement(dst []byte, key, ns string, oid primitive.ObjectID) []byte {
return AppendDBPointer(AppendHeader(dst, bsontype.DBPointer, key), ns, oid)
}
// ReadDBPointer will read a ns and oid from src. If there are not enough bytes it
// will return false.
func ReadDBPointer(src []byte) (ns string, oid primitive.ObjectID, rem []byte, ok bool) {
ns, rem, ok = readstring(src)
if !ok {
return "", primitive.ObjectID{}, src, false
}
oid, rem, ok = ReadObjectID(rem)
if !ok {
return "", primitive.ObjectID{}, src, false
}
return ns, oid, rem, true
}
// AppendJavaScript will append js to dst and return the extended buffer.
func AppendJavaScript(dst []byte, js string) []byte { return appendstring(dst, js) }
// AppendJavaScriptElement will append a BSON JavaScript element using key and
// js to dst and return the extended buffer.
func AppendJavaScriptElement(dst []byte, key, js string) []byte {
return AppendJavaScript(AppendHeader(dst, bsontype.JavaScript, key), js)
}
// ReadJavaScript will read a js string from src. If there are not enough bytes it
// will return false.
func ReadJavaScript(src []byte) (js string, rem []byte, ok bool) { return readstring(src) }
// AppendSymbol will append symbol to dst and return the extended buffer.
func AppendSymbol(dst []byte, symbol string) []byte { return appendstring(dst, symbol) }
// AppendSymbolElement will append a BSON symbol element using key and symbol to dst
// and return the extended buffer.
func AppendSymbolElement(dst []byte, key, symbol string) []byte {
return AppendSymbol(AppendHeader(dst, bsontype.Symbol, key), symbol)
}
// ReadSymbol will read a symbol string from src. If there are not enough bytes it
// will return false.
func ReadSymbol(src []byte) (symbol string, rem []byte, ok bool) { return readstring(src) }
// AppendCodeWithScope will append code and scope to dst and return the extended buffer.
func AppendCodeWithScope(dst []byte, code string, scope []byte) []byte {
length := int32(4 + 4 + len(code) + 1 + len(scope)) // length of cws, length of code, code, 0x00, scope
dst = appendLength(dst, length)
return append(appendstring(dst, code), scope...)
}
// AppendCodeWithScopeElement will append a BSON code with scope element using
// key, code, and scope to dst
// and return the extended buffer.
func AppendCodeWithScopeElement(dst []byte, key, code string, scope []byte) []byte {
return AppendCodeWithScope(AppendHeader(dst, bsontype.CodeWithScope, key), code, scope)
}
// ReadCodeWithScope will read code and scope from src. If there are not enough bytes it
// will return false.
func ReadCodeWithScope(src []byte) (code string, scope []byte, rem []byte, ok bool) {
length, rem, ok := ReadLength(src)
if !ok || len(src) < int(length) {
return "", nil, src, false
}
code, rem, ok = readstring(rem)
if !ok {
return "", nil, src, false
}
scope, rem, ok = ReadDocument(rem)
if !ok {
return "", nil, src, false
}
return code, scope, rem, true
}
// AppendInt32 will append i32 to dst and return the extended buffer.
func AppendInt32(dst []byte, i32 int32) []byte { return appendi32(dst, i32) }
// AppendInt32Element will append a BSON int32 element using key and i32 to dst
// and return the extended buffer.
func AppendInt32Element(dst []byte, key string, i32 int32) []byte {
return AppendInt32(AppendHeader(dst, bsontype.Int32, key), i32)
}
// ReadInt32 will read an int32 from src. If there are not enough bytes it
// will return false.
func ReadInt32(src []byte) (int32, []byte, bool) { return readi32(src) }
// AppendTimestamp will append t and i to dst and return the extended buffer.
func AppendTimestamp(dst []byte, t, i uint32) []byte {
return appendu32(appendu32(dst, i), t) // i is the lower 4 bytes, t is the higher 4 bytes
}
// AppendTimestampElement will append a BSON timestamp element using key, t, and
// i to dst and return the extended buffer.
func AppendTimestampElement(dst []byte, key string, t, i uint32) []byte {
return AppendTimestamp(AppendHeader(dst, bsontype.Timestamp, key), t, i)
}
// ReadTimestamp will read t and i from src. If there are not enough bytes it
// will return false.
func ReadTimestamp(src []byte) (t, i uint32, rem []byte, ok bool) {
i, rem, ok = readu32(src)
if !ok {
return 0, 0, src, false
}
t, rem, ok = readu32(rem)
if !ok {
return 0, 0, src, false
}
return t, i, rem, true
}
// AppendInt64 will append i64 to dst and return the extended buffer.
func AppendInt64(dst []byte, i64 int64) []byte { return appendi64(dst, i64) }
// AppendInt64Element will append a BSON int64 element using key and i64 to dst
// and return the extended buffer.
func AppendInt64Element(dst []byte, key string, i64 int64) []byte {
return AppendInt64(AppendHeader(dst, bsontype.Int64, key), i64)
}
// ReadInt64 will read an int64 from src. If there are not enough bytes it
// will return false.
func ReadInt64(src []byte) (int64, []byte, bool) { return readi64(src) }
// AppendDecimal128 will append d128 to dst and return the extended buffer.
func AppendDecimal128(dst []byte, d128 primitive.Decimal128) []byte {
high, low := d128.GetBytes()
return appendu64(appendu64(dst, low), high)
}
// AppendDecimal128Element will append a BSON primitive.28 element using key and
// d128 to dst and return the extended buffer.
func AppendDecimal128Element(dst []byte, key string, d128 primitive.Decimal128) []byte {
return AppendDecimal128(AppendHeader(dst, bsontype.Decimal128, key), d128)
}
// ReadDecimal128 will read a primitive.Decimal128 from src. If there are not enough bytes it
// will return false.
func ReadDecimal128(src []byte) (primitive.Decimal128, []byte, bool) {
l, rem, ok := readu64(src)
if !ok {
return primitive.Decimal128{}, src, false
}
h, rem, ok := readu64(rem)
if !ok {
return primitive.Decimal128{}, src, false
}
return primitive.NewDecimal128(h, l), rem, true
}
// AppendMaxKeyElement will append a BSON max key element using key to dst
// and return the extended buffer.
func AppendMaxKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, bsontype.MaxKey, key)
}
// AppendMinKeyElement will append a BSON min key element using key to dst
// and return the extended buffer.
func AppendMinKeyElement(dst []byte, key string) []byte {
return AppendHeader(dst, bsontype.MinKey, key)
}
// EqualValue will return true if the two values are equal.
func EqualValue(t1, t2 bsontype.Type, v1, v2 []byte) bool {
if t1 != t2 {
return false
}
v1, _, ok := readValue(v1, t1)
if !ok {
return false
}
v2, _, ok = readValue(v2, t2)
if !ok {
return false
}
return bytes.Equal(v1, v2)
}
// valueLength will determine the length of the next value contained in src as if it
// is type t. The returned bool will be false if there are not enough bytes in src for
// a value of type t.
func valueLength(src []byte, t bsontype.Type) (int32, bool) {
var length int32
ok := true
switch t {
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
length, _, ok = ReadLength(src)
case bsontype.Binary:
length, _, ok = ReadLength(src)
length += 4 + 1 // binary length + subtype byte
case bsontype.Boolean:
length = 1
case bsontype.DBPointer:
length, _, ok = ReadLength(src)
length += 4 + 12 // string length + ObjectID length
case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp:
length = 8
case bsontype.Decimal128:
length = 16
case bsontype.Int32:
length = 4
case bsontype.JavaScript, bsontype.String, bsontype.Symbol:
length, _, ok = ReadLength(src)
length += 4
case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined:
length = 0
case bsontype.ObjectID:
length = 12
case bsontype.Regex:
regex := bytes.IndexByte(src, 0x00)
if regex < 0 {
ok = false
break
}
pattern := bytes.IndexByte(src[regex+1:], 0x00)
if pattern < 0 {
ok = false
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1)
default:
ok = false
}
return length, ok
}
func readValue(src []byte, t bsontype.Type) ([]byte, []byte, bool) {
length, ok := valueLength(src, t)
if !ok || int(length) > len(src) {
return nil, src, false
}
return src[:length], src[length:], true
}
// ReserveLength reserves the space required for length and returns the index where to write the length
// and the []byte with reserved space.
func ReserveLength(dst []byte) (int32, []byte) {
index := len(dst)
return int32(index), append(dst, 0x00, 0x00, 0x00, 0x00)
}
// UpdateLength updates the length at index with length and returns the []byte.
func UpdateLength(dst []byte, index, length int32) []byte {
dst[index] = byte(length)
dst[index+1] = byte(length >> 8)
dst[index+2] = byte(length >> 16)
dst[index+3] = byte(length >> 24)
return dst
}
func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) }
func appendi32(dst []byte, i32 int32) []byte {
return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24))
}
// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
// there aren't enough bytes to read a valid length, src is returned unomdified and the returned
// bool will be false.
func ReadLength(src []byte) (int32, []byte, bool) {
ln, src, ok := readi32(src)
if ln < 0 {
return ln, src, false
}
return ln, src, ok
}
func readi32(src []byte) (int32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true
}
func appendi64(dst []byte, i64 int64) []byte {
return append(dst,
byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24),
byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56),
)
}
func readi64(src []byte) (int64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 |
int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56)
return i64, src[8:], true
}
func appendu32(dst []byte, u32 uint32) []byte {
return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24))
}
func readu32(src []byte) (uint32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true
}
func appendu64(dst []byte, u64 uint64) []byte {
return append(dst,
byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24),
byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56),
)
}
func readu64(src []byte) (uint64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 |
uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56)
return u64, src[8:], true
}
// keep in sync with readcstringbytes
func readcstring(src []byte) (string, []byte, bool) {
idx := bytes.IndexByte(src, 0x00)
if idx < 0 {
return "", src, false
}
return string(src[:idx]), src[idx+1:], true
}
// keep in sync with readcstring
func readcstringbytes(src []byte) ([]byte, []byte, bool) {
idx := bytes.IndexByte(src, 0x00)
if idx < 0 {
return nil, src, false
}
return src[:idx], src[idx+1:], true
}
func appendstring(dst []byte, s string) []byte {
l := int32(len(s) + 1)
dst = appendLength(dst, l)
dst = append(dst, s...)
return append(dst, 0x00)
}
func readstring(src []byte) (string, []byte, bool) {
l, rem, ok := ReadLength(src)
if !ok {
return "", src, false
}
if len(src[4:]) < int(l) || l == 0 {
return "", src, false
}
return string(rem[:l-1]), rem[l:], true
}
// readLengthBytes attempts to read a length and that number of bytes. This
// function requires that the length include the four bytes for itself.
func readLengthBytes(src []byte) ([]byte, []byte, bool) {
l, _, ok := ReadLength(src)
if !ok {
return nil, src, false
}
if len(src) < int(l) {
return nil, src, false
}
return src[:l], src[l:], true
}
func appendBinarySubtype2(dst []byte, subtype byte, b []byte) []byte {
dst = appendLength(dst, int32(len(b)+4)) // The bytes we'll encode need to be 4 larger for the length bytes
dst = append(dst, subtype)
dst = appendLength(dst, int32(len(b)))
return append(dst, b...)
}
func isValidCString(cs string) bool {
return !strings.ContainsRune(cs, '\x00')
}

View File

@@ -0,0 +1,959 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"math"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func compareErrors(err1, err2 error) bool {
if err1 == nil && err2 == nil {
return true
}
if err1 == nil || err2 == nil {
return false
}
if err1.Error() != err2.Error() {
return false
}
return true
}
func TestAppend(t *testing.T) {
bits := math.Float64bits(3.14159)
pi := make([]byte, 8)
binary.LittleEndian.PutUint64(pi, bits)
testCases := []struct {
name string
fn interface{}
params []interface{}
expected []byte
}{
{
"AppendType",
AppendType,
[]interface{}{make([]byte, 0), bsontype.Null},
[]byte{byte(bsontype.Null)},
},
{
"AppendKey",
AppendKey,
[]interface{}{make([]byte, 0), "foobar"},
[]byte{'f', 'o', 'o', 'b', 'a', 'r', 0x00},
},
{
"AppendHeader",
AppendHeader,
[]interface{}{make([]byte, 0), bsontype.Null, "foobar"},
[]byte{byte(bsontype.Null), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
},
{
"AppendValueElement",
AppendValueElement,
[]interface{}{make([]byte, 0), "testing", Value{Type: bsontype.Boolean, Data: []byte{0x01}}},
[]byte{byte(bsontype.Boolean), 't', 'e', 's', 't', 'i', 'n', 'g', 0x00, 0x01},
},
{
"AppendDouble",
AppendDouble,
[]interface{}{make([]byte, 0), float64(3.14159)},
pi,
},
{
"AppendDoubleElement",
AppendDoubleElement,
[]interface{}{make([]byte, 0), "foobar", float64(3.14159)},
append([]byte{byte(bsontype.Double), 'f', 'o', 'o', 'b', 'a', 'r', 0x00}, pi...),
},
{
"AppendString",
AppendString,
[]interface{}{make([]byte, 0), "barbaz"},
[]byte{0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00},
},
{
"AppendStringElement",
AppendStringElement,
[]interface{}{make([]byte, 0), "foobar", "barbaz"},
[]byte{byte(bsontype.String),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
},
},
{
"AppendDocument",
AppendDocument,
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
[]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00},
},
{
"AppendDocumentElement",
AppendDocumentElement,
[]interface{}{make([]byte, 0), "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
[]byte{byte(bsontype.EmbeddedDocument),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x05, 0x00, 0x00, 0x00, 0x00,
},
},
{
"AppendArray",
AppendArray,
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
[]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00},
},
{
"AppendArrayElement",
AppendArrayElement,
[]interface{}{make([]byte, 0), "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
[]byte{byte(bsontype.Array),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x05, 0x00, 0x00, 0x00, 0x00,
},
},
{
"BuildArray",
BuildArray,
[]interface{}{make([]byte, 0), Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}},
[]byte{
0x10, 0x00, 0x00, 0x00,
byte(bsontype.Double), '0', 0x00,
pi[0], pi[1], pi[2], pi[3], pi[4], pi[5], pi[6], pi[7],
0x00,
},
},
{
"BuildArrayElement",
BuildArrayElement,
[]interface{}{make([]byte, 0), "foobar", Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}},
[]byte{byte(bsontype.Array),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x10, 0x00, 0x00, 0x00,
byte(bsontype.Double), '0', 0x00,
pi[0], pi[1], pi[2], pi[3], pi[4], pi[5], pi[6], pi[7],
0x00,
},
},
{
"AppendBinary Subtype2",
AppendBinary,
[]interface{}{make([]byte, 0), byte(0x02), []byte{0x01, 0x02, 0x03}},
[]byte{0x07, 0x00, 0x00, 0x00, 0x02, 0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03},
},
{
"AppendBinaryElement Subtype 2",
AppendBinaryElement,
[]interface{}{make([]byte, 0), "foobar", byte(0x02), []byte{0x01, 0x02, 0x03}},
[]byte{byte(bsontype.Binary),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x07, 0x00, 0x00, 0x00,
0x02,
0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03,
},
},
{
"AppendBinary",
AppendBinary,
[]interface{}{make([]byte, 0), byte(0xFF), []byte{0x01, 0x02, 0x03}},
[]byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
},
{
"AppendBinaryElement",
AppendBinaryElement,
[]interface{}{make([]byte, 0), "foobar", byte(0xFF), []byte{0x01, 0x02, 0x03}},
[]byte{byte(bsontype.Binary),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x03, 0x00, 0x00, 0x00,
0xFF,
0x01, 0x02, 0x03,
},
},
{
"AppendUndefinedElement",
AppendUndefinedElement,
[]interface{}{make([]byte, 0), "foobar"},
[]byte{byte(bsontype.Undefined), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
},
{
"AppendObjectID",
AppendObjectID,
[]interface{}{
make([]byte, 0),
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
{
"AppendObjectIDElement",
AppendObjectIDElement,
[]interface{}{
make([]byte, 0), "foobar",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
[]byte{byte(bsontype.ObjectID),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
},
},
{
"AppendBoolean (true)",
AppendBoolean,
[]interface{}{make([]byte, 0), true},
[]byte{0x01},
},
{
"AppendBoolean (false)",
AppendBoolean,
[]interface{}{make([]byte, 0), false},
[]byte{0x00},
},
{
"AppendBooleanElement",
AppendBooleanElement,
[]interface{}{make([]byte, 0), "foobar", true},
[]byte{byte(bsontype.Boolean), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x01},
},
{
"AppendDateTime",
AppendDateTime,
[]interface{}{make([]byte, 0), int64(256)},
[]byte{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
"AppendDateTimeElement",
AppendDateTimeElement,
[]interface{}{make([]byte, 0), "foobar", int64(256)},
[]byte{byte(bsontype.DateTime), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
{
"AppendNullElement",
AppendNullElement,
[]interface{}{make([]byte, 0), "foobar"},
[]byte{byte(bsontype.Null), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
},
{
"AppendRegex",
AppendRegex,
[]interface{}{make([]byte, 0), "bar", "baz"},
[]byte{'b', 'a', 'r', 0x00, 'b', 'a', 'z', 0x00},
},
{
"AppendRegexElement",
AppendRegexElement,
[]interface{}{make([]byte, 0), "foobar", "bar", "baz"},
[]byte{byte(bsontype.Regex),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
'b', 'a', 'r', 0x00, 'b', 'a', 'z', 0x00,
},
},
{
"AppendDBPointer",
AppendDBPointer,
[]interface{}{
make([]byte, 0),
"foobar",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
[]byte{
0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
},
},
{
"AppendDBPointerElement",
AppendDBPointerElement,
[]interface{}{
make([]byte, 0), "foobar",
"barbaz",
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
},
[]byte{byte(bsontype.DBPointer),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
},
},
{
"AppendJavaScript",
AppendJavaScript,
[]interface{}{make([]byte, 0), "barbaz"},
[]byte{0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00},
},
{
"AppendJavaScriptElement",
AppendJavaScriptElement,
[]interface{}{make([]byte, 0), "foobar", "barbaz"},
[]byte{byte(bsontype.JavaScript),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
},
},
{
"AppendSymbol",
AppendSymbol,
[]interface{}{make([]byte, 0), "barbaz"},
[]byte{0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00},
},
{
"AppendSymbolElement",
AppendSymbolElement,
[]interface{}{make([]byte, 0), "foobar", "barbaz"},
[]byte{byte(bsontype.Symbol),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
},
},
{
"AppendCodeWithScope",
AppendCodeWithScope,
[]interface{}{[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, "foobar", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
[]byte{0x05, 0x00, 0x00, 0x00, 0x00,
0x14, 0x00, 0x00, 0x00,
0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x05, 0x00, 0x00, 0x00, 0x00,
},
},
{
"AppendCodeWithScopeElement",
AppendCodeWithScopeElement,
[]interface{}{make([]byte, 0), "foobar", "barbaz", []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
[]byte{byte(bsontype.CodeWithScope),
'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x14, 0x00, 0x00, 0x00,
0x07, 0x00, 0x00, 0x00, 'b', 'a', 'r', 'b', 'a', 'z', 0x00,
0x05, 0x00, 0x00, 0x00, 0x00,
},
},
{
"AppendInt32",
AppendInt32,
[]interface{}{make([]byte, 0), int32(256)},
[]byte{0x00, 0x01, 0x00, 0x00},
},
{
"AppendInt32Element",
AppendInt32Element,
[]interface{}{make([]byte, 0), "foobar", int32(256)},
[]byte{byte(bsontype.Int32), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x01, 0x00, 0x00},
},
{
"AppendTimestamp",
AppendTimestamp,
[]interface{}{make([]byte, 0), uint32(65536), uint32(256)},
[]byte{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
},
{
"AppendTimestampElement",
AppendTimestampElement,
[]interface{}{make([]byte, 0), "foobar", uint32(65536), uint32(256)},
[]byte{byte(bsontype.Timestamp), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
},
{
"AppendInt64",
AppendInt64,
[]interface{}{make([]byte, 0), int64(4294967296)},
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
},
{
"AppendInt64Element",
AppendInt64Element,
[]interface{}{make([]byte, 0), "foobar", int64(4294967296)},
[]byte{byte(bsontype.Int64), 'f', 'o', 'o', 'b', 'a', 'r', 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
},
{
"AppendDecimal128",
AppendDecimal128,
[]interface{}{make([]byte, 0), primitive.NewDecimal128(4294967296, 65536)},
[]byte{
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
},
},
{
"AppendDecimal128Element",
AppendDecimal128Element,
[]interface{}{make([]byte, 0), "foobar", primitive.NewDecimal128(4294967296, 65536)},
[]byte{
byte(bsontype.Decimal128), 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
},
},
{
"AppendMaxKeyElement",
AppendMaxKeyElement,
[]interface{}{make([]byte, 0), "foobar"},
[]byte{byte(bsontype.MaxKey), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
},
{
"AppendMinKeyElement",
AppendMinKeyElement,
[]interface{}{make([]byte, 0), "foobar"},
[]byte{byte(bsontype.MinKey), 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params) {
t.Fatalf("tc.params must match the number of params in tc.fn. params %d; fn %d", fn.Type().NumIn(), len(tc.params))
}
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf([]byte{}) {
t.Fatalf("fn must have one return parameter and it must be a []byte.")
}
params := make([]reflect.Value, 0, len(tc.params))
for _, param := range tc.params {
params = append(params, reflect.ValueOf(param))
}
results := fn.Call(params)
got := results[0].Interface().([]byte)
want := tc.expected
if !bytes.Equal(got, want) {
t.Errorf("Did not receive expected bytes. got %v; want %v", got, want)
}
})
}
}
func TestRead(t *testing.T) {
bits := math.Float64bits(3.14159)
pi := make([]byte, 8)
binary.LittleEndian.PutUint64(pi, bits)
testCases := []struct {
name string
fn interface{}
param []byte
expected []interface{}
}{
{
"ReadType/not enough bytes",
ReadType,
[]byte{},
[]interface{}{bsontype.Type(0), []byte{}, false},
},
{
"ReadType/success",
ReadType,
[]byte{0x0A},
[]interface{}{bsontype.Null, []byte{}, true},
},
{
"ReadKey/not enough bytes",
ReadKey,
[]byte{},
[]interface{}{"", []byte{}, false},
},
{
"ReadKey/success",
ReadKey,
[]byte{'f', 'o', 'o', 'b', 'a', 'r', 0x00},
[]interface{}{"foobar", []byte{}, true},
},
{
"ReadHeader/not enough bytes (type)",
ReadHeader,
[]byte{},
[]interface{}{bsontype.Type(0), "", []byte{}, false},
},
{
"ReadHeader/not enough bytes (key)",
ReadHeader,
[]byte{0x0A, 'f', 'o', 'o'},
[]interface{}{bsontype.Type(0), "", []byte{0x0A, 'f', 'o', 'o'}, false},
},
{
"ReadHeader/success",
ReadHeader,
[]byte{0x0A, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
[]interface{}{bsontype.Null, "foobar", []byte{}, true},
},
{
"ReadDouble/not enough bytes",
ReadDouble,
[]byte{0x01, 0x02, 0x03, 0x04},
[]interface{}{float64(0.00), []byte{0x01, 0x02, 0x03, 0x04}, false},
},
{
"ReadDouble/success",
ReadDouble,
pi,
[]interface{}{float64(3.14159), []byte{}, true},
},
{
"ReadString/not enough bytes (length)",
ReadString,
[]byte{},
[]interface{}{"", []byte{}, false},
},
{
"ReadString/not enough bytes (value)",
ReadString,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{"", []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadString/success",
ReadString,
[]byte{0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
[]interface{}{"foobar", []byte{}, true},
},
{
"ReadDocument/not enough bytes (length)",
ReadDocument,
[]byte{},
[]interface{}{Document(nil), []byte{}, false},
},
{
"ReadDocument/not enough bytes (value)",
ReadDocument,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{Document(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadDocument/success",
ReadDocument,
[]byte{0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00},
[]interface{}{Document{0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00}, []byte{}, true},
},
{
"ReadArray/not enough bytes (length)",
ReadArray,
[]byte{},
[]interface{}{Array(nil), []byte{}, false},
},
{
"ReadArray/not enough bytes (value)",
ReadArray,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{Array(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadArray/success",
ReadArray,
[]byte{0x08, 0x00, 0x00, 0x00, 0x0A, '0', 0x00, 0x00},
[]interface{}{Array{0x08, 0x00, 0x00, 0x00, 0x0A, '0', 0x00, 0x00}, []byte{}, true},
},
{
"ReadBinary/not enough bytes (length)",
ReadBinary,
[]byte{},
[]interface{}{byte(0), []byte(nil), []byte{}, false},
},
{
"ReadBinary/not enough bytes (subtype)",
ReadBinary,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{byte(0), []byte(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadBinary/not enough bytes (value)",
ReadBinary,
[]byte{0x0F, 0x00, 0x00, 0x00, 0x00},
[]interface{}{byte(0), []byte(nil), []byte{0x0F, 0x00, 0x00, 0x00, 0x00}, false},
},
{
"ReadBinary/not enough bytes (subtype 2 length)",
ReadBinary,
[]byte{0x03, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00},
[]interface{}{byte(0), []byte(nil), []byte{0x03, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00}, false},
},
{
"ReadBinary/not enough bytes (subtype 2 value)",
ReadBinary,
[]byte{0x0F, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00, 0x00, 0x01, 0x02},
[]interface{}{
byte(0), []byte(nil),
[]byte{0x0F, 0x00, 0x00, 0x00, 0x02, 0x0F, 0x00, 0x00, 0x00, 0x01, 0x02}, false,
},
},
{
"ReadBinary/success (subtype 2)",
ReadBinary,
[]byte{0x06, 0x00, 0x00, 0x00, 0x02, 0x02, 0x00, 0x00, 0x00, 0x01, 0x02},
[]interface{}{byte(0x02), []byte{0x01, 0x02}, []byte{}, true},
},
{
"ReadBinary/success",
ReadBinary,
[]byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
[]interface{}{byte(0xFF), []byte{0x01, 0x02, 0x03}, []byte{}, true},
},
{
"ReadObjectID/not enough bytes",
ReadObjectID,
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06},
[]interface{}{primitive.ObjectID{}, []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}, false},
},
{
"ReadObjectID/success",
ReadObjectID,
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
[]interface{}{
primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
[]byte{}, true,
},
},
{
"ReadBoolean/not enough bytes",
ReadBoolean,
[]byte{},
[]interface{}{false, []byte{}, false},
},
{
"ReadBoolean/success",
ReadBoolean,
[]byte{0x01},
[]interface{}{true, []byte{}, true},
},
{
"ReadDateTime/not enough bytes",
ReadDateTime,
[]byte{0x01, 0x02, 0x03, 0x04},
[]interface{}{int64(0), []byte{0x01, 0x02, 0x03, 0x04}, false},
},
{
"ReadDateTime/success",
ReadDateTime,
[]byte{0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00},
[]interface{}{int64(65536), []byte{}, true},
},
{
"ReadRegex/not enough bytes (pattern)",
ReadRegex,
[]byte{},
[]interface{}{"", "", []byte{}, false},
},
{
"ReadRegex/not enough bytes (options)",
ReadRegex,
[]byte{'f', 'o', 'o', 0x00},
[]interface{}{"", "", []byte{'f', 'o', 'o', 0x00}, false},
},
{
"ReadRegex/success",
ReadRegex,
[]byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r', 0x00},
[]interface{}{"foo", "bar", []byte{}, true},
},
{
"ReadDBPointer/not enough bytes (ns)",
ReadDBPointer,
[]byte{},
[]interface{}{"", primitive.ObjectID{}, []byte{}, false},
},
{
"ReadDBPointer/not enough bytes (objectID)",
ReadDBPointer,
[]byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00},
[]interface{}{"", primitive.ObjectID{}, []byte{0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00}, false},
},
{
"ReadDBPointer/success",
ReadDBPointer,
[]byte{
0x04, 0x00, 0x00, 0x00, 'f', 'o', 'o', 0x00,
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C,
},
[]interface{}{
"foo", primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C},
[]byte{}, true,
},
},
{
"ReadJavaScript/not enough bytes (length)",
ReadJavaScript,
[]byte{},
[]interface{}{"", []byte{}, false},
},
{
"ReadJavaScript/not enough bytes (value)",
ReadJavaScript,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{"", []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadJavaScript/success",
ReadJavaScript,
[]byte{0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
[]interface{}{"foobar", []byte{}, true},
},
{
"ReadSymbol/not enough bytes (length)",
ReadSymbol,
[]byte{},
[]interface{}{"", []byte{}, false},
},
{
"ReadSymbol/not enough bytes (value)",
ReadSymbol,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{"", []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadSymbol/success",
ReadSymbol,
[]byte{0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00},
[]interface{}{"foobar", []byte{}, true},
},
{
"ReadCodeWithScope/ not enough bytes (length)",
ReadCodeWithScope,
[]byte{},
[]interface{}{"", []byte(nil), []byte{}, false},
},
{
"ReadCodeWithScope/ not enough bytes (value)",
ReadCodeWithScope,
[]byte{0x0F, 0x00, 0x00, 0x00},
[]interface{}{"", []byte(nil), []byte{0x0F, 0x00, 0x00, 0x00}, false},
},
{
"ReadCodeWithScope/not enough bytes (code value)",
ReadCodeWithScope,
[]byte{
0x0C, 0x00, 0x00, 0x00,
0x0F, 0x00, 0x00, 0x00,
'f', 'o', 'o', 0x00,
},
[]interface{}{
"", []byte(nil),
[]byte{
0x0C, 0x00, 0x00, 0x00,
0x0F, 0x00, 0x00, 0x00,
'f', 'o', 'o', 0x00,
},
false,
},
},
{
"ReadCodeWithScope/success",
ReadCodeWithScope,
[]byte{
0x19, 0x00, 0x00, 0x00,
0x07, 0x00, 0x00, 0x00, 'f', 'o', 'o', 'b', 'a', 'r', 0x00,
0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00,
},
[]interface{}{
"foobar", []byte{0x0A, 0x00, 0x00, 0x00, 0x0A, 'f', 'o', 'o', 0x00, 0x00},
[]byte{}, true,
},
},
{
"ReadInt32/not enough bytes",
ReadInt32,
[]byte{0x01},
[]interface{}{int32(0), []byte{0x01}, false},
},
{
"ReadInt32/success",
ReadInt32,
[]byte{0x00, 0x01, 0x00, 0x00},
[]interface{}{int32(256), []byte{}, true},
},
{
"ReadTimestamp/not enough bytes (increment)",
ReadTimestamp,
[]byte{},
[]interface{}{uint32(0), uint32(0), []byte{}, false},
},
{
"ReadTimestamp/not enough bytes (timestamp)",
ReadTimestamp,
[]byte{0x00, 0x01, 0x00, 0x00},
[]interface{}{uint32(0), uint32(0), []byte{0x00, 0x01, 0x00, 0x00}, false},
},
{
"ReadTimestamp/success",
ReadTimestamp,
[]byte{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00},
[]interface{}{uint32(65536), uint32(256), []byte{}, true},
},
{
"ReadInt64/not enough bytes",
ReadInt64,
[]byte{0x01},
[]interface{}{int64(0), []byte{0x01}, false},
},
{
"ReadInt64/success",
ReadInt64,
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
[]interface{}{int64(4294967296), []byte{}, true},
},
{
"ReadDecimal128/not enough bytes (low)",
ReadDecimal128,
[]byte{},
[]interface{}{primitive.Decimal128{}, []byte{}, false},
},
{
"ReadDecimal128/not enough bytes (high)",
ReadDecimal128,
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00},
[]interface{}{primitive.Decimal128{}, []byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00}, false},
},
{
"ReadDecimal128/success",
ReadDecimal128,
[]byte{
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
},
[]interface{}{primitive.NewDecimal128(4294967296, 16777216), []byte{}, true},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
}
if fn.Type().NumIn() != 1 || fn.Type().In(0) != reflect.TypeOf([]byte{}) {
t.Fatalf("fn must have one parameter and it must be a []byte.")
}
results := fn.Call([]reflect.Value{reflect.ValueOf(tc.param)})
if len(results) != len(tc.expected) {
t.Fatalf("Length of results does not match. got %d; want %d", len(results), len(tc.expected))
}
for idx := range results {
got := results[idx].Interface()
want := tc.expected[idx]
if !cmp.Equal(got, want, cmp.Comparer(compareDecimal128)) {
t.Errorf("Result %d does not match. got %v; want %v", idx, got, want)
}
}
})
}
}
func TestBuild(t *testing.T) {
testCases := []struct {
name string
elems [][]byte
want []byte
}{
{
"one element",
[][]byte{AppendDoubleElement(nil, "pi", 3.14159)},
[]byte{0x11, 0x00, 0x00, 0x00, 0x1, 0x70, 0x69, 0x00, 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x9, 0x40, 0x00},
},
{
"two elements",
[][]byte{AppendDoubleElement(nil, "pi", 3.14159), AppendStringElement(nil, "hello", "world!!")},
[]byte{
0x24, 0x00, 0x00, 0x00, 0x01, 0x70, 0x69, 0x00, 0x6e, 0x86, 0x1b, 0xf0,
0xf9, 0x21, 0x09, 0x40, 0x02, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x00, 0x08,
0x00, 0x00, 0x00, 0x77, 0x6f, 0x72, 0x6c, 0x64, 0x21, 0x21, 0x00, 0x00,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("BuildDocument", func(t *testing.T) {
elems := make([]byte, 0)
for _, elem := range tc.elems {
elems = append(elems, elem...)
}
got := BuildDocument(nil, elems)
if !bytes.Equal(got, tc.want) {
t.Errorf("Documents do not match. got %v; want %v", got, tc.want)
}
})
t.Run("BuildDocumentFromElements", func(t *testing.T) {
got := BuildDocumentFromElements(nil, tc.elems...)
if !bytes.Equal(got, tc.want) {
t.Errorf("Documents do not match. got %v; want %v", got, tc.want)
}
})
})
}
}
func TestNullBytes(t *testing.T) {
// Helper function to execute the provided callback and assert that it panics with the expected message. The
// createBSONFn callback should create a BSON document/array/value and return the stringified version.
assertBSONCreationPanics := func(t *testing.T, createBSONFn func(), expected string) {
t.Helper()
defer func() {
got := recover()
assert.Equal(t, expected, got, "expected panic with error %v, got error %v", expected, got)
}()
createBSONFn()
}
t.Run("element keys", func(t *testing.T) {
createDocFn := func() {
NewDocumentBuilder().AppendString("a\x00", "foo")
}
assertBSONCreationPanics(t, createDocFn, invalidKeyPanicMsg)
})
t.Run("regex values", func(t *testing.T) {
testCases := []struct {
name string
pattern string
options string
}{
{"null bytes in pattern", "a\x00", "i"},
{"null bytes in options", "pattern", "i\x00"},
}
for _, tc := range testCases {
t.Run(tc.name+"-AppendRegexElement", func(t *testing.T) {
createDocFn := func() {
AppendRegexElement(nil, "foo", tc.pattern, tc.options)
}
assertBSONCreationPanics(t, createDocFn, invalidRegexPanicMsg)
})
t.Run(tc.name+"-AppendRegex", func(t *testing.T) {
createValFn := func() {
AppendRegex(nil, tc.pattern, tc.options)
}
assertBSONCreationPanics(t, createValFn, invalidRegexPanicMsg)
})
}
})
t.Run("sub document field name", func(t *testing.T) {
createDocFn := func() {
NewDocumentBuilder().StartDocument("foobar").AppendDocument("a\x00", []byte("foo")).FinishDocument()
}
assertBSONCreationPanics(t, createDocFn, invalidKeyPanicMsg)
})
}
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
d1H, d1L := d1.GetBytes()
d2H, d2L := d2.GetBytes()
if d1H != d2H {
return false
}
if d1L != d2L {
return false
}
return true
}

View File

@@ -0,0 +1,386 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"errors"
"fmt"
"io"
"strconv"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ValidationError is an error type returned when attempting to validate a document or array.
type ValidationError string
func (ve ValidationError) Error() string { return string(ve) }
// NewDocumentLengthError creates and returns an error for when the length of a document exceeds the
// bytes available.
func NewDocumentLengthError(length, rem int) error {
return lengthError("document", length, rem)
}
func lengthError(bufferType string, length, rem int) error {
return ValidationError(fmt.Sprintf("%v length exceeds available bytes. length=%d remainingBytes=%d",
bufferType, length, rem))
}
// InsufficientBytesError indicates that there were not enough bytes to read the next component.
type InsufficientBytesError struct {
Source []byte
Remaining []byte
}
// NewInsufficientBytesError creates a new InsufficientBytesError with the given Document and
// remaining bytes.
func NewInsufficientBytesError(src, rem []byte) InsufficientBytesError {
return InsufficientBytesError{Source: src, Remaining: rem}
}
// Error implements the error interface.
func (ibe InsufficientBytesError) Error() string {
return "too few bytes to read next component"
}
// Equal checks that err2 also is an ErrTooSmall.
func (ibe InsufficientBytesError) Equal(err2 error) bool {
switch err2.(type) {
case InsufficientBytesError:
return true
default:
return false
}
}
// InvalidDepthTraversalError is returned when attempting a recursive Lookup when one component of
// the path is neither an embedded document nor an array.
type InvalidDepthTraversalError struct {
Key string
Type bsontype.Type
}
func (idte InvalidDepthTraversalError) Error() string {
return fmt.Sprintf(
"attempt to traverse into %s, but it's type is %s, not %s nor %s",
idte.Key, idte.Type, bsontype.EmbeddedDocument, bsontype.Array,
)
}
// ErrMissingNull is returned when a document or array's last byte is not null.
const ErrMissingNull ValidationError = "document or array end is missing null byte"
// ErrInvalidLength indicates that a length in a binary representation of a BSON document or array
// is invalid.
const ErrInvalidLength ValidationError = "document or array length is invalid"
// ErrNilReader indicates that an operation was attempted on a nil io.Reader.
var ErrNilReader = errors.New("nil reader")
// ErrEmptyKey indicates that no key was provided to a Lookup method.
var ErrEmptyKey = errors.New("empty key provided")
// ErrElementNotFound indicates that an Element matching a certain condition does not exist.
var ErrElementNotFound = errors.New("element not found")
// ErrOutOfBounds indicates that an index provided to access something was invalid.
var ErrOutOfBounds = errors.New("out of bounds")
// Document is a raw bytes representation of a BSON document.
type Document []byte
// NewDocumentFromReader reads a document from r. This function will only validate the length is
// correct and that the document ends with a null byte.
func NewDocumentFromReader(r io.Reader) (Document, error) {
return newBufferFromReader(r)
}
func newBufferFromReader(r io.Reader) ([]byte, error) {
if r == nil {
return nil, ErrNilReader
}
var lengthBytes [4]byte
// ReadFull guarantees that we will have read at least len(lengthBytes) if err == nil
_, err := io.ReadFull(r, lengthBytes[:])
if err != nil {
return nil, err
}
length, _, _ := readi32(lengthBytes[:]) // ignore ok since we always have enough bytes to read a length
if length < 0 {
return nil, ErrInvalidLength
}
buffer := make([]byte, length)
copy(buffer, lengthBytes[:])
_, err = io.ReadFull(r, buffer[4:])
if err != nil {
return nil, err
}
if buffer[length-1] != 0x00 {
return nil, ErrMissingNull
}
return buffer, nil
}
// Lookup searches the document, potentially recursively, for the given key. If there are multiple
// keys provided, this method will recurse down, as long as the top and intermediate nodes are
// either documents or arrays. If an error occurs or if the value doesn't exist, an empty Value is
// returned.
func (d Document) Lookup(key ...string) Value {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr is the same as Lookup, except it returns an error in addition to an empty Value.
func (d Document) LookupErr(key ...string) (Value, error) {
if len(key) < 1 {
return Value{}, ErrEmptyKey
}
length, rem, ok := ReadLength(d)
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return Value{}, NewInsufficientBytesError(d, rem)
}
// We use `KeyBytes` rather than `Key` to avoid a needless string alloc.
if string(elem.KeyBytes()) != key[0] {
continue
}
if len(key) > 1 {
tt := bsontype.Type(elem[0])
switch tt {
case bsontype.EmbeddedDocument:
val, err := elem.Value().Document().LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
case bsontype.Array:
// Convert to Document to continue Lookup recursion.
val, err := Document(elem.Value().Array()).LookupErr(key[1:]...)
if err != nil {
return Value{}, err
}
return val, nil
default:
return Value{}, InvalidDepthTraversalError{Key: elem.Key(), Type: tt}
}
}
return elem.ValueErr()
}
return Value{}, ErrElementNotFound
}
// Index searches for and retrieves the element at the given index. This method will panic if
// the document is invalid or if the index is out of bounds.
func (d Document) Index(index uint) Element {
elem, err := d.IndexErr(index)
if err != nil {
panic(err)
}
return elem
}
// IndexErr searches for and retrieves the element at the given index.
func (d Document) IndexErr(index uint) (Element, error) {
return indexErr(d, index)
}
func indexErr(b []byte, index uint) (Element, error) {
length, rem, ok := ReadLength(b)
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
length -= 4
var current uint
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
if current != index {
current++
continue
}
return elem, nil
}
return nil, ErrOutOfBounds
}
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (d Document) DebugString() string {
if len(d) < 5 {
return "<malformed>"
}
var buf bytes.Buffer
buf.WriteString("Document")
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
buf.WriteByte('(')
buf.WriteString(strconv.Itoa(int(length)))
length -= 4
buf.WriteString("){")
var elem Element
var ok bool
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
buf.WriteString(fmt.Sprintf("<malformed (%d)>", length))
break
}
fmt.Fprintf(&buf, "%s ", elem.DebugString())
}
buf.WriteByte('}')
return buf.String()
}
// String outputs an ExtendedJSON version of Document. If the document is not valid, this method
// returns an empty string.
func (d Document) String() string {
if len(d) < 5 {
return ""
}
var buf bytes.Buffer
buf.WriteByte('{')
length, rem, _ := ReadLength(d) // We know we have enough bytes to read the length
length -= 4
var elem Element
var ok bool
first := true
for length > 1 {
if !first {
buf.WriteByte(',')
}
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return ""
}
fmt.Fprintf(&buf, "%s", elem.String())
first = false
}
buf.WriteByte('}')
return buf.String()
}
// Elements returns this document as a slice of elements. The returned slice will contain valid
// elements. If the document is not valid, the elements up to the invalid point will be returned
// along with an error.
func (d Document) Elements() ([]Element, error) {
length, rem, ok := ReadLength(d)
if !ok {
return nil, NewInsufficientBytesError(d, rem)
}
length -= 4
var elem Element
var elems []Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return elems, NewInsufficientBytesError(d, rem)
}
if err := elem.Validate(); err != nil {
return elems, err
}
elems = append(elems, elem)
}
return elems, nil
}
// Values returns this document as a slice of values. The returned slice will contain valid values.
// If the document is not valid, the values up to the invalid point will be returned along with an
// error.
func (d Document) Values() ([]Value, error) {
return values(d)
}
func values(b []byte) ([]Value, error) {
length, rem, ok := ReadLength(b)
if !ok {
return nil, NewInsufficientBytesError(b, rem)
}
length -= 4
var elem Element
var vals []Value
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return vals, NewInsufficientBytesError(b, rem)
}
if err := elem.Value().Validate(); err != nil {
return vals, err
}
vals = append(vals, elem.Value())
}
return vals, nil
}
// Validate validates the document and ensures the elements contained within are valid.
func (d Document) Validate() error {
length, rem, ok := ReadLength(d)
if !ok {
return NewInsufficientBytesError(d, rem)
}
if int(length) > len(d) {
return NewDocumentLengthError(int(length), len(d))
}
if d[length-1] != 0x00 {
return ErrMissingNull
}
length -= 4
var elem Element
for length > 1 {
elem, rem, ok = ReadElement(rem)
length -= int32(len(elem))
if !ok {
return NewInsufficientBytesError(d, rem)
}
err := elem.Validate()
if err != nil {
return err
}
}
if len(rem) < 1 || rem[0] != 0x00 {
return ErrMissingNull
}
return nil
}

View File

@@ -0,0 +1,189 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"errors"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// DocumentSequenceStyle is used to represent how a document sequence is laid out in a slice of
// bytes.
type DocumentSequenceStyle uint32
// These constants are the valid styles for a DocumentSequence.
const (
_ DocumentSequenceStyle = iota
SequenceStyle
ArrayStyle
)
// DocumentSequence represents a sequence of documents. The Style field indicates how the documents
// are laid out inside of the Data field.
type DocumentSequence struct {
Style DocumentSequenceStyle
Data []byte
Pos int
}
// ErrCorruptedDocument is returned when a full document couldn't be read from the sequence.
var ErrCorruptedDocument = errors.New("invalid DocumentSequence: corrupted document")
// ErrNonDocument is returned when a DocumentSequence contains a non-document BSON value.
var ErrNonDocument = errors.New("invalid DocumentSequence: a non-document value was found in sequence")
// ErrInvalidDocumentSequenceStyle is returned when an unknown DocumentSequenceStyle is set on a
// DocumentSequence.
var ErrInvalidDocumentSequenceStyle = errors.New("invalid DocumentSequenceStyle")
// DocumentCount returns the number of documents in the sequence.
func (ds *DocumentSequence) DocumentCount() int {
if ds == nil {
return 0
}
switch ds.Style {
case SequenceStyle:
var count int
var ok bool
rem := ds.Data
for len(rem) > 0 {
_, rem, ok = ReadDocument(rem)
if !ok {
return 0
}
count++
}
return count
case ArrayStyle:
_, rem, ok := ReadLength(ds.Data)
if !ok {
return 0
}
var count int
for len(rem) > 1 {
_, rem, ok = ReadElement(rem)
if !ok {
return 0
}
count++
}
return count
default:
return 0
}
}
// Empty returns true if the sequence is empty. It always returns true for unknown sequence styles.
func (ds *DocumentSequence) Empty() bool {
if ds == nil {
return true
}
switch ds.Style {
case SequenceStyle:
return len(ds.Data) == 0
case ArrayStyle:
return len(ds.Data) <= 5
default:
return true
}
}
// ResetIterator resets the iteration point for the Next method to the beginning of the document
// sequence.
func (ds *DocumentSequence) ResetIterator() {
if ds == nil {
return
}
ds.Pos = 0
}
// Documents returns a slice of the documents. If nil either the Data field is also nil or could not
// be properly read.
func (ds *DocumentSequence) Documents() ([]Document, error) {
if ds == nil {
return nil, nil
}
switch ds.Style {
case SequenceStyle:
rem := ds.Data
var docs []Document
var doc Document
var ok bool
for {
doc, rem, ok = ReadDocument(rem)
if !ok {
if len(rem) == 0 {
break
}
return nil, ErrCorruptedDocument
}
docs = append(docs, doc)
}
return docs, nil
case ArrayStyle:
if len(ds.Data) == 0 {
return nil, nil
}
vals, err := Document(ds.Data).Values()
if err != nil {
return nil, ErrCorruptedDocument
}
docs := make([]Document, 0, len(vals))
for _, v := range vals {
if v.Type != bsontype.EmbeddedDocument {
return nil, ErrNonDocument
}
docs = append(docs, v.Data)
}
return docs, nil
default:
return nil, ErrInvalidDocumentSequenceStyle
}
}
// Next retrieves the next document from this sequence and returns it. This method will return
// io.EOF when it has reached the end of the sequence.
func (ds *DocumentSequence) Next() (Document, error) {
if ds == nil || ds.Pos >= len(ds.Data) {
return nil, io.EOF
}
switch ds.Style {
case SequenceStyle:
doc, _, ok := ReadDocument(ds.Data[ds.Pos:])
if !ok {
return nil, ErrCorruptedDocument
}
ds.Pos += len(doc)
return doc, nil
case ArrayStyle:
if ds.Pos < 4 {
if len(ds.Data) < 4 {
return nil, ErrCorruptedDocument
}
ds.Pos = 4 // Skip the length of the document
}
if len(ds.Data[ds.Pos:]) == 1 && ds.Data[ds.Pos] == 0x00 {
return nil, io.EOF // At the end of the document
}
elem, _, ok := ReadElement(ds.Data[ds.Pos:])
if !ok {
return nil, ErrCorruptedDocument
}
ds.Pos += len(elem)
val := elem.Value()
if val.Type != bsontype.EmbeddedDocument {
return nil, ErrNonDocument
}
return val.Data, nil
default:
return nil, ErrInvalidDocumentSequenceStyle
}
}

View File

@@ -0,0 +1,421 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"io"
"strconv"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestDocumentSequence(t *testing.T) {
genArrayStyle := func(num int) []byte {
idx, seq := AppendDocumentStart(nil)
for i := 0; i < num; i++ {
seq = AppendDocumentElement(
seq, strconv.Itoa(i),
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
)
}
seq, _ = AppendDocumentEnd(seq, idx)
return seq
}
genSequenceStyle := func(num int) []byte {
var seq []byte
for i := 0; i < num; i++ {
seq = append(seq, BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159))...)
}
return seq
}
idx, arrayStyle := AppendDocumentStart(nil)
idx2, arrayStyle := AppendDocumentElementStart(arrayStyle, "0")
arrayStyle = AppendDoubleElement(arrayStyle, "pi", 3.14159)
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
idx2, arrayStyle = AppendDocumentElementStart(arrayStyle, "1")
arrayStyle = AppendStringElement(arrayStyle, "hello", "world")
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx)
t.Run("Documents", func(t *testing.T) {
testCases := []struct {
name string
style DocumentSequenceStyle
data []byte
documents []Document
err error
}{
{
"SequenceStle/corrupted document",
SequenceStyle,
[]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
nil,
ErrCorruptedDocument,
},
{
"SequenceStyle/success",
SequenceStyle,
BuildDocument(
BuildDocument(
nil,
AppendStringElement(AppendDoubleElement(nil, "pi", 3.14159), "hello", "world"),
),
AppendDoubleElement(AppendStringElement(nil, "hello", "world"), "pi", 3.14159),
),
[]Document{
BuildDocument(nil, AppendStringElement(AppendDoubleElement(nil, "pi", 3.14159), "hello", "world")),
BuildDocument(nil, AppendDoubleElement(AppendStringElement(nil, "hello", "world"), "pi", 3.14159)),
},
nil,
},
{
"ArrayStyle/insufficient bytes",
ArrayStyle,
[]byte{0x01, 0x02, 0x03, 0x04, 0x05},
nil,
ErrCorruptedDocument,
},
{
"ArrayStyle/non-document",
ArrayStyle,
BuildDocument(nil, AppendDoubleElement(nil, "0", 12345.67890)),
nil,
ErrNonDocument,
},
{
"ArrayStyle/success",
ArrayStyle,
arrayStyle,
[]Document{
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
BuildDocument(nil, AppendStringElement(nil, "hello", "world")),
},
nil,
},
{"Invalid DocumentSequenceStyle", 0, nil, nil, ErrInvalidDocumentSequenceStyle},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ds := &DocumentSequence{
Style: tc.style,
Data: tc.data,
}
documents, err := ds.Documents()
if !cmp.Equal(documents, tc.documents) {
t.Errorf("Documents do not match. got %v; want %v", documents, tc.documents)
}
if err != tc.err {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("Next", func(t *testing.T) {
seqDoc := BuildDocument(
BuildDocument(
nil,
AppendDoubleElement(nil, "pi", 3.14159),
),
AppendStringElement(nil, "hello", "world"),
)
idx, arrayStyle := AppendDocumentStart(nil)
idx2, arrayStyle := AppendDocumentElementStart(arrayStyle, "0")
arrayStyle = AppendDoubleElement(arrayStyle, "pi", 3.14159)
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
idx2, arrayStyle = AppendDocumentElementStart(arrayStyle, "1")
arrayStyle = AppendStringElement(arrayStyle, "hello", "world")
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx2)
arrayStyle, _ = AppendDocumentEnd(arrayStyle, idx)
testCases := []struct {
name string
style DocumentSequenceStyle
data []byte
pos int
document Document
err error
}{
{"io.EOF", 0, make([]byte, 10), 10, nil, io.EOF},
{
"SequenceStyle/corrupted document",
SequenceStyle,
[]byte{0x01, 0x02, 0x03, 0x04},
0,
nil,
ErrCorruptedDocument,
},
{
"SequenceStyle/success/first",
SequenceStyle,
seqDoc,
0,
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
nil,
},
{
"SequenceStyle/success/second",
SequenceStyle,
seqDoc,
17,
BuildDocument(nil, AppendStringElement(nil, "hello", "world")),
nil,
},
{
"ArrayStyle/corrupted document/too short",
ArrayStyle,
[]byte{0x01, 0x02, 0x03},
0,
nil,
ErrCorruptedDocument,
},
{
"ArrayStyle/corrupted document/invalid element",
ArrayStyle,
[]byte{0x00, 0x00, 0x00, 0x00, 0x01, '0', 0x00, 0x01, 0x02},
0,
nil,
ErrCorruptedDocument,
},
{
"ArrayStyle/non-document",
ArrayStyle,
BuildDocument(nil, AppendDoubleElement(nil, "0", 12345.67890)),
0,
nil,
ErrNonDocument,
},
{
"ArrayStyle/success/first",
ArrayStyle,
arrayStyle,
0,
BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159)),
nil,
},
{
"ArrayStyle/success/second",
ArrayStyle,
arrayStyle,
24,
BuildDocument(nil, AppendStringElement(nil, "hello", "world")),
nil,
},
{"Invalid DocumentSequenceStyle", 0, make([]byte, 4), 0, nil, ErrInvalidDocumentSequenceStyle},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ds := &DocumentSequence{
Style: tc.style,
Data: tc.data,
Pos: tc.pos,
}
document, err := ds.Next()
if !bytes.Equal(document, tc.document) {
t.Errorf("Documents do not match. got %v; want %v", document, tc.document)
}
if err != tc.err {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("Full Iteration", func(t *testing.T) {
testCases := []struct {
name string
style DocumentSequenceStyle
data []byte
count int
}{
{"SequenceStyle/success/nil", SequenceStyle, nil, 0},
{"SequenceStyle/success/0", SequenceStyle, []byte{}, 0},
{"SequenceStyle/success/1", SequenceStyle, genSequenceStyle(1), 1},
{"SequenceStyle/success/2", SequenceStyle, genSequenceStyle(2), 2},
{"SequenceStyle/success/10", SequenceStyle, genSequenceStyle(10), 10},
{"SequenceStyle/success/100", SequenceStyle, genSequenceStyle(100), 100},
{"ArrayStyle/success/nil", ArrayStyle, nil, 0},
{"ArrayStyle/success/0", ArrayStyle, []byte{0x05, 0x00, 0x00, 0x00, 0x00}, 0},
{"ArrayStyle/success/1", ArrayStyle, genArrayStyle(1), 1},
{"ArrayStyle/success/2", ArrayStyle, genArrayStyle(2), 2},
{"ArrayStyle/success/10", ArrayStyle, genArrayStyle(10), 10},
{"ArrayStyle/success/100", ArrayStyle, genArrayStyle(100), 100},
}
for _, tc := range testCases {
t.Run("Documents/"+tc.name, func(t *testing.T) {
ds := &DocumentSequence{
Style: tc.style,
Data: tc.data,
}
docs, err := ds.Documents()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
count := len(docs)
if count != tc.count {
t.Errorf("Coun't fully iterate documents, wrong count. got %v; want %v", count, tc.count)
}
})
t.Run("Next/"+tc.name, func(t *testing.T) {
ds := &DocumentSequence{
Style: tc.style,
Data: tc.data,
}
var docs []Document
for {
doc, err := ds.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
docs = append(docs, doc)
}
count := len(docs)
if count != tc.count {
t.Errorf("Coun't fully iterate documents, wrong count. got %v; want %v", count, tc.count)
}
})
}
})
t.Run("DocumentCount", func(t *testing.T) {
testCases := []struct {
name string
style DocumentSequenceStyle
data []byte
count int
}{
{
"SequenceStyle/corrupt document/first",
SequenceStyle,
[]byte{0x01, 0x02, 0x03},
0,
},
{
"SequenceStyle/corrupt document/second",
SequenceStyle,
[]byte{0x05, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03},
0,
},
{"SequenceStyle/success/nil", SequenceStyle, nil, 0},
{"SequenceStyle/success/0", SequenceStyle, []byte{}, 0},
{"SequenceStyle/success/1", SequenceStyle, genSequenceStyle(1), 1},
{"SequenceStyle/success/2", SequenceStyle, genSequenceStyle(2), 2},
{"SequenceStyle/success/10", SequenceStyle, genSequenceStyle(10), 10},
{"SequenceStyle/success/100", SequenceStyle, genSequenceStyle(100), 100},
{
"ArrayStyle/corrupt document/length",
ArrayStyle,
[]byte{0x01, 0x02, 0x03},
0,
},
{
"ArrayStyle/corrupt element/first",
ArrayStyle,
BuildDocument(nil, []byte{0x01, 0x00, 0x03, 0x04, 0x05}),
0,
},
{
"ArrayStyle/corrupt element/second",
ArrayStyle,
BuildDocument(nil, []byte{0x0A, 0x00, 0x01, 0x00, 0x03, 0x04, 0x05}),
0,
},
{"ArrayStyle/success/nil", ArrayStyle, nil, 0},
{"ArrayStyle/success/0", ArrayStyle, []byte{0x05, 0x00, 0x00, 0x00, 0x00}, 0},
{"ArrayStyle/success/1", ArrayStyle, genArrayStyle(1), 1},
{"ArrayStyle/success/2", ArrayStyle, genArrayStyle(2), 2},
{"ArrayStyle/success/10", ArrayStyle, genArrayStyle(10), 10},
{"ArrayStyle/success/100", ArrayStyle, genArrayStyle(100), 100},
{"Invalid DocumentSequenceStyle", 0, nil, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ds := &DocumentSequence{
Style: tc.style,
Data: tc.data,
}
count := ds.DocumentCount()
if count != tc.count {
t.Errorf("Document counts don't match. got %v; want %v", count, tc.count)
}
})
}
})
t.Run("Empty", func(t *testing.T) {
testCases := []struct {
name string
ds *DocumentSequence
isEmpty bool
}{
{"ArrayStyle/is empty/nil", nil, true},
{"ArrayStyle/is empty/0", &DocumentSequence{Style: ArrayStyle, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}}, true},
{"ArrayStyle/is not empty/non-0", &DocumentSequence{Style: ArrayStyle, Data: genArrayStyle(10)}, false},
{"SequenceStyle/is empty/nil", nil, true},
{"SequenceStyle/is empty/0", &DocumentSequence{Style: SequenceStyle, Data: []byte{}}, true},
{"SequenceStyle/is not empty/non-0", &DocumentSequence{Style: SequenceStyle, Data: genSequenceStyle(10)}, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
isEmpty := tc.ds.Empty()
if isEmpty != tc.isEmpty {
t.Errorf("Unexpected Empty result. got %v; want %v", isEmpty, tc.isEmpty)
}
})
}
})
t.Run("ResetIterator", func(t *testing.T) {
ds := &DocumentSequence{Pos: 1234567890}
want := 0
ds.ResetIterator()
if ds.Pos != want {
t.Errorf("Unexpected position after ResetIterator. got %d; want %d", ds.Pos, want)
}
})
t.Run("no panic on nil", func(t *testing.T) {
capturePanic := func() {
if err := recover(); err != nil {
t.Errorf("Unexpected panic. got %v; want <nil>", err)
}
}
t.Run("DocumentCount", func(t *testing.T) {
defer capturePanic()
var ds *DocumentSequence
_ = ds.DocumentCount()
})
t.Run("Empty", func(t *testing.T) {
defer capturePanic()
var ds *DocumentSequence
_ = ds.Empty()
})
t.Run("ResetIterator", func(t *testing.T) {
defer capturePanic()
var ds *DocumentSequence
ds.ResetIterator()
})
t.Run("Documents", func(t *testing.T) {
defer capturePanic()
var ds *DocumentSequence
_, _ = ds.Documents()
})
t.Run("Next", func(t *testing.T) {
defer capturePanic()
var ds *DocumentSequence
_, _ = ds.Next()
})
})
}

View File

@@ -0,0 +1,412 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
func ExampleDocument_Validate() {
doc := make(Document, 500)
doc[250], doc[251], doc[252], doc[253], doc[254] = 0x05, 0x00, 0x00, 0x00, 0x00
err := doc[250:].Validate()
fmt.Println(err)
// Output: <nil>
}
func BenchmarkDocumentValidate(b *testing.B) {
for i := 0; i < b.N; i++ {
doc := make(Document, 500)
doc[250], doc[251], doc[252], doc[253], doc[254] = 0x05, 0x00, 0x00, 0x00, 0x00
_ = doc[250:].Validate()
}
}
func TestDocument(t *testing.T) {
t.Run("Validate", func(t *testing.T) {
t.Run("TooShort", func(t *testing.T) {
want := NewInsufficientBytesError(nil, nil)
got := Document{'\x00', '\x00'}.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("InvalidLength", func(t *testing.T) {
want := NewDocumentLengthError(200, 5)
r := make(Document, 5)
binary.LittleEndian.PutUint32(r[0:4], 200)
got := r.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("Invalid Element", func(t *testing.T) {
want := NewInsufficientBytesError(nil, nil)
r := make(Document, 9)
binary.LittleEndian.PutUint32(r[0:4], 9)
r[4], r[5], r[6], r[7], r[8] = 0x02, 'f', 'o', 'o', 0x00
got := r.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("Missing Null Terminator", func(t *testing.T) {
want := ErrMissingNull
r := make(Document, 8)
binary.LittleEndian.PutUint32(r[0:4], 8)
r[4], r[5], r[6], r[7] = 0x0A, 'f', 'o', 'o'
got := r.Validate()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
testCases := []struct {
name string
r Document
want error
}{
{"null", Document{'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00'}, nil},
{"subdocument",
Document{
'\x15', '\x00', '\x00', '\x00',
'\x03',
'f', 'o', 'o', '\x00',
'\x0B', '\x00', '\x00', '\x00', '\x0A', 'a', '\x00',
'\x0A', 'b', '\x00', '\x00', '\x00',
},
nil,
},
{"array",
Document{
'\x15', '\x00', '\x00', '\x00',
'\x04',
'f', 'o', 'o', '\x00',
'\x0B', '\x00', '\x00', '\x00', '\x0A', '1', '\x00',
'\x0A', '2', '\x00', '\x00', '\x00',
},
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got := tc.r.Validate()
if !compareErrors(got, tc.want) {
t.Errorf("Returned error does not match. got %v; want %v", got, tc.want)
}
})
}
})
t.Run("Lookup", func(t *testing.T) {
t.Run("empty-key", func(t *testing.T) {
rdr := Document{'\x05', '\x00', '\x00', '\x00', '\x00'}
_, err := rdr.LookupErr()
if err != ErrEmptyKey {
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", err, ErrEmptyKey)
}
})
t.Run("corrupted-subdocument", func(t *testing.T) {
rdr := Document{
'\x0D', '\x00', '\x00', '\x00',
'\x03', 'x', '\x00',
'\x06', '\x00', '\x00', '\x00',
'\x01',
'\x00',
'\x00',
}
_, got := rdr.LookupErr("x", "y")
want := NewInsufficientBytesError(nil, nil)
if !cmp.Equal(got, want) {
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", got, want)
}
})
t.Run("corrupted-array", func(t *testing.T) {
rdr := Document{
'\x0D', '\x00', '\x00', '\x00',
'\x04', 'x', '\x00',
'\x06', '\x00', '\x00', '\x00',
'\x01',
'\x00',
'\x00',
}
_, got := rdr.LookupErr("x", "y")
want := NewInsufficientBytesError(nil, nil)
if !cmp.Equal(got, want) {
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", got, want)
}
})
t.Run("invalid-traversal", func(t *testing.T) {
rdr := Document{'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00'}
_, got := rdr.LookupErr("x", "y")
want := InvalidDepthTraversalError{Key: "x", Type: bsontype.Null}
if !compareErrors(got, want) {
t.Errorf("Empty key lookup did not return expected result. got %v; want %v", got, want)
}
})
testCases := []struct {
name string
r Document
key []string
want Value
err error
}{
{"first",
Document{
'\x08', '\x00', '\x00', '\x00', '\x0A', 'x', '\x00', '\x00',
},
[]string{"x"},
Value{Type: bsontype.Null, Data: []byte{}},
nil,
},
{"first-second",
Document{
'\x15', '\x00', '\x00', '\x00',
'\x03',
'f', 'o', 'o', '\x00',
'\x0B', '\x00', '\x00', '\x00', '\x0A', 'a', '\x00',
'\x0A', 'b', '\x00', '\x00', '\x00',
},
[]string{"foo", "b"},
Value{Type: bsontype.Null, Data: []byte{}},
nil,
},
{"first-second-array",
Document{
'\x15', '\x00', '\x00', '\x00',
'\x04',
'f', 'o', 'o', '\x00',
'\x0B', '\x00', '\x00', '\x00', '\x0A', '1', '\x00',
'\x0A', '2', '\x00', '\x00', '\x00',
},
[]string{"foo", "2"},
Value{Type: bsontype.Null, Data: []byte{}},
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Lookup", func(t *testing.T) {
got := tc.r.Lookup(tc.key...)
if !cmp.Equal(got, tc.want) {
t.Errorf("Returned value does not match expected element. got %v; want %v", got, tc.want)
}
})
t.Run("LookupErr", func(t *testing.T) {
got, err := tc.r.LookupErr(tc.key...)
if err != tc.err {
t.Errorf("Returned error does not match. got %v; want %v", err, tc.err)
}
if !cmp.Equal(got, tc.want) {
t.Errorf("Returned value does not match expected element. got %v; want %v", got, tc.want)
}
})
})
}
})
t.Run("Index", func(t *testing.T) {
t.Run("Out of bounds", func(t *testing.T) {
rdr := Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0}
_, err := rdr.IndexErr(3)
if err != ErrOutOfBounds {
t.Errorf("Out of bounds should be returned when accessing element beyond end of document. got %v; want %v", err, ErrOutOfBounds)
}
})
t.Run("Validation Error", func(t *testing.T) {
rdr := Document{0x07, 0x00, 0x00, 0x00, 0x00}
_, got := rdr.IndexErr(1)
want := NewInsufficientBytesError(nil, nil)
if !compareErrors(got, want) {
t.Errorf("Did not receive expected error. got %v; want %v", got, want)
}
})
testCases := []struct {
name string
rdr Document
index uint
want Element
}{
{"first",
Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0},
0, Element{0x0a, 0x78, 0x00},
},
{"second",
Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0},
1, Element{0x0a, 0x79, 0x00},
},
{"third",
Document{0xe, 0x0, 0x0, 0x0, 0xa, 0x78, 0x0, 0xa, 0x79, 0x0, 0xa, 0x7a, 0x0, 0x0},
2, Element{0x0a, 0x7a, 0x00},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("IndexErr", func(t *testing.T) {
got, err := tc.rdr.IndexErr(tc.index)
if err != nil {
t.Errorf("Unexpected error from IndexErr: %s", err)
}
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("Documents differ: (-got +want)\n%s", diff)
}
})
t.Run("Index", func(t *testing.T) {
defer func() {
if err := recover(); err != nil {
t.Errorf("Unexpected error: %v", err)
}
}()
got := tc.rdr.Index(tc.index)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("Documents differ: (-got +want)\n%s", diff)
}
})
})
}
})
t.Run("NewDocumentFromReader", func(t *testing.T) {
testCases := []struct {
name string
ioReader io.Reader
doc Document
err error
}{
{
"nil reader",
nil,
nil,
ErrNilReader,
},
{
"premature end of reader",
bytes.NewBuffer([]byte{}),
nil,
io.EOF,
},
{
"empty document",
bytes.NewBuffer([]byte{5, 0, 0, 0, 0}),
[]byte{5, 0, 0, 0, 0},
nil,
},
{
"non-empty document",
bytes.NewBuffer([]byte{
// length
0x17, 0x0, 0x0, 0x0,
// type - string
0x2,
// key - "foo"
0x66, 0x6f, 0x6f, 0x0,
// value - string length
0x4, 0x0, 0x0, 0x0,
// value - string "bar"
0x62, 0x61, 0x72, 0x0,
// type - null
0xa,
// key - "baz"
0x62, 0x61, 0x7a, 0x0,
// null terminator
0x0,
}),
[]byte{
// length
0x17, 0x0, 0x0, 0x0,
// type - string
0x2,
// key - "foo"
0x66, 0x6f, 0x6f, 0x0,
// value - string length
0x4, 0x0, 0x0, 0x0,
// value - string "bar"
0x62, 0x61, 0x72, 0x0,
// type - null
0xa,
// key - "baz"
0x62, 0x61, 0x7a, 0x0,
// null terminator
0x0,
},
nil,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
doc, err := NewDocumentFromReader(tc.ioReader)
if !compareErrors(err, tc.err) {
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
}
if !bytes.Equal(tc.doc, doc) {
t.Errorf("documents differ. got %v; want %v", tc.doc, doc)
}
})
}
})
t.Run("Elements", func(t *testing.T) {
invalidElem := BuildDocument(nil, AppendHeader(nil, bsontype.Double, "foo"))
invalidTwoElem := BuildDocument(nil,
AppendHeader(
AppendDoubleElement(nil, "pi", 3.14159),
bsontype.Double, "foo",
),
)
oneElem := BuildDocument(nil, AppendDoubleElement(nil, "pi", 3.14159))
twoElems := BuildDocument(nil,
AppendStringElement(
AppendDoubleElement(nil, "pi", 3.14159),
"hello", "world!",
),
)
testCases := []struct {
name string
doc Document
elems []Element
err error
}{
{"Insufficient Bytes Length", Document{0x03, 0x00, 0x00}, nil, NewInsufficientBytesError(nil, nil)},
{"Insufficient Bytes First Element", invalidElem, nil, NewInsufficientBytesError(nil, nil)},
{"Insufficient Bytes Second Element", invalidTwoElem, []Element{AppendDoubleElement(nil, "pi", 3.14159)}, NewInsufficientBytesError(nil, nil)},
{"Success One Element", oneElem, []Element{AppendDoubleElement(nil, "pi", 3.14159)}, nil},
{"Success Two Elements", twoElems, []Element{AppendDoubleElement(nil, "pi", 3.14159), AppendStringElement(nil, "hello", "world!")}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
elems, err := tc.doc.Elements()
if !compareErrors(err, tc.err) {
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
}
if len(elems) != len(tc.elems) {
t.Fatalf("number of elements returned does not match. got %d; want %d", len(elems), len(tc.elems))
}
for idx := range elems {
got, want := elems[idx], tc.elems[idx]
if !bytes.Equal(got, want) {
t.Errorf("Elements at index %d differ. got %v; want %v", idx, got.DebugString(), want.DebugString())
}
}
})
}
})
}

View File

@@ -0,0 +1,152 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// MalformedElementError represents a class of errors that RawElement methods return.
type MalformedElementError string
func (mee MalformedElementError) Error() string { return string(mee) }
// ErrElementMissingKey is returned when a RawElement is missing a key.
const ErrElementMissingKey MalformedElementError = "element is missing key"
// ErrElementMissingType is returned when a RawElement is missing a type.
const ErrElementMissingType MalformedElementError = "element is missing type"
// Element is a raw bytes representation of a BSON element.
type Element []byte
// Key returns the key for this element. If the element is not valid, this method returns an empty
// string. If knowing if the element is valid is important, use KeyErr.
func (e Element) Key() string {
key, _ := e.KeyErr()
return key
}
// KeyBytes returns the key for this element as a []byte. If the element is not valid, this method
// returns an empty string. If knowing if the element is valid is important, use KeyErr. This method
// will not include the null byte at the end of the key in the slice of bytes.
func (e Element) KeyBytes() []byte {
key, _ := e.KeyBytesErr()
return key
}
// KeyErr returns the key for this element, returning an error if the element is not valid.
func (e Element) KeyErr() (string, error) {
key, err := e.KeyBytesErr()
return string(key), err
}
// KeyBytesErr returns the key for this element as a []byte, returning an error if the element is
// not valid.
func (e Element) KeyBytesErr() ([]byte, error) {
if len(e) <= 0 {
return nil, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return nil, ErrElementMissingKey
}
return e[1 : idx+1], nil
}
// Validate ensures the element is a valid BSON element.
func (e Element) Validate() error {
if len(e) < 1 {
return ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return ErrElementMissingKey
}
return Value{Type: bsontype.Type(e[0]), Data: e[idx+2:]}.Validate()
}
// CompareKey will compare this element's key to key. This method makes it easy to compare keys
// without needing to allocate a string. The key may be null terminated. If a valid key cannot be
// read this method will return false.
func (e Element) CompareKey(key []byte) bool {
if len(e) < 2 {
return false
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return false
}
if index := bytes.IndexByte(key, 0x00); index > -1 {
key = key[:index]
}
return bytes.Equal(e[1:idx+1], key)
}
// Value returns the value of this element. If the element is not valid, this method returns an
// empty Value. If knowing if the element is valid is important, use ValueErr.
func (e Element) Value() Value {
val, _ := e.ValueErr()
return val
}
// ValueErr returns the value for this element, returning an error if the element is not valid.
func (e Element) ValueErr() (Value, error) {
if len(e) <= 0 {
return Value{}, ErrElementMissingType
}
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return Value{}, ErrElementMissingKey
}
val, rem, exists := ReadValue(e[idx+2:], bsontype.Type(e[0]))
if !exists {
return Value{}, NewInsufficientBytesError(e, rem)
}
return val, nil
}
// String implements the fmt.String interface. The output will be in extended JSON format.
func (e Element) String() string {
if len(e) <= 0 {
return ""
}
t := bsontype.Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return ""
}
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
val, _, valid := ReadValue(valBytes, t)
if !valid {
return ""
}
return fmt.Sprintf(`"%s": %v`, key, val)
}
// DebugString outputs a human readable version of RawElement. It will attempt to stringify the
// valid components of the element even if the entire element is not valid.
func (e Element) DebugString() string {
if len(e) <= 0 {
return "<malformed>"
}
t := bsontype.Type(e[0])
idx := bytes.IndexByte(e[1:], 0x00)
if idx == -1 {
return fmt.Sprintf(`bson.Element{[%s]<malformed>}`, t)
}
key, valBytes := []byte(e[1:idx+1]), []byte(e[idx+2:])
val, _, valid := ReadValue(valBytes, t)
if !valid {
return fmt.Sprintf(`bson.Element{[%s]"%s": <malformed>}`, t, key)
}
return fmt.Sprintf(`bson.Element{[%s]"%s": %v}`, t, key, val)
}

View File

@@ -0,0 +1,127 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
func TestElement(t *testing.T) {
t.Run("KeyErr", func(t *testing.T) {
testCases := []struct {
name string
elem Element
str string
err error
}{
{"No Type", Element{}, "", ErrElementMissingType},
{"No Key", Element{0x01, 'f', 'o', 'o'}, "", ErrElementMissingKey},
{"Success", AppendHeader(nil, bsontype.Double, "foo"), "foo", nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Key", func(t *testing.T) {
str := tc.elem.Key()
if str != tc.str {
t.Errorf("returned strings do not match. got %s; want %s", str, tc.str)
}
})
t.Run("KeyErr", func(t *testing.T) {
str, err := tc.elem.KeyErr()
if !compareErrors(err, tc.err) {
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
}
if str != tc.str {
t.Errorf("returned strings do not match. got %s; want %s", str, tc.str)
}
})
})
}
})
t.Run("Validate", func(t *testing.T) {
testCases := []struct {
name string
elem Element
err error
}{
{"No Type", Element{}, ErrElementMissingType},
{"No Key", Element{0x01, 'f', 'o', 'o'}, ErrElementMissingKey},
{"Insufficient Bytes", AppendHeader(nil, bsontype.Double, "foo"), NewInsufficientBytesError(nil, nil)},
{"Success", AppendDoubleElement(nil, "foo", 3.14159), nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.elem.Validate()
if !compareErrors(err, tc.err) {
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("CompareKey", func(t *testing.T) {
testCases := []struct {
name string
elem Element
key []byte
equal bool
}{
{"Element Too Short", Element{0x02}, nil, false},
{"Element Invalid Key", Element{0x02, 'f', 'o', 'o'}, nil, false},
{"Key With Null Byte", AppendHeader(nil, bsontype.Double, "foo"), []byte{'f', 'o', 'o', 0x00}, true},
{"Key Without Null Byte", AppendHeader(nil, bsontype.Double, "pi"), []byte{'p', 'i'}, true},
{"Key With Null Byte With Extra", AppendHeader(nil, bsontype.Double, "foo"), []byte{'f', 'o', 'o', 0x00, 'b', 'a', 'r'}, true},
{"Prefix Key No Match", AppendHeader(nil, bsontype.Double, "foo"), []byte{'f', 'o', 'o', 'b', 'a', 'r'}, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
equal := tc.elem.CompareKey(tc.key)
if equal != tc.equal {
t.Errorf("Did not get expected equality result. got %t; want %t", equal, tc.equal)
}
})
}
})
t.Run("Value & ValueErr", func(t *testing.T) {
testCases := []struct {
name string
elem Element
val Value
err error
}{
{"No Type", Element{}, Value{}, ErrElementMissingType},
{"No Key", Element{0x01, 'f', 'o', 'o'}, Value{}, ErrElementMissingKey},
{"Insufficient Bytes", AppendHeader(nil, bsontype.Double, "foo"), Value{}, NewInsufficientBytesError(nil, nil)},
{"Success", AppendDoubleElement(nil, "foo", 3.14159), Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Value", func(t *testing.T) {
val := tc.elem.Value()
if !cmp.Equal(val, tc.val) {
t.Errorf("Values do not match. got %v; want %v", val, tc.val)
}
})
t.Run("ValueErr", func(t *testing.T) {
val, err := tc.elem.ValueErr()
if !compareErrors(err, tc.err) {
t.Errorf("errors do not match. got %v; want %v", err, tc.err)
}
if !cmp.Equal(val, tc.val) {
t.Errorf("Values do not match. got %v; want %v", val, tc.val)
}
})
})
}
})
}

View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsoncore
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

View File

@@ -0,0 +1,972 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"bytes"
"encoding/base64"
"fmt"
"math"
"sort"
"strconv"
"strings"
"time"
"unicode/utf8"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ElementTypeError specifies that a method to obtain a BSON value an incorrect type was called on a bson.Value.
type ElementTypeError struct {
Method string
Type bsontype.Type
}
// Error implements the error interface.
func (ete ElementTypeError) Error() string {
return "Call of " + ete.Method + " on " + ete.Type.String() + " type"
}
// Value represents a BSON value with a type and raw bytes.
type Value struct {
Type bsontype.Type
Data []byte
}
// Validate ensures the value is a valid BSON value.
func (v Value) Validate() error {
_, _, valid := readValue(v.Data, v.Type)
if !valid {
return NewInsufficientBytesError(v.Data, v.Data)
}
return nil
}
// IsNumber returns true if the type of v is a numeric BSON type.
func (v Value) IsNumber() bool {
switch v.Type {
case bsontype.Double, bsontype.Int32, bsontype.Int64, bsontype.Decimal128:
return true
default:
return false
}
}
// AsInt32 returns a BSON number as an int32. If the BSON type is not a numeric one, this method
// will panic.
func (v Value) AsInt32() int32 {
if !v.IsNumber() {
panic(ElementTypeError{"bsoncore.Value.AsInt32", v.Type})
}
var i32 int32
switch v.Type {
case bsontype.Double:
f64, _, ok := ReadDouble(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
i32 = int32(f64)
case bsontype.Int32:
var ok bool
i32, _, ok = ReadInt32(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
case bsontype.Int64:
i64, _, ok := ReadInt64(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
i32 = int32(i64)
case bsontype.Decimal128:
panic(ElementTypeError{"bsoncore.Value.AsInt32", v.Type})
}
return i32
}
// AsInt32OK functions the same as AsInt32 but returns a boolean instead of panicking. False
// indicates an error.
func (v Value) AsInt32OK() (int32, bool) {
if !v.IsNumber() {
return 0, false
}
var i32 int32
switch v.Type {
case bsontype.Double:
f64, _, ok := ReadDouble(v.Data)
if !ok {
return 0, false
}
i32 = int32(f64)
case bsontype.Int32:
var ok bool
i32, _, ok = ReadInt32(v.Data)
if !ok {
return 0, false
}
case bsontype.Int64:
i64, _, ok := ReadInt64(v.Data)
if !ok {
return 0, false
}
i32 = int32(i64)
case bsontype.Decimal128:
return 0, false
}
return i32, true
}
// AsInt64 returns a BSON number as an int64. If the BSON type is not a numeric one, this method
// will panic.
func (v Value) AsInt64() int64 {
if !v.IsNumber() {
panic(ElementTypeError{"bsoncore.Value.AsInt64", v.Type})
}
var i64 int64
switch v.Type {
case bsontype.Double:
f64, _, ok := ReadDouble(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
i64 = int64(f64)
case bsontype.Int32:
var ok bool
i32, _, ok := ReadInt32(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
i64 = int64(i32)
case bsontype.Int64:
var ok bool
i64, _, ok = ReadInt64(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
case bsontype.Decimal128:
panic(ElementTypeError{"bsoncore.Value.AsInt64", v.Type})
}
return i64
}
// AsInt64OK functions the same as AsInt64 but returns a boolean instead of panicking. False
// indicates an error.
func (v Value) AsInt64OK() (int64, bool) {
if !v.IsNumber() {
return 0, false
}
var i64 int64
switch v.Type {
case bsontype.Double:
f64, _, ok := ReadDouble(v.Data)
if !ok {
return 0, false
}
i64 = int64(f64)
case bsontype.Int32:
var ok bool
i32, _, ok := ReadInt32(v.Data)
if !ok {
return 0, false
}
i64 = int64(i32)
case bsontype.Int64:
var ok bool
i64, _, ok = ReadInt64(v.Data)
if !ok {
return 0, false
}
case bsontype.Decimal128:
return 0, false
}
return i64, true
}
// AsFloat64 returns a BSON number as an float64. If the BSON type is not a numeric one, this method
// will panic.
//
// TODO(skriptble): Add support for Decimal128.
func (v Value) AsFloat64() float64 { return 0 }
// AsFloat64OK functions the same as AsFloat64 but returns a boolean instead of panicking. False
// indicates an error.
//
// TODO(skriptble): Add support for Decimal128.
func (v Value) AsFloat64OK() (float64, bool) { return 0, false }
// Add will add this value to another. This is currently only implemented for strings and numbers.
// If either value is a string, the other type is coerced into a string and added to the other.
//
// This method will alter v and will attempt to reuse the []byte of v. If the []byte is too small,
// it will be expanded.
func (v *Value) Add(v2 Value) error { return nil }
// Equal compaes v to v2 and returns true if they are equal.
func (v Value) Equal(v2 Value) bool {
if v.Type != v2.Type {
return false
}
return bytes.Equal(v.Data, v2.Data)
}
// String implements the fmt.String interface. This method will return values in extended JSON
// format. If the value is not valid, this returns an empty string
func (v Value) String() string {
switch v.Type {
case bsontype.Double:
f64, ok := v.DoubleOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$numberDouble":"%s"}`, formatDouble(f64))
case bsontype.String:
str, ok := v.StringValueOK()
if !ok {
return ""
}
return escapeString(str)
case bsontype.EmbeddedDocument:
doc, ok := v.DocumentOK()
if !ok {
return ""
}
return doc.String()
case bsontype.Array:
arr, ok := v.ArrayOK()
if !ok {
return ""
}
return arr.String()
case bsontype.Binary:
subtype, data, ok := v.BinaryOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$binary":{"base64":"%s","subType":"%02x"}}`, base64.StdEncoding.EncodeToString(data), subtype)
case bsontype.Undefined:
return `{"$undefined":true}`
case bsontype.ObjectID:
oid, ok := v.ObjectIDOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$oid":"%s"}`, oid.Hex())
case bsontype.Boolean:
b, ok := v.BooleanOK()
if !ok {
return ""
}
return strconv.FormatBool(b)
case bsontype.DateTime:
dt, ok := v.DateTimeOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$date":{"$numberLong":"%d"}}`, dt)
case bsontype.Null:
return "null"
case bsontype.Regex:
pattern, options, ok := v.RegexOK()
if !ok {
return ""
}
return fmt.Sprintf(
`{"$regularExpression":{"pattern":%s,"options":"%s"}}`,
escapeString(pattern), sortStringAlphebeticAscending(options),
)
case bsontype.DBPointer:
ns, pointer, ok := v.DBPointerOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$dbPointer":{"$ref":%s,"$id":{"$oid":"%s"}}}`, escapeString(ns), pointer.Hex())
case bsontype.JavaScript:
js, ok := v.JavaScriptOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$code":%s}`, escapeString(js))
case bsontype.Symbol:
symbol, ok := v.SymbolOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$symbol":%s}`, escapeString(symbol))
case bsontype.CodeWithScope:
code, scope, ok := v.CodeWithScopeOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$code":%s,"$scope":%s}`, code, scope)
case bsontype.Int32:
i32, ok := v.Int32OK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$numberInt":"%d"}`, i32)
case bsontype.Timestamp:
t, i, ok := v.TimestampOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$timestamp":{"t":%v,"i":%v}}`, t, i)
case bsontype.Int64:
i64, ok := v.Int64OK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$numberLong":"%d"}`, i64)
case bsontype.Decimal128:
d128, ok := v.Decimal128OK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$numberDecimal":"%s"}`, d128.String())
case bsontype.MinKey:
return `{"$minKey":1}`
case bsontype.MaxKey:
return `{"$maxKey":1}`
default:
return ""
}
}
// DebugString outputs a human readable version of Document. It will attempt to stringify the
// valid components of the document even if the entire document is not valid.
func (v Value) DebugString() string {
switch v.Type {
case bsontype.String:
str, ok := v.StringValueOK()
if !ok {
return "<malformed>"
}
return escapeString(str)
case bsontype.EmbeddedDocument:
doc, ok := v.DocumentOK()
if !ok {
return "<malformed>"
}
return doc.DebugString()
case bsontype.Array:
arr, ok := v.ArrayOK()
if !ok {
return "<malformed>"
}
return arr.DebugString()
case bsontype.CodeWithScope:
code, scope, ok := v.CodeWithScopeOK()
if !ok {
return ""
}
return fmt.Sprintf(`{"$code":%s,"$scope":%s}`, code, scope.DebugString())
default:
str := v.String()
if str == "" {
return "<malformed>"
}
return str
}
}
// Double returns the float64 value for this element.
// It panics if e's BSON type is not bsontype.Double.
func (v Value) Double() float64 {
if v.Type != bsontype.Double {
panic(ElementTypeError{"bsoncore.Value.Double", v.Type})
}
f64, _, ok := ReadDouble(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return f64
}
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
func (v Value) DoubleOK() (float64, bool) {
if v.Type != bsontype.Double {
return 0, false
}
f64, _, ok := ReadDouble(v.Data)
if !ok {
return 0, false
}
return f64, true
}
// StringValue returns the string balue for this element.
// It panics if e's BSON type is not bsontype.String.
//
// NOTE: This method is called StringValue to avoid a collision with the String method which
// implements the fmt.Stringer interface.
func (v Value) StringValue() string {
if v.Type != bsontype.String {
panic(ElementTypeError{"bsoncore.Value.StringValue", v.Type})
}
str, _, ok := ReadString(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return str
}
// StringValueOK is the same as StringValue, but returns a boolean instead of
// panicking.
func (v Value) StringValueOK() (string, bool) {
if v.Type != bsontype.String {
return "", false
}
str, _, ok := ReadString(v.Data)
if !ok {
return "", false
}
return str, true
}
// Document returns the BSON document the Value represents as a Document. It panics if the
// value is a BSON type other than document.
func (v Value) Document() Document {
if v.Type != bsontype.EmbeddedDocument {
panic(ElementTypeError{"bsoncore.Value.Document", v.Type})
}
doc, _, ok := ReadDocument(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return doc
}
// DocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (v Value) DocumentOK() (Document, bool) {
if v.Type != bsontype.EmbeddedDocument {
return nil, false
}
doc, _, ok := ReadDocument(v.Data)
if !ok {
return nil, false
}
return doc, true
}
// Array returns the BSON array the Value represents as an Array. It panics if the
// value is a BSON type other than array.
func (v Value) Array() Array {
if v.Type != bsontype.Array {
panic(ElementTypeError{"bsoncore.Value.Array", v.Type})
}
arr, _, ok := ReadArray(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return arr
}
// ArrayOK is the same as Array, except it returns a boolean instead
// of panicking.
func (v Value) ArrayOK() (Array, bool) {
if v.Type != bsontype.Array {
return nil, false
}
arr, _, ok := ReadArray(v.Data)
if !ok {
return nil, false
}
return arr, true
}
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Value) Binary() (subtype byte, data []byte) {
if v.Type != bsontype.Binary {
panic(ElementTypeError{"bsoncore.Value.Binary", v.Type})
}
subtype, data, _, ok := ReadBinary(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return subtype, data
}
// BinaryOK is the same as Binary, except it returns a boolean instead of
// panicking.
func (v Value) BinaryOK() (subtype byte, data []byte, ok bool) {
if v.Type != bsontype.Binary {
return 0x00, nil, false
}
subtype, data, _, ok = ReadBinary(v.Data)
if !ok {
return 0x00, nil, false
}
return subtype, data, true
}
// ObjectID returns the BSON objectid value the Value represents. It panics if the value is a BSON
// type other than objectid.
func (v Value) ObjectID() primitive.ObjectID {
if v.Type != bsontype.ObjectID {
panic(ElementTypeError{"bsoncore.Value.ObjectID", v.Type})
}
oid, _, ok := ReadObjectID(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return oid
}
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
// panicking.
func (v Value) ObjectIDOK() (primitive.ObjectID, bool) {
if v.Type != bsontype.ObjectID {
return primitive.ObjectID{}, false
}
oid, _, ok := ReadObjectID(v.Data)
if !ok {
return primitive.ObjectID{}, false
}
return oid, true
}
// Boolean returns the boolean value the Value represents. It panics if the
// value is a BSON type other than boolean.
func (v Value) Boolean() bool {
if v.Type != bsontype.Boolean {
panic(ElementTypeError{"bsoncore.Value.Boolean", v.Type})
}
b, _, ok := ReadBoolean(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return b
}
// BooleanOK is the same as Boolean, except it returns a boolean instead of
// panicking.
func (v Value) BooleanOK() (bool, bool) {
if v.Type != bsontype.Boolean {
return false, false
}
b, _, ok := ReadBoolean(v.Data)
if !ok {
return false, false
}
return b, true
}
// DateTime returns the BSON datetime value the Value represents as a
// unix timestamp. It panics if the value is a BSON type other than datetime.
func (v Value) DateTime() int64 {
if v.Type != bsontype.DateTime {
panic(ElementTypeError{"bsoncore.Value.DateTime", v.Type})
}
dt, _, ok := ReadDateTime(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return dt
}
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
// panicking.
func (v Value) DateTimeOK() (int64, bool) {
if v.Type != bsontype.DateTime {
return 0, false
}
dt, _, ok := ReadDateTime(v.Data)
if !ok {
return 0, false
}
return dt, true
}
// Time returns the BSON datetime value the Value represents. It panics if the value is a BSON
// type other than datetime.
func (v Value) Time() time.Time {
if v.Type != bsontype.DateTime {
panic(ElementTypeError{"bsoncore.Value.Time", v.Type})
}
dt, _, ok := ReadDateTime(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return time.Unix(dt/1000, dt%1000*1000000)
}
// TimeOK is the same as Time, except it returns a boolean instead of
// panicking.
func (v Value) TimeOK() (time.Time, bool) {
if v.Type != bsontype.DateTime {
return time.Time{}, false
}
dt, _, ok := ReadDateTime(v.Data)
if !ok {
return time.Time{}, false
}
return time.Unix(dt/1000, dt%1000*1000000), true
}
// Regex returns the BSON regex value the Value represents. It panics if the value is a BSON
// type other than regex.
func (v Value) Regex() (pattern, options string) {
if v.Type != bsontype.Regex {
panic(ElementTypeError{"bsoncore.Value.Regex", v.Type})
}
pattern, options, _, ok := ReadRegex(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return pattern, options
}
// RegexOK is the same as Regex, except it returns a boolean instead of
// panicking.
func (v Value) RegexOK() (pattern, options string, ok bool) {
if v.Type != bsontype.Regex {
return "", "", false
}
pattern, options, _, ok = ReadRegex(v.Data)
if !ok {
return "", "", false
}
return pattern, options, true
}
// DBPointer returns the BSON dbpointer value the Value represents. It panics if the value is a BSON
// type other than DBPointer.
func (v Value) DBPointer() (string, primitive.ObjectID) {
if v.Type != bsontype.DBPointer {
panic(ElementTypeError{"bsoncore.Value.DBPointer", v.Type})
}
ns, pointer, _, ok := ReadDBPointer(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return ns, pointer
}
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
// instead of panicking.
func (v Value) DBPointerOK() (string, primitive.ObjectID, bool) {
if v.Type != bsontype.DBPointer {
return "", primitive.ObjectID{}, false
}
ns, pointer, _, ok := ReadDBPointer(v.Data)
if !ok {
return "", primitive.ObjectID{}, false
}
return ns, pointer, true
}
// JavaScript returns the BSON JavaScript code value the Value represents. It panics if the value is
// a BSON type other than JavaScript code.
func (v Value) JavaScript() string {
if v.Type != bsontype.JavaScript {
panic(ElementTypeError{"bsoncore.Value.JavaScript", v.Type})
}
js, _, ok := ReadJavaScript(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return js
}
// JavaScriptOK is the same as Javascript, excepti that it returns a boolean
// instead of panicking.
func (v Value) JavaScriptOK() (string, bool) {
if v.Type != bsontype.JavaScript {
return "", false
}
js, _, ok := ReadJavaScript(v.Data)
if !ok {
return "", false
}
return js, true
}
// Symbol returns the BSON symbol value the Value represents. It panics if the value is a BSON
// type other than symbol.
func (v Value) Symbol() string {
if v.Type != bsontype.Symbol {
panic(ElementTypeError{"bsoncore.Value.Symbol", v.Type})
}
symbol, _, ok := ReadSymbol(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return symbol
}
// SymbolOK is the same as Symbol, excepti that it returns a boolean
// instead of panicking.
func (v Value) SymbolOK() (string, bool) {
if v.Type != bsontype.Symbol {
return "", false
}
symbol, _, ok := ReadSymbol(v.Data)
if !ok {
return "", false
}
return symbol, true
}
// CodeWithScope returns the BSON JavaScript code with scope the Value represents.
// It panics if the value is a BSON type other than JavaScript code with scope.
func (v Value) CodeWithScope() (string, Document) {
if v.Type != bsontype.CodeWithScope {
panic(ElementTypeError{"bsoncore.Value.CodeWithScope", v.Type})
}
code, scope, _, ok := ReadCodeWithScope(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return code, scope
}
// CodeWithScopeOK is the same as CodeWithScope, except that it returns a boolean instead of
// panicking.
func (v Value) CodeWithScopeOK() (string, Document, bool) {
if v.Type != bsontype.CodeWithScope {
return "", nil, false
}
code, scope, _, ok := ReadCodeWithScope(v.Data)
if !ok {
return "", nil, false
}
return code, scope, true
}
// Int32 returns the int32 the Value represents. It panics if the value is a BSON type other than
// int32.
func (v Value) Int32() int32 {
if v.Type != bsontype.Int32 {
panic(ElementTypeError{"bsoncore.Value.Int32", v.Type})
}
i32, _, ok := ReadInt32(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return i32
}
// Int32OK is the same as Int32, except that it returns a boolean instead of
// panicking.
func (v Value) Int32OK() (int32, bool) {
if v.Type != bsontype.Int32 {
return 0, false
}
i32, _, ok := ReadInt32(v.Data)
if !ok {
return 0, false
}
return i32, true
}
// Timestamp returns the BSON timestamp value the Value represents. It panics if the value is a
// BSON type other than timestamp.
func (v Value) Timestamp() (t, i uint32) {
if v.Type != bsontype.Timestamp {
panic(ElementTypeError{"bsoncore.Value.Timestamp", v.Type})
}
t, i, _, ok := ReadTimestamp(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return t, i
}
// TimestampOK is the same as Timestamp, except that it returns a boolean
// instead of panicking.
func (v Value) TimestampOK() (t, i uint32, ok bool) {
if v.Type != bsontype.Timestamp {
return 0, 0, false
}
t, i, _, ok = ReadTimestamp(v.Data)
if !ok {
return 0, 0, false
}
return t, i, true
}
// Int64 returns the int64 the Value represents. It panics if the value is a BSON type other than
// int64.
func (v Value) Int64() int64 {
if v.Type != bsontype.Int64 {
panic(ElementTypeError{"bsoncore.Value.Int64", v.Type})
}
i64, _, ok := ReadInt64(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return i64
}
// Int64OK is the same as Int64, except that it returns a boolean instead of
// panicking.
func (v Value) Int64OK() (int64, bool) {
if v.Type != bsontype.Int64 {
return 0, false
}
i64, _, ok := ReadInt64(v.Data)
if !ok {
return 0, false
}
return i64, true
}
// Decimal128 returns the decimal the Value represents. It panics if the value is a BSON type other than
// decimal.
func (v Value) Decimal128() primitive.Decimal128 {
if v.Type != bsontype.Decimal128 {
panic(ElementTypeError{"bsoncore.Value.Decimal128", v.Type})
}
d128, _, ok := ReadDecimal128(v.Data)
if !ok {
panic(NewInsufficientBytesError(v.Data, v.Data))
}
return d128
}
// Decimal128OK is the same as Decimal128, except that it returns a boolean
// instead of panicking.
func (v Value) Decimal128OK() (primitive.Decimal128, bool) {
if v.Type != bsontype.Decimal128 {
return primitive.Decimal128{}, false
}
d128, _, ok := ReadDecimal128(v.Data)
if !ok {
return primitive.Decimal128{}, false
}
return d128, true
}
var hexChars = "0123456789abcdef"
func escapeString(s string) string {
escapeHTML := true
var buf bytes.Buffer
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
return buf.String()
}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
oldI := ss[i]
ss[i] = ss[j]
ss[j] = oldI
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}

View File

@@ -0,0 +1,678 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncore
import (
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func TestValue(t *testing.T) {
t.Parallel()
t.Run("Validate", func(t *testing.T) {
t.Parallel()
t.Run("invalid", func(t *testing.T) {
t.Parallel()
v := Value{Type: bsontype.Double, Data: []byte{0x01, 0x02, 0x03, 0x04}}
want := NewInsufficientBytesError(v.Data, v.Data)
got := v.Validate()
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("value", func(t *testing.T) {
t.Parallel()
v := Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)}
var want error
got := v.Validate()
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
})
t.Run("IsNumber", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
val Value
isnum bool
}{
{"double", Value{Type: bsontype.Double}, true},
{"int32", Value{Type: bsontype.Int32}, true},
{"int64", Value{Type: bsontype.Int64}, true},
{"decimal128", Value{Type: bsontype.Decimal128}, true},
{"string", Value{Type: bsontype.String}, false},
{"regex", Value{Type: bsontype.Regex}, false},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
isnum := tc.val.IsNumber()
if isnum != tc.isnum {
t.Errorf("IsNumber did not return the expected boolean. got %t; want %t", isnum, tc.isnum)
}
})
}
})
now := time.Now().Truncate(time.Millisecond)
oid := primitive.NewObjectID()
testCases := []struct {
name string
fn interface{}
val Value
panicErr error
ret []interface{}
}{
{
"Double/Not Double", Value.Double, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Double", bsontype.String},
nil,
},
{
"Double/Insufficient Bytes", Value.Double, Value{Type: bsontype.Double, Data: []byte{0x01, 0x02, 0x03, 0x04}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
nil,
},
{
"Double/Success", Value.Double, Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)},
nil,
[]interface{}{float64(3.14159)},
},
{
"DoubleOK/Not Double", Value.DoubleOK, Value{Type: bsontype.String},
nil,
[]interface{}{float64(0), false},
},
{
"DoubleOK/Insufficient Bytes", Value.DoubleOK, Value{Type: bsontype.Double, Data: []byte{0x01, 0x02, 0x03, 0x04}},
nil,
[]interface{}{float64(0), false},
},
{
"DoubleOK/Success", Value.DoubleOK, Value{Type: bsontype.Double, Data: AppendDouble(nil, 3.14159)},
nil,
[]interface{}{float64(3.14159), true},
},
{
"StringValue/Not String", Value.StringValue, Value{Type: bsontype.Double},
ElementTypeError{"bsoncore.Value.StringValue", bsontype.Double},
nil,
},
{
"StringValue/Insufficient Bytes", Value.StringValue, Value{Type: bsontype.String, Data: []byte{0x01, 0x02, 0x03, 0x04}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
nil,
},
{
"StringValue/Zero Length", Value.StringValue, Value{Type: bsontype.String, Data: []byte{0x00, 0x00, 0x00, 0x00}},
NewInsufficientBytesError([]byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00, 0x00, 0x00}),
nil,
},
{
"StringValue/Success", Value.StringValue, Value{Type: bsontype.String, Data: AppendString(nil, "hello, world!")},
nil,
[]interface{}{"hello, world!"},
},
{
"StringValueOK/Not String", Value.StringValueOK, Value{Type: bsontype.Double},
nil,
[]interface{}{"", false},
},
{
"StringValueOK/Insufficient Bytes", Value.StringValueOK, Value{Type: bsontype.String, Data: []byte{0x01, 0x02, 0x03, 0x04}},
nil,
[]interface{}{"", false},
},
{
"StringValueOK/Zero Length", Value.StringValueOK, Value{Type: bsontype.String, Data: []byte{0x00, 0x00, 0x00, 0x00}},
nil,
[]interface{}{"", false},
},
{
"StringValueOK/Success", Value.StringValueOK, Value{Type: bsontype.String, Data: AppendString(nil, "hello, world!")},
nil,
[]interface{}{"hello, world!", true},
},
{
"Document/Not Document", Value.Document, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Document", bsontype.String},
nil,
},
{
"Document/Insufficient Bytes", Value.Document, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x01, 0x02, 0x03, 0x04}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
nil,
},
{
"Document/Success", Value.Document, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
nil,
[]interface{}{Document{0x05, 0x00, 0x00, 0x00, 0x00}},
},
{
"DocumentOK/Not Document", Value.DocumentOK, Value{Type: bsontype.String},
nil,
[]interface{}{Document(nil), false},
},
{
"DocumentOK/Insufficient Bytes", Value.DocumentOK, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x01, 0x02, 0x03, 0x04}},
nil,
[]interface{}{Document(nil), false},
},
{
"DocumentOK/Success", Value.DocumentOK, Value{Type: bsontype.EmbeddedDocument, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
nil,
[]interface{}{Document{0x05, 0x00, 0x00, 0x00, 0x00}, true},
},
{
"Array/Not Array", Value.Array, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Array", bsontype.String},
nil,
},
{
"Array/Insufficient Bytes", Value.Array, Value{Type: bsontype.Array, Data: []byte{0x01, 0x02, 0x03, 0x04}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
nil,
},
{
"Array/Success", Value.Array, Value{Type: bsontype.Array, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
nil,
[]interface{}{Array{0x05, 0x00, 0x00, 0x00, 0x00}},
},
{
"ArrayOK/Not Array", Value.ArrayOK, Value{Type: bsontype.String},
nil,
[]interface{}{Array(nil), false},
},
{
"ArrayOK/Insufficient Bytes", Value.ArrayOK, Value{Type: bsontype.Array, Data: []byte{0x01, 0x02, 0x03, 0x04}},
nil,
[]interface{}{Array(nil), false},
},
{
"ArrayOK/Success", Value.ArrayOK, Value{Type: bsontype.Array, Data: []byte{0x05, 0x00, 0x00, 0x00, 0x00}},
nil,
[]interface{}{Array{0x05, 0x00, 0x00, 0x00, 0x00}, true},
},
{
"Binary/Not Binary", Value.Binary, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Binary", bsontype.String},
nil,
},
{
"Binary/Insufficient Bytes", Value.Binary, Value{Type: bsontype.Binary, Data: []byte{0x01, 0x02, 0x03, 0x04}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
nil,
},
{
"Binary/Success", Value.Binary, Value{Type: bsontype.Binary, Data: AppendBinary(nil, 0xFF, []byte{0x01, 0x02, 0x03})},
nil,
[]interface{}{byte(0xFF), []byte{0x01, 0x02, 0x03}},
},
{
"BinaryOK/Not Binary", Value.BinaryOK, Value{Type: bsontype.String},
nil,
[]interface{}{byte(0x00), []byte(nil), false},
},
{
"BinaryOK/Insufficient Bytes", Value.BinaryOK, Value{Type: bsontype.Binary, Data: []byte{0x01, 0x02, 0x03, 0x04}},
nil,
[]interface{}{byte(0x00), []byte(nil), false},
},
{
"BinaryOK/Success", Value.BinaryOK, Value{Type: bsontype.Binary, Data: AppendBinary(nil, 0xFF, []byte{0x01, 0x02, 0x03})},
nil,
[]interface{}{byte(0xFF), []byte{0x01, 0x02, 0x03}, true},
},
{
"ObjectID/Not ObjectID", Value.ObjectID, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.ObjectID", bsontype.String},
nil,
},
{
"ObjectID/Insufficient Bytes", Value.ObjectID, Value{Type: bsontype.ObjectID, Data: []byte{0x01, 0x02, 0x03, 0x04}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}),
nil,
},
{
"ObjectID/Success", Value.ObjectID, Value{Type: bsontype.ObjectID, Data: AppendObjectID(nil, primitive.ObjectID{0x01, 0x02})},
nil,
[]interface{}{primitive.ObjectID{0x01, 0x02}},
},
{
"ObjectIDOK/Not ObjectID", Value.ObjectIDOK, Value{Type: bsontype.String},
nil,
[]interface{}{primitive.ObjectID{}, false},
},
{
"ObjectIDOK/Insufficient Bytes", Value.ObjectIDOK, Value{Type: bsontype.ObjectID, Data: []byte{0x01, 0x02, 0x03, 0x04}},
nil,
[]interface{}{primitive.ObjectID{}, false},
},
{
"ObjectIDOK/Success", Value.ObjectIDOK, Value{Type: bsontype.ObjectID, Data: AppendObjectID(nil, primitive.ObjectID{0x01, 0x02})},
nil,
[]interface{}{primitive.ObjectID{0x01, 0x02}, true},
},
{
"Boolean/Not Boolean", Value.Boolean, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Boolean", bsontype.String},
nil,
},
{
"Boolean/Insufficient Bytes", Value.Boolean, Value{Type: bsontype.Boolean, Data: []byte{}},
NewInsufficientBytesError([]byte{}, []byte{}),
nil,
},
{
"Boolean/Success", Value.Boolean, Value{Type: bsontype.Boolean, Data: AppendBoolean(nil, true)},
nil,
[]interface{}{true},
},
{
"BooleanOK/Not Boolean", Value.BooleanOK, Value{Type: bsontype.String},
nil,
[]interface{}{false, false},
},
{
"BooleanOK/Insufficient Bytes", Value.BooleanOK, Value{Type: bsontype.Boolean, Data: []byte{}},
nil,
[]interface{}{false, false},
},
{
"BooleanOK/Success", Value.BooleanOK, Value{Type: bsontype.Boolean, Data: AppendBoolean(nil, true)},
nil,
[]interface{}{true, true},
},
{
"DateTime/Not DateTime", Value.DateTime, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.DateTime", bsontype.String},
nil,
},
{
"DateTime/Insufficient Bytes", Value.DateTime, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{}, []byte{}),
nil,
},
{
"DateTime/Success", Value.DateTime, Value{Type: bsontype.DateTime, Data: AppendDateTime(nil, 12345)},
nil,
[]interface{}{int64(12345)},
},
{
"DateTimeOK/Not DateTime", Value.DateTimeOK, Value{Type: bsontype.String},
nil,
[]interface{}{int64(0), false},
},
{
"DateTimeOK/Insufficient Bytes", Value.DateTimeOK, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{int64(0), false},
},
{
"DateTimeOK/Success", Value.DateTimeOK, Value{Type: bsontype.DateTime, Data: AppendDateTime(nil, 12345)},
nil,
[]interface{}{int64(12345), true},
},
{
"Time/Not DateTime", Value.Time, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Time", bsontype.String},
nil,
},
{
"Time/Insufficient Bytes", Value.Time, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Time/Success", Value.Time, Value{Type: bsontype.DateTime, Data: AppendTime(nil, now)},
nil,
[]interface{}{now},
},
{
"TimeOK/Not DateTime", Value.TimeOK, Value{Type: bsontype.String},
nil,
[]interface{}{time.Time{}, false},
},
{
"TimeOK/Insufficient Bytes", Value.TimeOK, Value{Type: bsontype.DateTime, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{time.Time{}, false},
},
{
"TimeOK/Success", Value.TimeOK, Value{Type: bsontype.DateTime, Data: AppendTime(nil, now)},
nil,
[]interface{}{now, true},
},
{
"Regex/Not Regex", Value.Regex, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Regex", bsontype.String},
nil,
},
{
"Regex/Insufficient Bytes", Value.Regex, Value{Type: bsontype.Regex, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Regex/Success", Value.Regex, Value{Type: bsontype.Regex, Data: AppendRegex(nil, "/abcdefg/", "hijkl")},
nil,
[]interface{}{"/abcdefg/", "hijkl"},
},
{
"RegexOK/Not Regex", Value.RegexOK, Value{Type: bsontype.String},
nil,
[]interface{}{"", "", false},
},
{
"RegexOK/Insufficient Bytes", Value.RegexOK, Value{Type: bsontype.Regex, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{"", "", false},
},
{
"RegexOK/Success", Value.RegexOK, Value{Type: bsontype.Regex, Data: AppendRegex(nil, "/abcdefg/", "hijkl")},
nil,
[]interface{}{"/abcdefg/", "hijkl", true},
},
{
"DBPointer/Not DBPointer", Value.DBPointer, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.DBPointer", bsontype.String},
nil,
},
{
"DBPointer/Insufficient Bytes", Value.DBPointer, Value{Type: bsontype.DBPointer, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"DBPointer/Success", Value.DBPointer, Value{Type: bsontype.DBPointer, Data: AppendDBPointer(nil, "foobar", oid)},
nil,
[]interface{}{"foobar", oid},
},
{
"DBPointerOK/Not DBPointer", Value.DBPointerOK, Value{Type: bsontype.String},
nil,
[]interface{}{"", primitive.ObjectID{}, false},
},
{
"DBPointerOK/Insufficient Bytes", Value.DBPointerOK, Value{Type: bsontype.DBPointer, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{"", primitive.ObjectID{}, false},
},
{
"DBPointerOK/Success", Value.DBPointerOK, Value{Type: bsontype.DBPointer, Data: AppendDBPointer(nil, "foobar", oid)},
nil,
[]interface{}{"foobar", oid, true},
},
{
"JavaScript/Not JavaScript", Value.JavaScript, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.JavaScript", bsontype.String},
nil,
},
{
"JavaScript/Insufficient Bytes", Value.JavaScript, Value{Type: bsontype.JavaScript, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"JavaScript/Success", Value.JavaScript, Value{Type: bsontype.JavaScript, Data: AppendJavaScript(nil, "var hello = 'world';")},
nil,
[]interface{}{"var hello = 'world';"},
},
{
"JavaScriptOK/Not JavaScript", Value.JavaScriptOK, Value{Type: bsontype.String},
nil,
[]interface{}{"", false},
},
{
"JavaScriptOK/Insufficient Bytes", Value.JavaScriptOK, Value{Type: bsontype.JavaScript, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{"", false},
},
{
"JavaScriptOK/Success", Value.JavaScriptOK, Value{Type: bsontype.JavaScript, Data: AppendJavaScript(nil, "var hello = 'world';")},
nil,
[]interface{}{"var hello = 'world';", true},
},
{
"Symbol/Not Symbol", Value.Symbol, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Symbol", bsontype.String},
nil,
},
{
"Symbol/Insufficient Bytes", Value.Symbol, Value{Type: bsontype.Symbol, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Symbol/Success", Value.Symbol, Value{Type: bsontype.Symbol, Data: AppendSymbol(nil, "symbol123456")},
nil,
[]interface{}{"symbol123456"},
},
{
"SymbolOK/Not Symbol", Value.SymbolOK, Value{Type: bsontype.String},
nil,
[]interface{}{"", false},
},
{
"SymbolOK/Insufficient Bytes", Value.SymbolOK, Value{Type: bsontype.Symbol, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{"", false},
},
{
"SymbolOK/Success", Value.SymbolOK, Value{Type: bsontype.Symbol, Data: AppendSymbol(nil, "symbol123456")},
nil,
[]interface{}{"symbol123456", true},
},
{
"CodeWithScope/Not CodeWithScope", Value.CodeWithScope, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.CodeWithScope", bsontype.String},
nil,
},
{
"CodeWithScope/Insufficient Bytes", Value.CodeWithScope, Value{Type: bsontype.CodeWithScope, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"CodeWithScope/Success", Value.CodeWithScope, Value{Type: bsontype.CodeWithScope, Data: AppendCodeWithScope(nil, "var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00})},
nil,
[]interface{}{"var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00}},
},
{
"CodeWithScopeOK/Not CodeWithScope", Value.CodeWithScopeOK, Value{Type: bsontype.String},
nil,
[]interface{}{"", Document(nil), false},
},
{
"CodeWithScopeOK/Insufficient Bytes", Value.CodeWithScopeOK, Value{Type: bsontype.CodeWithScope, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{"", Document(nil), false},
},
{
"CodeWithScopeOK/Success", Value.CodeWithScopeOK, Value{Type: bsontype.CodeWithScope, Data: AppendCodeWithScope(nil, "var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00})},
nil,
[]interface{}{"var hello = 'world';", Document{0x05, 0x00, 0x00, 0x00, 0x00}, true},
},
{
"Int32/Not Int32", Value.Int32, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Int32", bsontype.String},
nil,
},
{
"Int32/Insufficient Bytes", Value.Int32, Value{Type: bsontype.Int32, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Int32/Success", Value.Int32, Value{Type: bsontype.Int32, Data: AppendInt32(nil, 1234)},
nil,
[]interface{}{int32(1234)},
},
{
"Int32OK/Not Int32", Value.Int32OK, Value{Type: bsontype.String},
nil,
[]interface{}{int32(0), false},
},
{
"Int32OK/Insufficient Bytes", Value.Int32OK, Value{Type: bsontype.Int32, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{int32(0), false},
},
{
"Int32OK/Success", Value.Int32OK, Value{Type: bsontype.Int32, Data: AppendInt32(nil, 1234)},
nil,
[]interface{}{int32(1234), true},
},
{
"Timestamp/Not Timestamp", Value.Timestamp, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Timestamp", bsontype.String},
nil,
},
{
"Timestamp/Insufficient Bytes", Value.Timestamp, Value{Type: bsontype.Timestamp, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Timestamp/Success", Value.Timestamp, Value{Type: bsontype.Timestamp, Data: AppendTimestamp(nil, 12345, 67890)},
nil,
[]interface{}{uint32(12345), uint32(67890)},
},
{
"TimestampOK/Not Timestamp", Value.TimestampOK, Value{Type: bsontype.String},
nil,
[]interface{}{uint32(0), uint32(0), false},
},
{
"TimestampOK/Insufficient Bytes", Value.TimestampOK, Value{Type: bsontype.Timestamp, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{uint32(0), uint32(0), false},
},
{
"TimestampOK/Success", Value.TimestampOK, Value{Type: bsontype.Timestamp, Data: AppendTimestamp(nil, 12345, 67890)},
nil,
[]interface{}{uint32(12345), uint32(67890), true},
},
{
"Int64/Not Int64", Value.Int64, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Int64", bsontype.String},
nil,
},
{
"Int64/Insufficient Bytes", Value.Int64, Value{Type: bsontype.Int64, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Int64/Success", Value.Int64, Value{Type: bsontype.Int64, Data: AppendInt64(nil, 1234567890)},
nil,
[]interface{}{int64(1234567890)},
},
{
"Int64OK/Not Int64", Value.Int64OK, Value{Type: bsontype.String},
nil,
[]interface{}{int64(0), false},
},
{
"Int64OK/Insufficient Bytes", Value.Int64OK, Value{Type: bsontype.Int64, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{int64(0), false},
},
{
"Int64OK/Success", Value.Int64OK, Value{Type: bsontype.Int64, Data: AppendInt64(nil, 1234567890)},
nil,
[]interface{}{int64(1234567890), true},
},
{
"Decimal128/Not Decimal128", Value.Decimal128, Value{Type: bsontype.String},
ElementTypeError{"bsoncore.Value.Decimal128", bsontype.String},
nil,
},
{
"Decimal128/Insufficient Bytes", Value.Decimal128, Value{Type: bsontype.Decimal128, Data: []byte{0x01, 0x02, 0x03}},
NewInsufficientBytesError([]byte{0x01, 0x02, 0x03}, []byte{0x01, 0x02, 0x03}),
nil,
},
{
"Decimal128/Success", Value.Decimal128, Value{Type: bsontype.Decimal128, Data: AppendDecimal128(nil, primitive.NewDecimal128(12345, 67890))},
nil,
[]interface{}{primitive.NewDecimal128(12345, 67890)},
},
{
"Decimal128OK/Not Decimal128", Value.Decimal128OK, Value{Type: bsontype.String},
nil,
[]interface{}{primitive.Decimal128{}, false},
},
{
"Decimal128OK/Insufficient Bytes", Value.Decimal128OK, Value{Type: bsontype.Decimal128, Data: []byte{0x01, 0x02, 0x03}},
nil,
[]interface{}{primitive.Decimal128{}, false},
},
{
"Decimal128OK/Success", Value.Decimal128OK, Value{Type: bsontype.Decimal128, Data: AppendDecimal128(nil, primitive.NewDecimal128(12345, 67890))},
nil,
[]interface{}{primitive.NewDecimal128(12345, 67890), true},
},
{
"Timestamp.String/Success", Value.String, Value{Type: bsontype.Timestamp, Data: AppendTimestamp(nil, 12345, 67890)},
nil,
[]interface{}{"{\"$timestamp\":{\"t\":12345,\"i\":67890}}"},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
defer func() {
err := recover()
if !cmp.Equal(err, tc.panicErr, cmp.Comparer(compareErrors)) {
t.Errorf("Did not receive expected panic error. got %v; want %v", err, tc.panicErr)
}
}()
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func || fn.Type().NumIn() != 1 || fn.Type().In(0) != reflect.TypeOf(Value{}) {
t.Fatalf("test case field fn must be a function with 1 parameter that is a Value, but it is %v", fn.Type())
}
got := fn.Call([]reflect.Value{reflect.ValueOf(tc.val)})
want := make([]reflect.Value, 0, len(tc.ret))
for _, ret := range tc.ret {
want = append(want, reflect.ValueOf(ret))
}
if len(got) != len(want) {
t.Fatalf("incorrect number of values returned. got %d; want %d", len(got), len(want))
}
for idx := range got {
gotv, wantv := got[idx].Interface(), want[idx].Interface()
if !cmp.Equal(gotv, wantv, cmp.Comparer(compareDecimal128)) {
t.Errorf("return values at index %d are not equal. got %v; want %v", idx, gotv, wantv)
}
}
})
}
}

View File

@@ -0,0 +1,166 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"encoding/binary"
"math"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// IDoc is the interface implemented by Doc and MDoc. It allows either of these types to be provided
// to the Document function to create a Value.
type IDoc interface {
idoc()
}
// Double constructs a BSON double Value.
func Double(f64 float64) Val {
v := Val{t: bsontype.Double}
binary.LittleEndian.PutUint64(v.bootstrap[0:8], math.Float64bits(f64))
return v
}
// String constructs a BSON string Value.
func String(str string) Val { return Val{t: bsontype.String}.writestring(str) }
// Document constructs a Value from the given IDoc. If nil is provided, a BSON Null value will be
// returned.
func Document(doc IDoc) Val {
var v Val
switch tt := doc.(type) {
case Doc:
if tt == nil {
v.t = bsontype.Null
break
}
v.t = bsontype.EmbeddedDocument
v.primitive = tt
case MDoc:
if tt == nil {
v.t = bsontype.Null
break
}
v.t = bsontype.EmbeddedDocument
v.primitive = tt
default:
v.t = bsontype.Null
}
return v
}
// Array constructs a Value from arr. If arr is nil, a BSON Null value is returned.
func Array(arr Arr) Val {
if arr == nil {
return Val{t: bsontype.Null}
}
return Val{t: bsontype.Array, primitive: arr}
}
// Binary constructs a BSON binary Value.
func Binary(subtype byte, data []byte) Val {
return Val{t: bsontype.Binary, primitive: primitive.Binary{Subtype: subtype, Data: data}}
}
// Undefined constructs a BSON binary Value.
func Undefined() Val { return Val{t: bsontype.Undefined} }
// ObjectID constructs a BSON objectid Value.
func ObjectID(oid primitive.ObjectID) Val {
v := Val{t: bsontype.ObjectID}
copy(v.bootstrap[0:12], oid[:])
return v
}
// Boolean constructs a BSON boolean Value.
func Boolean(b bool) Val {
v := Val{t: bsontype.Boolean}
if b {
v.bootstrap[0] = 0x01
}
return v
}
// DateTime constructs a BSON datetime Value.
func DateTime(dt int64) Val { return Val{t: bsontype.DateTime}.writei64(dt) }
// Time constructs a BSON datetime Value.
func Time(t time.Time) Val {
return Val{t: bsontype.DateTime}.writei64(t.Unix()*1e3 + int64(t.Nanosecond()/1e6))
}
// Null constructs a BSON binary Value.
func Null() Val { return Val{t: bsontype.Null} }
// Regex constructs a BSON regex Value.
func Regex(pattern, options string) Val {
regex := primitive.Regex{Pattern: pattern, Options: options}
return Val{t: bsontype.Regex, primitive: regex}
}
// DBPointer constructs a BSON dbpointer Value.
func DBPointer(ns string, ptr primitive.ObjectID) Val {
dbptr := primitive.DBPointer{DB: ns, Pointer: ptr}
return Val{t: bsontype.DBPointer, primitive: dbptr}
}
// JavaScript constructs a BSON javascript Value.
func JavaScript(js string) Val {
return Val{t: bsontype.JavaScript}.writestring(js)
}
// Symbol constructs a BSON symbol Value.
func Symbol(symbol string) Val {
return Val{t: bsontype.Symbol}.writestring(symbol)
}
// CodeWithScope constructs a BSON code with scope Value.
func CodeWithScope(code string, scope IDoc) Val {
cws := primitive.CodeWithScope{Code: primitive.JavaScript(code), Scope: scope}
return Val{t: bsontype.CodeWithScope, primitive: cws}
}
// Int32 constructs a BSON int32 Value.
func Int32(i32 int32) Val {
v := Val{t: bsontype.Int32}
v.bootstrap[0] = byte(i32)
v.bootstrap[1] = byte(i32 >> 8)
v.bootstrap[2] = byte(i32 >> 16)
v.bootstrap[3] = byte(i32 >> 24)
return v
}
// Timestamp constructs a BSON timestamp Value.
func Timestamp(t, i uint32) Val {
v := Val{t: bsontype.Timestamp}
v.bootstrap[0] = byte(i)
v.bootstrap[1] = byte(i >> 8)
v.bootstrap[2] = byte(i >> 16)
v.bootstrap[3] = byte(i >> 24)
v.bootstrap[4] = byte(t)
v.bootstrap[5] = byte(t >> 8)
v.bootstrap[6] = byte(t >> 16)
v.bootstrap[7] = byte(t >> 24)
return v
}
// Int64 constructs a BSON int64 Value.
func Int64(i64 int64) Val { return Val{t: bsontype.Int64}.writei64(i64) }
// Decimal128 constructs a BSON decimal128 Value.
func Decimal128(d128 primitive.Decimal128) Val {
return Val{t: bsontype.Decimal128, primitive: d128}
}
// MinKey constructs a BSON minkey Value.
func MinKey() Val { return Val{t: bsontype.MinKey} }
// MaxKey constructs a BSON maxkey Value.
func MaxKey() Val { return Val{t: bsontype.MaxKey} }

305
mongo/x/bsonx/document.go Normal file
View File

@@ -0,0 +1,305 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"bytes"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrNilDocument indicates that an operation was attempted on a nil *bson.Document.
var ErrNilDocument = errors.New("document is nil")
// KeyNotFound is an error type returned from the Lookup methods on Document. This type contains
// information about which key was not found and if it was actually not found or if a component of
// the key except the last was not a document nor array.
type KeyNotFound struct {
Key []string // The keys that were searched for.
Depth uint // Which key either was not found or was an incorrect type.
Type bsontype.Type // The type of the key that was found but was an incorrect type.
}
func (knf KeyNotFound) Error() string {
depth := knf.Depth
if depth >= uint(len(knf.Key)) {
depth = uint(len(knf.Key)) - 1
}
if len(knf.Key) == 0 {
return "no keys were provided for lookup"
}
if knf.Type != bsontype.Type(0) {
return fmt.Sprintf(`key "%s" was found but was not valid to traverse BSON type %s`, knf.Key[depth], knf.Type)
}
return fmt.Sprintf(`key "%s" was not found`, knf.Key[depth])
}
// Doc is a type safe, concise BSON document representation.
type Doc []Elem
// ReadDoc will create a Document using the provided slice of bytes. If the
// slice of bytes is not a valid BSON document, this method will return an error.
func ReadDoc(b []byte) (Doc, error) {
doc := make(Doc, 0)
err := doc.UnmarshalBSON(b)
if err != nil {
return nil, err
}
return doc, nil
}
// Copy makes a shallow copy of this document.
func (d Doc) Copy() Doc {
d2 := make(Doc, len(d))
copy(d2, d)
return d2
}
// Append adds an element to the end of the document, creating it from the key and value provided.
func (d Doc) Append(key string, val Val) Doc {
return append(d, Elem{Key: key, Value: val})
}
// Prepend adds an element to the beginning of the document, creating it from the key and value provided.
func (d Doc) Prepend(key string, val Val) Doc {
// TODO: should we just modify d itself instead of doing an alloc here?
return append(Doc{{Key: key, Value: val}}, d...)
}
// Set replaces an element of a document. If an element with a matching key is
// found, the element will be replaced with the one provided. If the document
// does not have an element with that key, the element is appended to the
// document instead.
func (d Doc) Set(key string, val Val) Doc {
idx := d.IndexOf(key)
if idx == -1 {
return append(d, Elem{Key: key, Value: val})
}
d[idx] = Elem{Key: key, Value: val}
return d
}
// IndexOf returns the index of the first element with a key of key, or -1 if no element with a key
// was found.
func (d Doc) IndexOf(key string) int {
for i, e := range d {
if e.Key == key {
return i
}
}
return -1
}
// Delete removes the element with key if it exists and returns the updated Doc.
func (d Doc) Delete(key string) Doc {
idx := d.IndexOf(key)
if idx == -1 {
return d
}
return append(d[:idx], d[idx+1:]...)
}
// Lookup searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Value if they key does not exist. To know if they key actually
// exists, use LookupErr.
func (d Doc) Lookup(key ...string) Val {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (d Doc) LookupErr(key ...string) (Val, error) {
elem, err := d.LookupElementErr(key...)
return elem.Value, err
}
// LookupElement searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Element if they key does not exist. To know if they key actually
// exists, use LookupElementErr.
func (d Doc) LookupElement(key ...string) Elem {
elem, _ := d.LookupElementErr(key...)
return elem
}
// LookupElementErr searches the document and potentially subdocuments for the
// provided key. Each key provided to this method represents a layer of depth.
func (d Doc) LookupElementErr(key ...string) (Elem, error) {
// KeyNotFound operates by being created where the error happens and then the depth is
// incremented by 1 as each function unwinds. Whenever this function returns, it also assigns
// the Key slice to the key slice it has. This ensures that the proper depth is identified and
// the proper keys.
if len(key) == 0 {
return Elem{}, KeyNotFound{Key: key}
}
var elem Elem
var err error
idx := d.IndexOf(key[0])
if idx == -1 {
return Elem{}, KeyNotFound{Key: key}
}
elem = d[idx]
if len(key) == 1 {
return elem, nil
}
switch elem.Value.Type() {
case bsontype.EmbeddedDocument:
switch tt := elem.Value.primitive.(type) {
case Doc:
elem, err = tt.LookupElementErr(key[1:]...)
case MDoc:
elem, err = tt.LookupElementErr(key[1:]...)
}
default:
return Elem{}, KeyNotFound{Type: elem.Value.Type()}
}
switch tt := err.(type) {
case KeyNotFound:
tt.Depth++
tt.Key = key
return Elem{}, tt
case nil:
return elem, nil
default:
return Elem{}, err // We can't actually hit this.
}
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
//
// This method will never return an error.
func (d Doc) MarshalBSONValue() (bsontype.Type, []byte, error) {
if d == nil {
// TODO: Should we do this?
return bsontype.Null, nil, nil
}
data, _ := d.MarshalBSON()
return bsontype.EmbeddedDocument, data, nil
}
// MarshalBSON implements the Marshaler interface.
//
// This method will never return an error.
func (d Doc) MarshalBSON() ([]byte, error) { return d.AppendMarshalBSON(nil) }
// AppendMarshalBSON marshals Doc to BSON bytes, appending to dst.
//
// This method will never return an error.
func (d Doc) AppendMarshalBSON(dst []byte) ([]byte, error) {
idx, dst := bsoncore.ReserveLength(dst)
for _, elem := range d {
t, data, _ := elem.Value.MarshalBSONValue() // Value.MarshalBSONValue never returns an error.
dst = append(dst, byte(t))
dst = append(dst, elem.Key...)
dst = append(dst, 0x00)
dst = append(dst, data...)
}
dst = append(dst, 0x00)
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst, nil
}
// UnmarshalBSON implements the Unmarshaler interface.
func (d *Doc) UnmarshalBSON(b []byte) error {
if d == nil {
return ErrNilDocument
}
if err := bsoncore.Document(b).Validate(); err != nil {
return err
}
elems, err := bsoncore.Document(b).Elements()
if err != nil {
return err
}
var val Val
for _, elem := range elems {
rawv := elem.Value()
err = val.UnmarshalBSONValue(rawv.Type, rawv.Data)
if err != nil {
return err
}
*d = d.Append(elem.Key(), val)
}
return nil
}
// UnmarshalBSONValue implements the bson.ValueUnmarshaler interface.
func (d *Doc) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
if t != bsontype.EmbeddedDocument {
return fmt.Errorf("cannot unmarshal %s into a bsonx.Doc", t)
}
return d.UnmarshalBSON(data)
}
// Equal compares this document to another, returning true if they are equal.
func (d Doc) Equal(id IDoc) bool {
switch tt := id.(type) {
case Doc:
d2 := tt
if len(d) != len(d2) {
return false
}
for idx := range d {
if !d[idx].Equal(d2[idx]) {
return false
}
}
case MDoc:
unique := make(map[string]struct{})
for _, elem := range d {
unique[elem.Key] = struct{}{}
val, ok := tt[elem.Key]
if !ok {
return false
}
if !val.Equal(elem.Value) {
return false
}
}
if len(unique) != len(tt) {
return false
}
case nil:
return d == nil
default:
return false
}
return true
}
// String implements the fmt.Stringer interface.
func (d Doc) String() string {
var buf bytes.Buffer
buf.Write([]byte("bson.Document{"))
for idx, elem := range d {
if idx > 0 {
buf.Write([]byte(", "))
}
fmt.Fprintf(&buf, "%v", elem)
}
buf.WriteByte('}')
return buf.String()
}
func (Doc) idoc() {}

View File

@@ -0,0 +1,295 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"fmt"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func ExampleDocument() {
internalVersion := "1234567"
f := func(appName string) Doc {
doc := Doc{
{"driver", Document(Doc{{"name", String("mongo-go-driver")}, {"version", String(internalVersion)}})},
{"os", Document(Doc{{"type", String("darwin")}, {"architecture", String("amd64")}})},
{"platform", String("go1.11.1")},
}
if appName != "" {
doc = append(doc, Elem{"application", Document(MDoc{"name": String(appName)})})
}
return doc
}
buf, err := f("hello-world").MarshalBSON()
if err != nil {
fmt.Println(err)
}
fmt.Println(buf)
// Output: [178 0 0 0 3 100 114 105 118 101 114 0 52 0 0 0 2 110 97 109 101 0 16 0 0 0 109 111 110 103 111 45 103 111 45 100 114 105 118 101 114 0 2 118 101 114 115 105 111 110 0 8 0 0 0 49 50 51 52 53 54 55 0 0 3 111 115 0 46 0 0 0 2 116 121 112 101 0 7 0 0 0 100 97 114 119 105 110 0 2 97 114 99 104 105 116 101 99 116 117 114 101 0 6 0 0 0 97 109 100 54 52 0 0 2 112 108 97 116 102 111 114 109 0 9 0 0 0 103 111 49 46 49 49 46 49 0 3 97 112 112 108 105 99 97 116 105 111 110 0 27 0 0 0 2 110 97 109 101 0 12 0 0 0 104 101 108 108 111 45 119 111 114 108 100 0 0 0]
}
func BenchmarkDocument(b *testing.B) {
b.ReportAllocs()
internalVersion := "1234567"
for i := 0; i < b.N; i++ {
doc := Doc{
{"driver", Document(Doc{{"name", String("mongo-go-driver")}, {"version", String(internalVersion)}})},
{"os", Document(Doc{{"type", String("darwin")}, {"architecture", String("amd64")}})},
{"platform", String("go1.11.1")},
}
_, _ = doc.MarshalBSON()
}
}
func TestDocument(t *testing.T) {
t.Parallel()
t.Run("ReadDocument", func(t *testing.T) {
t.Parallel()
t.Run("UnmarshalingError", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
invalid []byte
}{
{"base", []byte{0x01, 0x02}},
{"fuzzed1", []byte("0\x990\xc4")}, // fuzzed
{"fuzzed2", []byte("\x10\x00\x00\x00\x10\x000000\x0600\x00\x05\x00\xff\xff\xff\u007f")}, // fuzzed
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
want := bsoncore.NewInsufficientBytesError(nil, nil)
_, got := ReadDoc(tc.invalid)
if !compareErrors(got, want) {
t.Errorf("Expected errors to match. got %v; want %v", got, want)
}
})
}
})
t.Run("success", func(t *testing.T) {
t.Parallel()
valid := bsoncore.BuildDocument(nil, bsoncore.AppendNullElement(nil, "foobar"))
var want error
wantDoc := Doc{{"foobar", Null()}}
gotDoc, got := ReadDoc(valid)
if !compareErrors(got, want) {
t.Errorf("Expected errors to match. got %v; want %v", got, want)
}
if !cmp.Equal(gotDoc, wantDoc) {
t.Errorf("Expected returned documents to match. got %v; want %v", gotDoc, wantDoc)
}
})
})
t.Run("Copy", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
start Doc
copy Doc
}{
{"nil", nil, Doc{}},
{"not-nil", Doc{{"foobar", Null()}}, Doc{{"foobar", Null()}}},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
copy := tc.start.Copy()
if !cmp.Equal(copy, tc.copy) {
t.Errorf("Expected copies to be equal. got %v; want %v", copy, tc.copy)
}
})
}
})
testCases := []struct {
name string
fn interface{} // method to call
params []interface{} // parameters
rets []interface{} // returns
}{
{
"Append", Doc{}.Append,
[]interface{}{"foo", Null()},
[]interface{}{Doc{{"foo", Null()}}},
},
{
"Prepend", Doc{}.Prepend,
[]interface{}{"foo", Null()},
[]interface{}{Doc{{"foo", Null()}}},
},
{
"Set/append", Doc{{"foo", Null()}}.Set,
[]interface{}{"bar", Null()},
[]interface{}{Doc{{"foo", Null()}, {"bar", Null()}}},
},
{
"Set/replace", Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}.Set,
[]interface{}{"bar", Int64(1234567890)},
[]interface{}{Doc{{"foo", Null()}, {"bar", Int64(1234567890)}, {"baz", Double(3.14159)}}},
},
{
"Delete/doesn't exist", Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}.Delete,
[]interface{}{"qux"},
[]interface{}{Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}},
},
{
"Delete/exists", Doc{{"foo", Null()}, {"bar", Null()}, {"baz", Double(3.14159)}}.Delete,
[]interface{}{"bar"},
[]interface{}{Doc{{"foo", Null()}, {"baz", Double(3.14159)}}},
},
{
"Lookup/err", Doc{}.Lookup,
[]interface{}{[]string{}},
[]interface{}{Val{}},
},
{
"Lookup/success", Doc{{"pi", Double(3.14159)}}.Lookup,
[]interface{}{[]string{"pi"}},
[]interface{}{Double(3.14159)},
},
{
"LookupErr/err", Doc{}.LookupErr,
[]interface{}{[]string{}},
[]interface{}{Val{}, KeyNotFound{Key: []string{}}},
},
{
"LookupErr/success", Doc{{"pi", Double(3.14159)}}.LookupErr,
[]interface{}{[]string{"pi"}},
[]interface{}{Double(3.14159), error(nil)},
},
{
"LookupElem/err", Doc{}.LookupElement,
[]interface{}{[]string{}},
[]interface{}{Elem{}},
},
{
"LookupElem/success", Doc{{"pi", Double(3.14159)}}.LookupElement,
[]interface{}{[]string{"pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}},
},
{
"LookupElementErr/zero length key", Doc{}.LookupElementErr,
[]interface{}{[]string{}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{}}},
},
{
"LookupElementErr/key not found", Doc{}.LookupElementErr,
[]interface{}{[]string{"foo"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo"}}},
},
{
"LookupElementErr/key not found/depth 2", Doc{{"foo", Document(Doc{})}}.LookupElementErr,
[]interface{}{[]string{"foo", "bar"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "bar"}, Depth: 1}},
},
{
"LookupElementErr/invalid depth 2 type", Doc{{"foo", Document(Doc{{"pi", Double(3.14159)}})}}.LookupElementErr,
[]interface{}{[]string{"foo", "pi", "baz"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "pi", "baz"}, Depth: 1, Type: bsontype.Double}},
},
{
"LookupElementErr/success", Doc{{"pi", Double(3.14159)}}.LookupElementErr,
[]interface{}{[]string{"pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
},
{
"LookupElementErr/success/depth 2", Doc{{"foo", Document(Doc{{"pi", Double(3.14159)}})}}.LookupElementErr,
[]interface{}{[]string{"foo", "pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
},
{
"MarshalBSONValue/nil", Doc(nil).MarshalBSONValue,
nil,
[]interface{}{bsontype.Null, []byte(nil), error(nil)},
},
{
"MarshalBSONValue/success", Doc{{"pi", Double(3.14159)}}.MarshalBSONValue, nil,
[]interface{}{
bsontype.EmbeddedDocument,
bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)),
error(nil),
},
},
{
"MarshalBSON", Doc{{"pi", Double(3.14159)}}.MarshalBSON, nil,
[]interface{}{bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
},
{
"MarshalBSON/empty", Doc{}.MarshalBSON, nil,
[]interface{}{bsoncore.BuildDocument(nil, nil), error(nil)},
},
{
"AppendMarshalBSON", Doc{{"pi", Double(3.14159)}}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
},
{
"AppendMarshalBSON/empty", Doc{}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, nil), error(nil)},
},
{"Equal/IDoc nil", Doc(nil).Equal, []interface{}{IDoc(nil)}, []interface{}{true}},
{"Equal/MDoc nil", Doc(nil).Equal, []interface{}{MDoc(nil)}, []interface{}{true}},
{"Equal/Doc/different length", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{Doc{}}, []interface{}{false}},
{"Equal/Doc/elems not equal", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{Doc{{"pi", Int32(1)}}}, []interface{}{false}},
{"Equal/Doc/success", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{Doc{{"pi", Double(3.14159)}}}, []interface{}{true}},
{"Equal/MDoc/elems not equal", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"pi": Int32(1)}}, []interface{}{false}},
{"Equal/MDoc/elems not found", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"foo": Int32(1)}}, []interface{}{false}},
{
"Equal/MDoc/duplicate",
Doc{{"a", Int32(1)}, {"a", Int32(1)}}.Equal, []interface{}{MDoc{"a": Int32(1), "b": Int32(2)}},
[]interface{}{false},
},
{"Equal/MDoc/success", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"pi": Double(3.14159)}}, []interface{}{true}},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("property fn must be a function, but it is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params) && !fn.Type().IsVariadic() {
t.Fatalf("number of parameters does not match. fn takes %d, but was provided %d", fn.Type().NumIn(), len(tc.params))
}
params := make([]reflect.Value, 0, len(tc.params))
for idx, param := range tc.params {
if param == nil {
params = append(params, reflect.New(fn.Type().In(idx)).Elem())
continue
}
params = append(params, reflect.ValueOf(param))
}
var rets []reflect.Value
if fn.Type().IsVariadic() {
rets = fn.CallSlice(params)
} else {
rets = fn.Call(params)
}
if len(rets) != len(tc.rets) {
t.Fatalf("mismatched number of returns. received %d; expected %d", len(rets), len(tc.rets))
}
for idx := range rets {
got, want := rets[idx].Interface(), tc.rets[idx]
if !cmp.Equal(got, want) {
t.Errorf("Return %d does not match. got %v; want %v", idx, got, want)
}
}
})
}
}

51
mongo/x/bsonx/element.go Normal file
View File

@@ -0,0 +1,51 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ElementTypeError specifies that a method to obtain a BSON value an incorrect type was called on a bson.Value.
//
// TODO: rename this ValueTypeError.
type ElementTypeError struct {
Method string
Type bsontype.Type
}
// Error implements the error interface.
func (ete ElementTypeError) Error() string {
return "Call of " + ete.Method + " on " + ete.Type.String() + " type"
}
// Elem represents a BSON element.
//
// NOTE: Element cannot be the value of a map nor a property of a struct without special handling.
// The default encoders and decoders will not process Element correctly. To do so would require
// information loss since an Element contains a key, but the keys used when encoding a struct are
// the struct field names. Instead of using an Element, use a Value as a value in a map or a
// property of a struct.
type Elem struct {
Key string
Value Val
}
// Equal compares e and e2 and returns true if they are equal.
func (e Elem) Equal(e2 Elem) bool {
if e.Key != e2.Key {
return false
}
return e.Value.Equal(e2.Value)
}
func (e Elem) String() string {
// TODO(GODRIVER-612): When bsoncore has appenders for extended JSON use that here.
return fmt.Sprintf(`bson.Element{"%s": %v}`, e.Key, e.Value)
}

View File

@@ -0,0 +1,7 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx

231
mongo/x/bsonx/mdocument.go Normal file
View File

@@ -0,0 +1,231 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"bytes"
"fmt"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// MDoc is an unordered, type safe, concise BSON document representation. This type should not be
// used if you require ordering of values or duplicate keys.
type MDoc map[string]Val
// ReadMDoc will create a Doc using the provided slice of bytes. If the
// slice of bytes is not a valid BSON document, this method will return an error.
func ReadMDoc(b []byte) (MDoc, error) {
doc := make(MDoc)
err := doc.UnmarshalBSON(b)
if err != nil {
return nil, err
}
return doc, nil
}
// Copy makes a shallow copy of this document.
func (d MDoc) Copy() MDoc {
d2 := make(MDoc, len(d))
for k, v := range d {
d2[k] = v
}
return d2
}
// Lookup searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Value if they key does not exist. To know if they key actually
// exists, use LookupErr.
func (d MDoc) Lookup(key ...string) Val {
val, _ := d.LookupErr(key...)
return val
}
// LookupErr searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
func (d MDoc) LookupErr(key ...string) (Val, error) {
elem, err := d.LookupElementErr(key...)
return elem.Value, err
}
// LookupElement searches the document and potentially subdocuments or arrays for the
// provided key. Each key provided to this method represents a layer of depth.
//
// This method will return an empty Element if they key does not exist. To know if they key actually
// exists, use LookupElementErr.
func (d MDoc) LookupElement(key ...string) Elem {
elem, _ := d.LookupElementErr(key...)
return elem
}
// LookupElementErr searches the document and potentially subdocuments for the
// provided key. Each key provided to this method represents a layer of depth.
func (d MDoc) LookupElementErr(key ...string) (Elem, error) {
// KeyNotFound operates by being created where the error happens and then the depth is
// incremented by 1 as each function unwinds. Whenever this function returns, it also assigns
// the Key slice to the key slice it has. This ensures that the proper depth is identified and
// the proper keys.
if len(key) == 0 {
return Elem{}, KeyNotFound{Key: key}
}
var elem Elem
var err error
val, ok := d[key[0]]
if !ok {
return Elem{}, KeyNotFound{Key: key}
}
if len(key) == 1 {
return Elem{Key: key[0], Value: val}, nil
}
switch val.Type() {
case bsontype.EmbeddedDocument:
switch tt := val.primitive.(type) {
case Doc:
elem, err = tt.LookupElementErr(key[1:]...)
case MDoc:
elem, err = tt.LookupElementErr(key[1:]...)
}
default:
return Elem{}, KeyNotFound{Type: val.Type()}
}
switch tt := err.(type) {
case KeyNotFound:
tt.Depth++
tt.Key = key
return Elem{}, tt
case nil:
return elem, nil
default:
return Elem{}, err // We can't actually hit this.
}
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
//
// This method will never return an error.
func (d MDoc) MarshalBSONValue() (bsontype.Type, []byte, error) {
if d == nil {
// TODO: Should we do this?
return bsontype.Null, nil, nil
}
data, _ := d.MarshalBSON()
return bsontype.EmbeddedDocument, data, nil
}
// MarshalBSON implements the Marshaler interface.
//
// This method will never return an error.
func (d MDoc) MarshalBSON() ([]byte, error) { return d.AppendMarshalBSON(nil) }
// AppendMarshalBSON marshals Doc to BSON bytes, appending to dst.
//
// This method will never return an error.
func (d MDoc) AppendMarshalBSON(dst []byte) ([]byte, error) {
idx, dst := bsoncore.ReserveLength(dst)
for k, v := range d {
t, data, _ := v.MarshalBSONValue() // Value.MarshalBSONValue never returns an error.
dst = append(dst, byte(t))
dst = append(dst, k...)
dst = append(dst, 0x00)
dst = append(dst, data...)
}
dst = append(dst, 0x00)
dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
return dst, nil
}
// UnmarshalBSON implements the Unmarshaler interface.
func (d *MDoc) UnmarshalBSON(b []byte) error {
if d == nil {
return ErrNilDocument
}
if err := bsoncore.Document(b).Validate(); err != nil {
return err
}
elems, err := bsoncore.Document(b).Elements()
if err != nil {
return err
}
var val Val
for _, elem := range elems {
rawv := elem.Value()
err = val.UnmarshalBSONValue(rawv.Type, rawv.Data)
if err != nil {
return err
}
(*d)[elem.Key()] = val
}
return nil
}
// Equal compares this document to another, returning true if they are equal.
func (d MDoc) Equal(id IDoc) bool {
switch tt := id.(type) {
case MDoc:
d2 := tt
if len(d) != len(d2) {
return false
}
for key, value := range d {
value2, ok := d2[key]
if !ok {
return false
}
if !value.Equal(value2) {
return false
}
}
case Doc:
unique := make(map[string]struct{})
for _, elem := range tt {
unique[elem.Key] = struct{}{}
val, ok := d[elem.Key]
if !ok {
return false
}
if !val.Equal(elem.Value) {
return false
}
}
if len(unique) != len(d) {
return false
}
case nil:
return d == nil
default:
return false
}
return true
}
// String implements the fmt.Stringer interface.
func (d MDoc) String() string {
var buf bytes.Buffer
buf.Write([]byte("bson.Document{"))
first := true
for key, value := range d {
if !first {
buf.Write([]byte(", "))
}
fmt.Fprintf(&buf, "%v", Elem{Key: key, Value: value})
first = false
}
buf.WriteByte('}')
return buf.String()
}
func (MDoc) idoc() {}

View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestMDoc(t *testing.T) {
t.Parallel()
t.Run("ReadMDoc", func(t *testing.T) {
t.Parallel()
t.Run("UnmarshalingError", func(t *testing.T) {
t.Parallel()
invalid := []byte{0x01, 0x02}
want := bsoncore.NewInsufficientBytesError(nil, nil)
_, got := ReadMDoc(invalid)
if !compareErrors(got, want) {
t.Errorf("Expected errors to match. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
t.Parallel()
valid := bsoncore.BuildDocument(nil, bsoncore.AppendNullElement(nil, "foobar"))
var want error
wantDoc := MDoc{"foobar": Null()}
gotDoc, got := ReadMDoc(valid)
if !compareErrors(got, want) {
t.Errorf("Expected errors to match. got %v; want %v", got, want)
}
if !cmp.Equal(gotDoc, wantDoc) {
t.Errorf("Expected returned documents to match. got %v; want %v", gotDoc, wantDoc)
}
})
})
t.Run("Copy", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
start MDoc
copy MDoc
}{
{"nil", nil, MDoc{}},
{"not-nil", MDoc{"foobar": Null()}, MDoc{"foobar": Null()}},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
copy := tc.start.Copy()
if !cmp.Equal(copy, tc.copy) {
t.Errorf("Expected copies to be equal. got %v; want %v", copy, tc.copy)
}
})
}
})
testCases := []struct {
name string
fn interface{} // method to call
params []interface{} // parameters
rets []interface{} // returns
}{
{
"Lookup/err", MDoc{}.Lookup,
[]interface{}{[]string{}},
[]interface{}{Val{}},
},
{
"Lookup/success", MDoc{"pi": Double(3.14159)}.Lookup,
[]interface{}{[]string{"pi"}},
[]interface{}{Double(3.14159)},
},
{
"LookupErr/err", MDoc{}.LookupErr,
[]interface{}{[]string{}},
[]interface{}{Val{}, KeyNotFound{Key: []string{}}},
},
{
"LookupErr/success", MDoc{"pi": Double(3.14159)}.LookupErr,
[]interface{}{[]string{"pi"}},
[]interface{}{Double(3.14159), error(nil)},
},
{
"LookupElem/err", MDoc{}.LookupElement,
[]interface{}{[]string{}},
[]interface{}{Elem{}},
},
{
"LookupElem/success", MDoc{"pi": Double(3.14159)}.LookupElement,
[]interface{}{[]string{"pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}},
},
{
"LookupElementErr/zero length key", MDoc{}.LookupElementErr,
[]interface{}{[]string{}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{}}},
},
{
"LookupElementErr/key not found", MDoc{}.LookupElementErr,
[]interface{}{[]string{"foo"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo"}}},
},
{
"LookupElementErr/key not found/depth 2", MDoc{"foo": Document(Doc{})}.LookupElementErr,
[]interface{}{[]string{"foo", "bar"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "bar"}, Depth: 1}},
},
{
"LookupElementErr/invalid depth 2 type", MDoc{"foo": Document(MDoc{"pi": Double(3.14159)})}.LookupElementErr,
[]interface{}{[]string{"foo", "pi", "baz"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "pi", "baz"}, Depth: 1, Type: bsontype.Double}},
},
{
"LookupElementErr/invalid depth 2 type (Doc)", MDoc{"foo": Document(Doc{{"pi", Double(3.14159)}})}.LookupElementErr,
[]interface{}{[]string{"foo", "pi", "baz"}},
[]interface{}{Elem{}, KeyNotFound{Key: []string{"foo", "pi", "baz"}, Depth: 1, Type: bsontype.Double}},
},
{
"LookupElementErr/success", MDoc{"pi": Double(3.14159)}.LookupElementErr,
[]interface{}{[]string{"pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
},
{
"LookupElementErr/success/depth 2 (Doc)", MDoc{"foo": Document(Doc{{"pi", Double(3.14159)}})}.LookupElementErr,
[]interface{}{[]string{"foo", "pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
},
{
"LookupElementErr/success/depth 2", MDoc{"foo": Document(MDoc{"pi": Double(3.14159)})}.LookupElementErr,
[]interface{}{[]string{"foo", "pi"}},
[]interface{}{Elem{"pi", Double(3.14159)}, error(nil)},
},
{
"MarshalBSONValue/nil", MDoc(nil).MarshalBSONValue,
nil,
[]interface{}{bsontype.Null, []byte(nil), error(nil)},
},
{
"MarshalBSONValue/success", MDoc{"pi": Double(3.14159)}.MarshalBSONValue, nil,
[]interface{}{
bsontype.EmbeddedDocument,
bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)),
error(nil),
},
},
{
"MarshalBSON", MDoc{"pi": Double(3.14159)}.MarshalBSON, nil,
[]interface{}{bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
},
{
"MarshalBSON/empty", MDoc{}.MarshalBSON, nil,
[]interface{}{bsoncore.BuildDocument(nil, nil), error(nil)},
},
{
"AppendMarshalBSON", MDoc{"pi": Double(3.14159)}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)), error(nil)},
},
{
"AppendMarshalBSON/empty", MDoc{}.AppendMarshalBSON, []interface{}{[]byte{0x01, 0x02, 0x03}},
[]interface{}{bsoncore.BuildDocument([]byte{0x01, 0x02, 0x03}, nil), error(nil)},
},
{"Equal/IDoc nil", MDoc(nil).Equal, []interface{}{IDoc(nil)}, []interface{}{true}},
{"Equal/MDoc nil", MDoc(nil).Equal, []interface{}{Doc(nil)}, []interface{}{true}},
{"Equal/Doc/different length", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{Doc{}}, []interface{}{false}},
{"Equal/Doc/elems not equal", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{Doc{{"pi", Int32(1)}}}, []interface{}{false}},
{"Equal/Doc/success", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{Doc{{"pi", Double(3.14159)}}}, []interface{}{true}},
{"Equal/MDoc/elems not equal", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{MDoc{"pi": Int32(1)}}, []interface{}{false}},
{"Equal/MDoc/elems not found", MDoc{"pi": Double(3.14159)}.Equal, []interface{}{MDoc{"foo": Int32(1)}}, []interface{}{false}},
{
"Equal/MDoc/duplicate",
Doc{{"a", Int32(1)}, {"a", Int32(1)}}.Equal, []interface{}{MDoc{"a": Int32(1), "b": Int32(2)}},
[]interface{}{false},
},
{"Equal/MDoc/success", Doc{{"pi", Double(3.14159)}}.Equal, []interface{}{MDoc{"pi": Double(3.14159)}}, []interface{}{true}},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("property fn must be a function, but it is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params) && !fn.Type().IsVariadic() {
t.Fatalf("number of parameters does not match. fn takes %d, but was provided %d", fn.Type().NumIn(), len(tc.params))
}
params := make([]reflect.Value, 0, len(tc.params))
for idx, param := range tc.params {
if param == nil {
params = append(params, reflect.New(fn.Type().In(idx)).Elem())
continue
}
params = append(params, reflect.ValueOf(param))
}
var rets []reflect.Value
if fn.Type().IsVariadic() {
rets = fn.CallSlice(params)
} else {
rets = fn.Call(params)
}
if len(rets) != len(tc.rets) {
t.Fatalf("mismatched number of returns. received %d; expected %d", len(rets), len(tc.rets))
}
for idx := range rets {
got, want := rets[idx].Interface(), tc.rets[idx]
if !cmp.Equal(got, want) {
t.Errorf("Return %d does not match. got %v; want %v", idx, got, want)
}
}
})
}
}

View File

@@ -0,0 +1,637 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"errors"
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var primitiveCodecs PrimitiveCodecs
var tDocument = reflect.TypeOf((Doc)(nil))
var tArray = reflect.TypeOf((Arr)(nil))
var tValue = reflect.TypeOf(Val{})
var tElementSlice = reflect.TypeOf(([]Elem)(nil))
// PrimitiveCodecs is a namespace for all of the default bsoncodec.Codecs for the primitive types
// defined in this package.
type PrimitiveCodecs struct{}
// RegisterPrimitiveCodecs will register the encode and decode methods attached to PrimitiveCodecs
// with the provided RegistryBuilder. if rb is nil, a new empty RegistryBuilder will be created.
func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) {
if rb == nil {
panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil"))
}
rb.
RegisterTypeEncoder(tDocument, bsoncodec.ValueEncoderFunc(pc.DocumentEncodeValue)).
RegisterTypeEncoder(tArray, bsoncodec.ValueEncoderFunc(pc.ArrayEncodeValue)).
RegisterTypeEncoder(tValue, bsoncodec.ValueEncoderFunc(pc.ValueEncodeValue)).
RegisterTypeEncoder(tElementSlice, bsoncodec.ValueEncoderFunc(pc.ElementSliceEncodeValue)).
RegisterTypeDecoder(tDocument, bsoncodec.ValueDecoderFunc(pc.DocumentDecodeValue)).
RegisterTypeDecoder(tArray, bsoncodec.ValueDecoderFunc(pc.ArrayDecodeValue)).
RegisterTypeDecoder(tValue, bsoncodec.ValueDecoderFunc(pc.ValueDecodeValue)).
RegisterTypeDecoder(tElementSlice, bsoncodec.ValueDecoderFunc(pc.ElementSliceDecodeValue))
}
// DocumentEncodeValue is the ValueEncoderFunc for *Document.
func (pc PrimitiveCodecs) DocumentEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDocument {
return bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
doc := val.Interface().(Doc)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return pc.encodeDocument(ec, dw, doc)
}
// DocumentDecodeValue is the ValueDecoderFunc for *Document.
func (pc PrimitiveCodecs) DocumentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tDocument {
return bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
}
return pc.documentDecodeValue(dctx, vr, val.Addr().Interface().(*Doc))
}
func (pc PrimitiveCodecs) documentDecodeValue(dctx bsoncodec.DecodeContext, vr bsonrw.ValueReader, doc *Doc) error {
dr, err := vr.ReadDocument()
if err != nil {
return err
}
return pc.decodeDocument(dctx, dr, doc)
}
// ArrayEncodeValue is the ValueEncoderFunc for *Array.
func (pc PrimitiveCodecs) ArrayEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tArray {
return bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
arr := val.Interface().(Arr)
aw, err := vw.WriteArray()
if err != nil {
return err
}
for _, val := range arr {
dvw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = pc.encodeValue(ec, dvw, val)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// ArrayDecodeValue is the ValueDecoderFunc for *Array.
func (pc PrimitiveCodecs) ArrayDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tArray {
return bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
}
ar, err := vr.ReadArray()
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeSlice(tArray, 0, 0))
}
val.SetLen(0)
for {
vr, err := ar.ReadValue()
if err == bsonrw.ErrEOA {
break
}
if err != nil {
return err
}
var elem Val
err = pc.valueDecodeValue(dc, vr, &elem)
if err != nil {
return err
}
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
}
return nil
}
// ElementSliceEncodeValue is the ValueEncoderFunc for []*Element.
func (pc PrimitiveCodecs) ElementSliceEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tElementSlice {
return bsoncodec.ValueEncoderError{Name: "ElementSliceEncodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
return pc.DocumentEncodeValue(ec, vw, val.Convert(tDocument))
}
// ElementSliceDecodeValue is the ValueDecoderFunc for []*Element.
func (pc PrimitiveCodecs) ElementSliceDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tElementSlice {
return bsoncodec.ValueDecoderError{Name: "ElementSliceDecodeValue", Types: []reflect.Type{tElementSlice}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
dr, err := vr.ReadDocument()
if err != nil {
return err
}
elems := make([]reflect.Value, 0)
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
var elem Elem
err = pc.elementDecodeValue(dc, vr, key, &elem)
if err != nil {
return err
}
elems = append(elems, reflect.ValueOf(elem))
}
val.Set(reflect.Append(val, elems...))
return nil
}
// ValueEncodeValue is the ValueEncoderFunc for *Value.
func (pc PrimitiveCodecs) ValueEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tValue {
return bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: val}
}
v := val.Interface().(Val)
return pc.encodeValue(ec, vw, v)
}
// ValueDecodeValue is the ValueDecoderFunc for *Value.
func (pc PrimitiveCodecs) ValueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tValue {
return bsoncodec.ValueDecoderError{Name: "ValueDecodeValue", Types: []reflect.Type{tValue}, Received: val}
}
return pc.valueDecodeValue(dc, vr, val.Addr().Interface().(*Val))
}
// encodeDocument is a separate function that we use because CodeWithScope
// returns us a DocumentWriter and we need to do the same logic that we would do
// for a document but cannot use a Codec.
func (pc PrimitiveCodecs) encodeDocument(ec bsoncodec.EncodeContext, dw bsonrw.DocumentWriter, doc Doc) error {
for _, elem := range doc {
dvw, err := dw.WriteDocumentElement(elem.Key)
if err != nil {
return err
}
err = pc.encodeValue(ec, dvw, elem.Value)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// DecodeDocument haves decoding into a Doc from a bsonrw.DocumentReader.
func (pc PrimitiveCodecs) DecodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
return pc.decodeDocument(dctx, dr, pdoc)
}
func (pc PrimitiveCodecs) decodeDocument(dctx bsoncodec.DecodeContext, dr bsonrw.DocumentReader, pdoc *Doc) error {
if *pdoc == nil {
*pdoc = make(Doc, 0)
}
*pdoc = (*pdoc)[:0]
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
var elem Elem
err = pc.elementDecodeValue(dctx, vr, key, &elem)
if err != nil {
return err
}
*pdoc = append(*pdoc, elem)
}
return nil
}
func (pc PrimitiveCodecs) elementDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, key string, elem *Elem) error {
var val Val
switch vr.Type() {
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return err
}
val = Double(f64)
case bsontype.String:
str, err := vr.ReadString()
if err != nil {
return err
}
val = String(str)
case bsontype.EmbeddedDocument:
var embeddedDoc Doc
err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
if err != nil {
return err
}
val = Document(embeddedDoc)
case bsontype.Array:
arr := reflect.New(tArray).Elem()
err := pc.ArrayDecodeValue(dc, vr, arr)
if err != nil {
return err
}
val = Array(arr.Interface().(Arr))
case bsontype.Binary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
val = Binary(subtype, data)
case bsontype.Undefined:
err := vr.ReadUndefined()
if err != nil {
return err
}
val = Undefined()
case bsontype.ObjectID:
oid, err := vr.ReadObjectID()
if err != nil {
return err
}
val = ObjectID(oid)
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return err
}
val = Boolean(b)
case bsontype.DateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return err
}
val = DateTime(dt)
case bsontype.Null:
err := vr.ReadNull()
if err != nil {
return err
}
val = Null()
case bsontype.Regex:
pattern, options, err := vr.ReadRegex()
if err != nil {
return err
}
val = Regex(pattern, options)
case bsontype.DBPointer:
ns, pointer, err := vr.ReadDBPointer()
if err != nil {
return err
}
val = DBPointer(ns, pointer)
case bsontype.JavaScript:
js, err := vr.ReadJavascript()
if err != nil {
return err
}
val = JavaScript(js)
case bsontype.Symbol:
symbol, err := vr.ReadSymbol()
if err != nil {
return err
}
val = Symbol(symbol)
case bsontype.CodeWithScope:
code, scope, err := vr.ReadCodeWithScope()
if err != nil {
return err
}
var doc Doc
err = pc.decodeDocument(dc, scope, &doc)
if err != nil {
return err
}
val = CodeWithScope(code, doc)
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return err
}
val = Int32(i32)
case bsontype.Timestamp:
t, i, err := vr.ReadTimestamp()
if err != nil {
return err
}
val = Timestamp(t, i)
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return err
}
val = Int64(i64)
case bsontype.Decimal128:
d128, err := vr.ReadDecimal128()
if err != nil {
return err
}
val = Decimal128(d128)
case bsontype.MinKey:
err := vr.ReadMinKey()
if err != nil {
return err
}
val = MinKey()
case bsontype.MaxKey:
err := vr.ReadMaxKey()
if err != nil {
return err
}
val = MaxKey()
default:
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
}
*elem = Elem{Key: key, Value: val}
return nil
}
// encodeValue does not validation, and the callers must perform validation on val before calling
// this method.
func (pc PrimitiveCodecs) encodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val Val) error {
var err error
switch val.Type() {
case bsontype.Double:
err = vw.WriteDouble(val.Double())
case bsontype.String:
err = vw.WriteString(val.StringValue())
case bsontype.EmbeddedDocument:
var encoder bsoncodec.ValueEncoder
encoder, err = ec.LookupEncoder(tDocument)
if err != nil {
break
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Document()))
case bsontype.Array:
var encoder bsoncodec.ValueEncoder
encoder, err = ec.LookupEncoder(tArray)
if err != nil {
break
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(val.Array()))
case bsontype.Binary:
// TODO: FIX THIS (╯°□°)╯︵ ┻━┻
subtype, data := val.Binary()
err = vw.WriteBinaryWithSubtype(data, subtype)
case bsontype.Undefined:
err = vw.WriteUndefined()
case bsontype.ObjectID:
err = vw.WriteObjectID(val.ObjectID())
case bsontype.Boolean:
err = vw.WriteBoolean(val.Boolean())
case bsontype.DateTime:
err = vw.WriteDateTime(val.DateTime())
case bsontype.Null:
err = vw.WriteNull()
case bsontype.Regex:
err = vw.WriteRegex(val.Regex())
case bsontype.DBPointer:
err = vw.WriteDBPointer(val.DBPointer())
case bsontype.JavaScript:
err = vw.WriteJavascript(val.JavaScript())
case bsontype.Symbol:
err = vw.WriteSymbol(val.Symbol())
case bsontype.CodeWithScope:
code, scope := val.CodeWithScope()
var cwsw bsonrw.DocumentWriter
cwsw, err = vw.WriteCodeWithScope(code)
if err != nil {
break
}
err = pc.encodeDocument(ec, cwsw, scope)
case bsontype.Int32:
err = vw.WriteInt32(val.Int32())
case bsontype.Timestamp:
err = vw.WriteTimestamp(val.Timestamp())
case bsontype.Int64:
err = vw.WriteInt64(val.Int64())
case bsontype.Decimal128:
err = vw.WriteDecimal128(val.Decimal128())
case bsontype.MinKey:
err = vw.WriteMinKey()
case bsontype.MaxKey:
err = vw.WriteMaxKey()
default:
err = fmt.Errorf("%T is not a valid BSON type to encode", val.Type())
}
return err
}
func (pc PrimitiveCodecs) valueDecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val *Val) error {
switch vr.Type() {
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return err
}
*val = Double(f64)
case bsontype.String:
str, err := vr.ReadString()
if err != nil {
return err
}
*val = String(str)
case bsontype.EmbeddedDocument:
var embeddedDoc Doc
err := pc.documentDecodeValue(dc, vr, &embeddedDoc)
if err != nil {
return err
}
*val = Document(embeddedDoc)
case bsontype.Array:
arr := reflect.New(tArray).Elem()
err := pc.ArrayDecodeValue(dc, vr, arr)
if err != nil {
return err
}
*val = Array(arr.Interface().(Arr))
case bsontype.Binary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
*val = Binary(subtype, data)
case bsontype.Undefined:
err := vr.ReadUndefined()
if err != nil {
return err
}
*val = Undefined()
case bsontype.ObjectID:
oid, err := vr.ReadObjectID()
if err != nil {
return err
}
*val = ObjectID(oid)
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return err
}
*val = Boolean(b)
case bsontype.DateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return err
}
*val = DateTime(dt)
case bsontype.Null:
err := vr.ReadNull()
if err != nil {
return err
}
*val = Null()
case bsontype.Regex:
pattern, options, err := vr.ReadRegex()
if err != nil {
return err
}
*val = Regex(pattern, options)
case bsontype.DBPointer:
ns, pointer, err := vr.ReadDBPointer()
if err != nil {
return err
}
*val = DBPointer(ns, pointer)
case bsontype.JavaScript:
js, err := vr.ReadJavascript()
if err != nil {
return err
}
*val = JavaScript(js)
case bsontype.Symbol:
symbol, err := vr.ReadSymbol()
if err != nil {
return err
}
*val = Symbol(symbol)
case bsontype.CodeWithScope:
code, scope, err := vr.ReadCodeWithScope()
if err != nil {
return err
}
var scopeDoc Doc
err = pc.decodeDocument(dc, scope, &scopeDoc)
if err != nil {
return err
}
*val = CodeWithScope(code, scopeDoc)
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return err
}
*val = Int32(i32)
case bsontype.Timestamp:
t, i, err := vr.ReadTimestamp()
if err != nil {
return err
}
*val = Timestamp(t, i)
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return err
}
*val = Int64(i64)
case bsontype.Decimal128:
d128, err := vr.ReadDecimal128()
if err != nil {
return err
}
*val = Decimal128(d128)
case bsontype.MinKey:
err := vr.ReadMinKey()
if err != nil {
return err
}
*val = MinKey()
case bsontype.MaxKey:
err := vr.ReadMaxKey()
if err != nil {
return err
}
*val = MaxKey()
default:
return fmt.Errorf("Cannot read unknown BSON type %s", vr.Type())
}
return nil
}

View File

@@ -0,0 +1,910 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"errors"
"fmt"
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestDefaultValueEncoders(t *testing.T) {
var pcx PrimitiveCodecs
var wrong = func(string, string) string { return "wrong" }
type subtest struct {
name string
val interface{}
ectx *bsoncodec.EncodeContext
llvrw *bsonrwtest.ValueReaderWriter
invoke bsonrwtest.Invoked
err error
}
testCases := []struct {
name string
ve bsoncodec.ValueEncoder
subtests []subtest
}{
{
"ValueEncodeValue",
bsoncodec.ValueEncoderFunc(pcx.ValueEncodeValue),
[]subtest{
{
"wrong type",
wrong,
nil,
nil,
bsonrwtest.Nothing,
bsoncodec.ValueEncoderError{Name: "ValueEncodeValue", Types: []reflect.Type{tValue}, Received: reflect.ValueOf(wrong)},
},
{"empty value", Val{}, nil, nil, bsonrwtest.WriteNull, nil},
{
"success",
Null(),
&bsoncodec.EncodeContext{Registry: DefaultRegistry},
&bsonrwtest.ValueReaderWriter{},
bsonrwtest.WriteNull,
nil,
},
},
},
{
"ElementSliceEncodeValue",
bsoncodec.ValueEncoderFunc(pcx.ElementSliceEncodeValue),
[]subtest{
{
"wrong type",
wrong,
nil,
nil,
bsonrwtest.Nothing,
bsoncodec.ValueEncoderError{
Name: "ElementSliceEncodeValue",
Types: []reflect.Type{tElementSlice},
Received: reflect.ValueOf(wrong),
},
},
},
},
{
"ArrayEncodeValue",
bsoncodec.ValueEncoderFunc(pcx.ArrayEncodeValue),
[]subtest{
{
"wrong type",
wrong,
nil,
nil,
bsonrwtest.Nothing,
bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: reflect.ValueOf(wrong)},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, subtest := range tc.subtests {
t.Run(subtest.name, func(t *testing.T) {
var ec bsoncodec.EncodeContext
if subtest.ectx != nil {
ec = *subtest.ectx
}
llvrw := new(bsonrwtest.ValueReaderWriter)
if subtest.llvrw != nil {
llvrw = subtest.llvrw
}
llvrw.T = t
err := tc.ve.EncodeValue(ec, llvrw, reflect.ValueOf(subtest.val))
if !compareErrors(err, subtest.err) {
t.Errorf("Errors do not match. got %v; want %v", err, subtest.err)
}
invoked := llvrw.Invoked
if !cmp.Equal(invoked, subtest.invoke) {
t.Errorf("Incorrect method invoked. got %v; want %v", invoked, subtest.invoke)
}
})
}
})
}
t.Run("DocumentEncodeValue", func(t *testing.T) {
t.Run("ValueEncoderError", func(t *testing.T) {
val := reflect.ValueOf(bool(true))
want := bsoncodec.ValueEncoderError{Name: "DocumentEncodeValue", Types: []reflect.Type{tDocument}, Received: val}
got := (PrimitiveCodecs{}).DocumentEncodeValue(bsoncodec.EncodeContext{}, nil, val)
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("WriteDocument Error", func(t *testing.T) {
want := errors.New("WriteDocument Error")
llvrw := &bsonrwtest.ValueReaderWriter{
T: t,
Err: want,
ErrAfter: bsonrwtest.WriteDocument,
}
got := (PrimitiveCodecs{}).DocumentEncodeValue(bsoncodec.EncodeContext{}, llvrw, reflect.MakeSlice(tDocument, 0, 0))
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("encodeDocument errors", func(t *testing.T) {
ec := bsoncodec.EncodeContext{}
err := errors.New("encodeDocument error")
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
testCases := []struct {
name string
ec bsoncodec.EncodeContext
llvrw *bsonrwtest.ValueReaderWriter
doc Doc
err error
}{
{
"WriteDocumentElement",
ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("wde error"), ErrAfter: bsonrwtest.WriteDocumentElement},
Doc{{"foo", Null()}},
errors.New("wde error"),
},
{
"WriteDouble", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDouble},
Doc{{"foo", Double(3.14159)}}, err,
},
{
"WriteString", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteString},
Doc{{"foo", String("bar")}}, err,
},
{
"WriteDocument (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t},
Doc{{"foo", Document(Doc{{"bar", Null()}})}},
bsoncodec.ErrNoEncoder{Type: tDocument},
},
{
"WriteArray (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t},
Doc{{"foo", Array(Arr{Null()})}},
bsoncodec.ErrNoEncoder{Type: tArray},
},
{
"WriteBinary", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBinaryWithSubtype},
Doc{{"foo", Binary(0xFF, []byte{0x01, 0x02, 0x03})}}, err,
},
{
"WriteUndefined", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteUndefined},
Doc{{"foo", Undefined()}}, err,
},
{
"WriteObjectID", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteObjectID},
Doc{{"foo", ObjectID(oid)}}, err,
},
{
"WriteBoolean", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBoolean},
Doc{{"foo", Boolean(true)}}, err,
},
{
"WriteDateTime", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDateTime},
Doc{{"foo", DateTime(1234567890)}}, err,
},
{
"WriteNull", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteNull},
Doc{{"foo", Null()}}, err,
},
{
"WriteRegex", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteRegex},
Doc{{"foo", Regex("bar", "baz")}}, err,
},
{
"WriteDBPointer", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDBPointer},
Doc{{"foo", DBPointer("bar", oid)}}, err,
},
{
"WriteJavascript", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteJavascript},
Doc{{"foo", JavaScript("var hello = 'world';")}}, err,
},
{
"WriteSymbol", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteSymbol},
Doc{{"foo", Symbol("symbolbaz")}}, err,
},
{
"WriteCodeWithScope (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteCodeWithScope},
Doc{{"foo", CodeWithScope("var hello = 'world';", Doc{}.Append("bar", Null()))}},
err,
},
{
"WriteInt32", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt32},
Doc{{"foo", Int32(12345)}}, err,
},
{
"WriteInt64", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt64},
Doc{{"foo", Int64(1234567890)}}, err,
},
{
"WriteTimestamp", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteTimestamp},
Doc{{"foo", Timestamp(10, 20)}}, err,
},
{
"WriteDecimal128", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDecimal128},
Doc{{"foo", Decimal128(primitive.NewDecimal128(10, 20))}}, err,
},
{
"WriteMinKey", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMinKey},
Doc{{"foo", MinKey()}}, err,
},
{
"WriteMaxKey", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMaxKey},
Doc{{"foo", MaxKey()}}, err,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := (PrimitiveCodecs{}).DocumentEncodeValue(tc.ec, tc.llvrw, reflect.ValueOf(tc.doc))
if !compareErrors(err, tc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("success", func(t *testing.T) {
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
d128 := primitive.NewDecimal128(10, 20)
want := Doc{
{"a", Double(3.14159)}, {"b", String("foo")},
{"c", Document(Doc{{"aa", Null()}})}, {"d", Array(Arr{Null()})},
{"e", Binary(0xFF, []byte{0x01, 0x02, 0x03})}, {"f", Undefined()},
{"g", ObjectID(oid)}, {"h", Boolean(true)},
{"i", DateTime(1234567890)}, {"j", Null()},
{"k", Regex("foo", "abr")},
{"l", DBPointer("foobar", oid)}, {"m", JavaScript("var hello = 'world';")},
{"n", Symbol("bazqux")},
{"o", CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}})},
{"p", Int32(12345)},
{"q", Timestamp(10, 20)}, {"r", Int64(1234567890)}, {"s", Decimal128(d128)}, {"t", MinKey()}, {"u", MaxKey()},
}
slc := make(bsonrw.SliceWriter, 0, 128)
vw, err := bsonrw.NewBSONValueWriter(&slc)
noerr(t, err)
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
err = (PrimitiveCodecs{}).DocumentEncodeValue(ec, vw, reflect.ValueOf(want))
noerr(t, err)
got, err := ReadDoc(slc)
noerr(t, err)
if !got.Equal(want) {
t.Error("Documents do not match")
t.Errorf("\ngot :%v\nwant:%v", got, want)
}
})
})
t.Run("ArrayEncodeValue", func(t *testing.T) {
t.Run("CodecEncodeError", func(t *testing.T) {
val := reflect.ValueOf(bool(true))
want := bsoncodec.ValueEncoderError{Name: "ArrayEncodeValue", Types: []reflect.Type{tArray}, Received: val}
got := (PrimitiveCodecs{}).ArrayEncodeValue(bsoncodec.EncodeContext{}, nil, val)
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("WriteArray Error", func(t *testing.T) {
want := errors.New("WriteArray Error")
llvrw := &bsonrwtest.ValueReaderWriter{
T: t,
Err: want,
ErrAfter: bsonrwtest.WriteArray,
}
got := (PrimitiveCodecs{}).ArrayEncodeValue(bsoncodec.EncodeContext{}, llvrw, reflect.MakeSlice(tArray, 0, 0))
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("encode array errors", func(t *testing.T) {
ec := bsoncodec.EncodeContext{}
err := errors.New("encode array error")
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
testCases := []struct {
name string
ec bsoncodec.EncodeContext
llvrw *bsonrwtest.ValueReaderWriter
arr Arr
err error
}{
{
"WriteDocumentElement",
ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("wde error"), ErrAfter: bsonrwtest.WriteArrayElement},
Arr{Null()},
errors.New("wde error"),
},
{
"WriteDouble", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDouble},
Arr{Double(3.14159)}, err,
},
{
"WriteString", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteString},
Arr{String("bar")}, err,
},
{
"WriteDocument (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t},
Arr{Document(Doc{{"bar", Null()}})},
bsoncodec.ErrNoEncoder{Type: tDocument},
},
{
"WriteArray (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t},
Arr{Array(Arr{Null()})},
bsoncodec.ErrNoEncoder{Type: tArray},
},
{
"WriteBinary", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBinaryWithSubtype},
Arr{Binary(0xFF, []byte{0x01, 0x02, 0x03})}, err,
},
{
"WriteUndefined", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteUndefined},
Arr{Undefined()}, err,
},
{
"WriteObjectID", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteObjectID},
Arr{ObjectID(oid)}, err,
},
{
"WriteBoolean", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteBoolean},
Arr{Boolean(true)}, err,
},
{
"WriteDateTime", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDateTime},
Arr{DateTime(1234567890)}, err,
},
{
"WriteNull", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteNull},
Arr{Null()}, err,
},
{
"WriteRegex", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteRegex},
Arr{Regex("bar", "baz")}, err,
},
{
"WriteDBPointer", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDBPointer},
Arr{DBPointer("bar", oid)}, err,
},
{
"WriteJavascript", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteJavascript},
Arr{JavaScript("var hello = 'world';")}, err,
},
{
"WriteSymbol", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteSymbol},
Arr{Symbol("symbolbaz")}, err,
},
{
"WriteCodeWithScope (Lookup)", bsoncodec.EncodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteCodeWithScope},
Arr{CodeWithScope("var hello = 'world';", Doc{{"bar", Null()}})},
err,
},
{
"WriteInt32", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt32},
Arr{Int32(12345)}, err,
},
{
"WriteInt64", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteInt64},
Arr{Int64(1234567890)}, err,
},
{
"WriteTimestamp", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteTimestamp},
Arr{Timestamp(10, 20)}, err,
},
{
"WriteDecimal128", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteDecimal128},
Arr{Decimal128(primitive.NewDecimal128(10, 20))}, err,
},
{
"WriteMinKey", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMinKey},
Arr{MinKey()}, err,
},
{
"WriteMaxKey", ec,
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.WriteMaxKey},
Arr{MaxKey()}, err,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := (PrimitiveCodecs{}).ArrayEncodeValue(tc.ec, tc.llvrw, reflect.ValueOf(tc.arr))
if !compareErrors(err, tc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("success", func(t *testing.T) {
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
d128 := primitive.NewDecimal128(10, 20)
want := Arr{
Double(3.14159), String("foo"), Document(Doc{{"aa", Null()}}),
Array(Arr{Null()}),
Binary(0xFF, []byte{0x01, 0x02, 0x03}), Undefined(),
ObjectID(oid), Boolean(true), DateTime(1234567890), Null(), Regex("foo", "abr"),
DBPointer("foobar", oid), JavaScript("var hello = 'world';"), Symbol("bazqux"),
CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}}), Int32(12345),
Timestamp(10, 20), Int64(1234567890), Decimal128(d128), MinKey(), MaxKey(),
}
ec := bsoncodec.EncodeContext{Registry: DefaultRegistry}
slc := make(bsonrw.SliceWriter, 0, 128)
vw, err := bsonrw.NewBSONValueWriter(&slc)
noerr(t, err)
dr, err := vw.WriteDocument()
noerr(t, err)
vr, err := dr.WriteDocumentElement("foo")
noerr(t, err)
err = (PrimitiveCodecs{}).ArrayEncodeValue(ec, vr, reflect.ValueOf(want))
noerr(t, err)
err = dr.WriteDocumentEnd()
noerr(t, err)
val, err := bsoncore.Document(slc).LookupErr("foo")
noerr(t, err)
rgot := val.Array()
doc, err := ReadDoc(rgot)
noerr(t, err)
got := make(Arr, 0)
for _, elem := range doc {
got = append(got, elem.Value)
}
if !got.Equal(want) {
t.Error("Documents do not match")
t.Errorf("\ngot :%v\nwant:%v", got, want)
}
})
})
}
func TestDefaultValueDecoders(t *testing.T) {
var pcx PrimitiveCodecs
var wrong = func(string, string) string { return "wrong" }
const cansetreflectiontest = "cansetreflectiontest"
type subtest struct {
name string
val interface{}
dctx *bsoncodec.DecodeContext
llvrw *bsonrwtest.ValueReaderWriter
invoke bsonrwtest.Invoked
err error
}
testCases := []struct {
name string
vd bsoncodec.ValueDecoder
subtests []subtest
}{
{
"ValueDecodeValue",
bsoncodec.ValueDecoderFunc(pcx.ValueDecodeValue),
[]subtest{
{
"wrong type",
wrong,
nil,
nil,
bsonrwtest.Nothing,
bsoncodec.ValueDecoderError{
Name: "ValueDecodeValue",
Types: []reflect.Type{tValue},
Received: reflect.ValueOf(wrong),
},
},
{
"invalid value",
(*Val)(nil),
nil,
nil,
bsonrwtest.Nothing,
bsoncodec.ValueDecoderError{
Name: "ValueDecodeValue",
Types: []reflect.Type{tValue},
Received: reflect.ValueOf((*Val)(nil)),
},
},
{
"success",
Double(3.14159),
&bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
bsonrwtest.ReadDouble,
nil,
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, rc := range tc.subtests {
t.Run(rc.name, func(t *testing.T) {
var dc bsoncodec.DecodeContext
if rc.dctx != nil {
dc = *rc.dctx
}
llvrw := new(bsonrwtest.ValueReaderWriter)
if rc.llvrw != nil {
llvrw = rc.llvrw
}
llvrw.T = t
// var got interface{}
if rc.val == cansetreflectiontest { // We're doing a CanSet reflection test
err := tc.vd.DecodeValue(dc, llvrw, reflect.Value{})
if !compareErrors(err, rc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, rc.err)
}
val := reflect.New(reflect.TypeOf(rc.val)).Elem()
err = tc.vd.DecodeValue(dc, llvrw, val)
if !compareErrors(err, rc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, rc.err)
}
return
}
var val reflect.Value
if rtype := reflect.TypeOf(rc.val); rtype != nil {
val = reflect.New(rtype).Elem()
}
want := rc.val
defer func() {
if err := recover(); err != nil {
fmt.Println(t.Name())
panic(err)
}
}()
err := tc.vd.DecodeValue(dc, llvrw, val)
if !compareErrors(err, rc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, rc.err)
}
invoked := llvrw.Invoked
if !cmp.Equal(invoked, rc.invoke) {
t.Errorf("Incorrect method invoked. got %v; want %v", invoked, rc.invoke)
}
var got interface{}
if val.IsValid() && val.CanInterface() {
got = val.Interface()
}
if rc.err == nil && !cmp.Equal(got, want, cmp.Comparer(compareValues)) {
t.Errorf("Values do not match. got (%T)%v; want (%T)%v", got, got, want, want)
}
})
}
})
}
t.Run("DocumentDecodeValue", func(t *testing.T) {
t.Run("CodecDecodeError", func(t *testing.T) {
val := reflect.New(reflect.TypeOf(false)).Elem()
want := bsoncodec.ValueDecoderError{Name: "DocumentDecodeValue", Types: []reflect.Type{tDocument}, Received: val}
got := pcx.DocumentDecodeValue(bsoncodec.DecodeContext{}, &bsonrwtest.ValueReaderWriter{BSONType: bsontype.EmbeddedDocument}, val)
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("ReadDocument Error", func(t *testing.T) {
want := errors.New("ReadDocument Error")
llvrw := &bsonrwtest.ValueReaderWriter{
T: t,
Err: want,
ErrAfter: bsonrwtest.ReadDocument,
BSONType: bsontype.EmbeddedDocument,
}
got := pcx.DocumentDecodeValue(bsoncodec.DecodeContext{}, llvrw, reflect.New(reflect.TypeOf(Doc{})).Elem())
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("decodeDocument errors", func(t *testing.T) {
dc := bsoncodec.DecodeContext{}
err := errors.New("decodeDocument error")
testCases := []struct {
name string
dc bsoncodec.DecodeContext
llvrw *bsonrwtest.ValueReaderWriter
err error
}{
{
"ReadElement",
dc,
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("re error"), ErrAfter: bsonrwtest.ReadElement},
errors.New("re error"),
},
{"ReadDouble", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDouble, BSONType: bsontype.Double}, err},
{"ReadString", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadString, BSONType: bsontype.String}, err},
{"ReadBinary", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBinary, BSONType: bsontype.Binary}, err},
{"ReadUndefined", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadUndefined, BSONType: bsontype.Undefined}, err},
{"ReadObjectID", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadObjectID, BSONType: bsontype.ObjectID}, err},
{"ReadBoolean", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBoolean, BSONType: bsontype.Boolean}, err},
{"ReadDateTime", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDateTime, BSONType: bsontype.DateTime}, err},
{"ReadNull", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadNull, BSONType: bsontype.Null}, err},
{"ReadRegex", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadRegex, BSONType: bsontype.Regex}, err},
{"ReadDBPointer", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDBPointer, BSONType: bsontype.DBPointer}, err},
{"ReadJavascript", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadJavascript, BSONType: bsontype.JavaScript}, err},
{"ReadSymbol", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadSymbol, BSONType: bsontype.Symbol}, err},
{
"ReadCodeWithScope (Lookup)", bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadCodeWithScope, BSONType: bsontype.CodeWithScope},
err,
},
{"ReadInt32", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt32, BSONType: bsontype.Int32}, err},
{"ReadInt64", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt64, BSONType: bsontype.Int64}, err},
{"ReadTimestamp", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadTimestamp, BSONType: bsontype.Timestamp}, err},
{"ReadDecimal128", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDecimal128, BSONType: bsontype.Decimal128}, err},
{"ReadMinKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMinKey, BSONType: bsontype.MinKey}, err},
{"ReadMaxKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMaxKey, BSONType: bsontype.MaxKey}, err},
{"Invalid Type", dc, &bsonrwtest.ValueReaderWriter{T: t, BSONType: bsontype.Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", bsontype.Type(0))},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := pcx.DecodeDocument(tc.dc, tc.llvrw, new(Doc))
if !compareErrors(err, tc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("success", func(t *testing.T) {
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
d128 := primitive.NewDecimal128(10, 20)
want := Doc{
{"a", Double(3.14159)}, {"b", String("foo")},
{"c", Document(Doc{{"aa", Null()}})},
{"d", Array(Arr{Null()})},
{"e", Binary(0xFF, []byte{0x01, 0x02, 0x03})}, {"f", Undefined()},
{"g", ObjectID(oid)}, {"h", Boolean(true)},
{"i", DateTime(1234567890)}, {"j", Null()}, {"k", Regex("foo", "bar")},
{"l", DBPointer("foobar", oid)}, {"m", JavaScript("var hello = 'world';")},
{"n", Symbol("bazqux")},
{"o", CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}})},
{"p", Int32(12345)},
{"q", Timestamp(10, 20)}, {"r", Int64(1234567890)},
{"s", Decimal128(d128)}, {"t", MinKey()}, {"u", MaxKey()},
}
got := reflect.New(reflect.TypeOf(Doc{})).Elem()
dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
b, err := want.MarshalBSON()
noerr(t, err)
err = pcx.DocumentDecodeValue(dc, bsonrw.NewBSONDocumentReader(b), got)
noerr(t, err)
if !got.Interface().(Doc).Equal(want) {
t.Error("Documents do not match")
t.Errorf("\ngot :%v\nwant:%v", got, want)
}
})
})
t.Run("ArrayDecodeValue", func(t *testing.T) {
t.Run("CodecDecodeError", func(t *testing.T) {
val := reflect.New(reflect.TypeOf(false)).Elem()
want := bsoncodec.ValueDecoderError{Name: "ArrayDecodeValue", Types: []reflect.Type{tArray}, Received: val}
got := pcx.ArrayDecodeValue(bsoncodec.DecodeContext{}, &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Array}, val)
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("ReadArray Error", func(t *testing.T) {
want := errors.New("ReadArray Error")
llvrw := &bsonrwtest.ValueReaderWriter{
T: t,
Err: want,
ErrAfter: bsonrwtest.ReadArray,
BSONType: bsontype.Array,
}
got := pcx.ArrayDecodeValue(bsoncodec.DecodeContext{}, llvrw, reflect.New(tArray).Elem())
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
t.Run("decode array errors", func(t *testing.T) {
dc := bsoncodec.DecodeContext{}
err := errors.New("decode array error")
testCases := []struct {
name string
dc bsoncodec.DecodeContext
llvrw *bsonrwtest.ValueReaderWriter
err error
}{
{
"ReadValue",
dc,
&bsonrwtest.ValueReaderWriter{T: t, Err: errors.New("re error"), ErrAfter: bsonrwtest.ReadValue},
errors.New("re error"),
},
{"ReadDouble", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDouble, BSONType: bsontype.Double}, err},
{"ReadString", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadString, BSONType: bsontype.String}, err},
{"ReadBinary", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBinary, BSONType: bsontype.Binary}, err},
{"ReadUndefined", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadUndefined, BSONType: bsontype.Undefined}, err},
{"ReadObjectID", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadObjectID, BSONType: bsontype.ObjectID}, err},
{"ReadBoolean", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadBoolean, BSONType: bsontype.Boolean}, err},
{"ReadDateTime", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDateTime, BSONType: bsontype.DateTime}, err},
{"ReadNull", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadNull, BSONType: bsontype.Null}, err},
{"ReadRegex", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadRegex, BSONType: bsontype.Regex}, err},
{"ReadDBPointer", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDBPointer, BSONType: bsontype.DBPointer}, err},
{"ReadJavascript", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadJavascript, BSONType: bsontype.JavaScript}, err},
{"ReadSymbol", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadSymbol, BSONType: bsontype.Symbol}, err},
{
"ReadCodeWithScope (Lookup)", bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()},
&bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadCodeWithScope, BSONType: bsontype.CodeWithScope},
err,
},
{"ReadInt32", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt32, BSONType: bsontype.Int32}, err},
{"ReadInt64", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadInt64, BSONType: bsontype.Int64}, err},
{"ReadTimestamp", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadTimestamp, BSONType: bsontype.Timestamp}, err},
{"ReadDecimal128", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadDecimal128, BSONType: bsontype.Decimal128}, err},
{"ReadMinKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMinKey, BSONType: bsontype.MinKey}, err},
{"ReadMaxKey", dc, &bsonrwtest.ValueReaderWriter{T: t, Err: err, ErrAfter: bsonrwtest.ReadMaxKey, BSONType: bsontype.MaxKey}, err},
{"Invalid Type", dc, &bsonrwtest.ValueReaderWriter{T: t, BSONType: bsontype.Type(0)}, fmt.Errorf("Cannot read unknown BSON type %s", bsontype.Type(0))},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := pcx.ArrayDecodeValue(tc.dc, tc.llvrw, reflect.New(tArray).Elem())
if !compareErrors(err, tc.err) {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("success", func(t *testing.T) {
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
d128 := primitive.NewDecimal128(10, 20)
want := Arr{
Double(3.14159), String("foo"), Document(Doc{{"aa", Null()}}),
Array(Arr{Null()}),
Binary(0xFF, []byte{0x01, 0x02, 0x03}), Undefined(),
ObjectID(oid), Boolean(true), DateTime(1234567890), Null(), Regex("foo", "bar"),
DBPointer("foobar", oid), JavaScript("var hello = 'world';"), Symbol("bazqux"),
CodeWithScope("var hello = 'world';", Doc{{"ab", Null()}}), Int32(12345),
Timestamp(10, 20), Int64(1234567890), Decimal128(d128), MinKey(), MaxKey(),
}
dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()}
b, err := Doc{{"", Array(want)}}.MarshalBSON()
noerr(t, err)
dvr := bsonrw.NewBSONDocumentReader(b)
dr, err := dvr.ReadDocument()
noerr(t, err)
_, vr, err := dr.ReadElement()
noerr(t, err)
val := reflect.New(tArray).Elem()
err = pcx.ArrayDecodeValue(dc, vr, val)
noerr(t, err)
got := val.Interface().(Arr)
if !got.Equal(want) {
t.Error("Documents do not match")
t.Errorf("\ngot :%v\nwant:%v", got, want)
}
})
})
t.Run("success path", func(t *testing.T) {
testCases := []struct {
name string
value interface{}
b []byte
err error
}{
{
"map[string][]Element",
map[string][]Elem{"Z": {{"A", Int32(1)}, {"B", Int32(2)}, {"EC", Int32(3)}}},
docToBytes(Doc{{"Z", Document(Doc{{"A", Int32(1)}, {"B", Int32(2)}, {"EC", Int32(3)}})}}),
nil,
},
{
"map[string][]Value",
map[string][]Val{"Z": {Int32(1), Int32(2), Int32(3)}},
docToBytes(Doc{{"Z", Array(Arr{Int32(1), Int32(2), Int32(3)})}}),
nil,
},
{
"map[string]*Document",
map[string]Doc{"Z": {{"foo", Null()}}},
docToBytes(Doc{{"Z", Document(Doc{{"foo", Null()}})}}),
nil,
},
}
t.Run("Decode", func(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
vr := bsonrw.NewBSONDocumentReader(tc.b)
dec, err := bson.NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr)
noerr(t, err)
gotVal := reflect.New(reflect.TypeOf(tc.value))
err = dec.Decode(gotVal.Interface())
noerr(t, err)
got := gotVal.Elem().Interface()
want := tc.value
if diff := cmp.Diff(
got, want,
); diff != "" {
t.Errorf("difference:\n%s", diff)
t.Errorf("Values are not equal.\ngot: %#v\nwant:%#v", got, want)
}
})
}
})
})
}
func compareValues(v1, v2 Val) bool { return v1.Equal(v2) }
func docToBytes(d Doc) []byte {
b, err := d.MarshalBSON()
if err != nil {
panic(err)
}
return b
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,123 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestReflectionFreeDCodec(t *testing.T) {
assert.RegisterOpts(reflect.TypeOf(primitive.D{}), cmp.AllowUnexported(primitive.Decimal128{}))
now := time.Now()
oid := primitive.NewObjectID()
d128 := primitive.NewDecimal128(10, 20)
js := primitive.JavaScript("js")
symbol := primitive.Symbol("sybmol")
binary := primitive.Binary{Subtype: 2, Data: []byte("binary")}
datetime := primitive.NewDateTimeFromTime(now)
regex := primitive.Regex{Pattern: "pattern", Options: "i"}
dbPointer := primitive.DBPointer{DB: "db", Pointer: oid}
timestamp := primitive.Timestamp{T: 5, I: 10}
cws := primitive.CodeWithScope{Code: js, Scope: bson.D{{"x", 1}}}
noReflectionRegistry := bson.NewRegistryBuilder().RegisterCodec(tPrimitiveD, ReflectionFreeDCodec).Build()
docWithAllTypes := primitive.D{
{"byteSlice", []byte("foobar")},
{"sliceByteSlice", [][]byte{[]byte("foobar")}},
{"timeTime", now},
{"sliceTimeTime", []time.Time{now}},
{"objectID", oid},
{"sliceObjectID", []primitive.ObjectID{oid}},
{"decimal128", d128},
{"sliceDecimal128", []primitive.Decimal128{d128}},
{"js", js},
{"sliceJS", []primitive.JavaScript{js}},
{"symbol", symbol},
{"sliceSymbol", []primitive.Symbol{symbol}},
{"binary", binary},
{"sliceBinary", []primitive.Binary{binary}},
{"undefined", primitive.Undefined{}},
{"sliceUndefined", []primitive.Undefined{{}}},
{"datetime", datetime},
{"sliceDateTime", []primitive.DateTime{datetime}},
{"null", primitive.Null{}},
{"sliceNull", []primitive.Null{{}}},
{"regex", regex},
{"sliceRegex", []primitive.Regex{regex}},
{"dbPointer", dbPointer},
{"sliceDBPointer", []primitive.DBPointer{dbPointer}},
{"timestamp", timestamp},
{"sliceTimestamp", []primitive.Timestamp{timestamp}},
{"minKey", primitive.MinKey{}},
{"sliceMinKey", []primitive.MinKey{{}}},
{"maxKey", primitive.MaxKey{}},
{"sliceMaxKey", []primitive.MaxKey{{}}},
{"cws", cws},
{"sliceCWS", []primitive.CodeWithScope{cws}},
{"bool", true},
{"sliceBool", []bool{true}},
{"int", int(10)},
{"sliceInt", []int{10}},
{"int8", int8(10)},
{"sliceInt8", []int8{10}},
{"int16", int16(10)},
{"sliceInt16", []int16{10}},
{"int32", int32(10)},
{"sliceInt32", []int32{10}},
{"int64", int64(10)},
{"sliceInt64", []int64{10}},
{"uint", uint(10)},
{"sliceUint", []uint{10}},
{"uint8", uint8(10)},
{"sliceUint8", []uint8{10}},
{"uint16", uint16(10)},
{"sliceUint16", []uint16{10}},
{"uint32", uint32(10)},
{"sliceUint32", []uint32{10}},
{"uint64", uint64(10)},
{"sliceUint64", []uint64{10}},
{"float32", float32(10)},
{"sliceFloat32", []float32{10}},
{"float64", float64(10)},
{"sliceFloat64", []float64{10}},
{"primitiveA", primitive.A{"foo", "bar"}},
}
t.Run("encode", func(t *testing.T) {
// Assert that bson.Marshal returns the same result when using the default registry and noReflectionRegistry.
expected, err := bson.Marshal(docWithAllTypes)
assert.Nil(t, err, "Marshal error with default registry: %v", err)
actual, err := bson.MarshalWithRegistry(noReflectionRegistry, docWithAllTypes)
assert.Nil(t, err, "Marshal error with noReflectionRegistry: %v", err)
assert.Equal(t, expected, actual, "expected doc %s, got %s", bson.Raw(expected), bson.Raw(actual))
})
t.Run("decode", func(t *testing.T) {
// Assert that bson.Unmarshal returns the same result when using the default registry and noReflectionRegistry.
// docWithAllTypes contains some types that can't be roundtripped. For example, any slices besides primitive.A
// would start of as []T but unmarshal to primitive.A. To get around this, we first marshal docWithAllTypes to
// raw bytes and then Unmarshal to another primitive.D rather than asserting directly against docWithAllTypes.
docBytes, err := bson.Marshal(docWithAllTypes)
assert.Nil(t, err, "Marshal error: %v", err)
var expected, actual primitive.D
err = bson.Unmarshal(docBytes, &expected)
assert.Nil(t, err, "Unmarshal error with default registry: %v", err)
err = bson.UnmarshalWithRegistry(noReflectionRegistry, docBytes, &actual)
assert.Nil(t, err, "Unmarshal error with noReflectionRegistry: %v", err)
assert.Equal(t, expected, actual, "expected document %v, got %v", expected, actual)
})
}

28
mongo/x/bsonx/registry.go Normal file
View File

@@ -0,0 +1,28 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
)
// DefaultRegistry is the default bsoncodec.Registry. It contains the default codecs and the
// primitive codecs.
var DefaultRegistry = NewRegistryBuilder().Build()
// NewRegistryBuilder creates a new RegistryBuilder configured with the default encoders and
// decoders from the bsoncodec.DefaultValueEncoders and bsoncodec.DefaultValueDecoders types and the
// PrimitiveCodecs type in this package.
func NewRegistryBuilder() *bsoncodec.RegistryBuilder {
rb := bsoncodec.NewRegistryBuilder()
bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb)
bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)
bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb)
primitiveCodecs.RegisterPrimitiveCodecs(rb)
return rb
}

866
mongo/x/bsonx/value.go Normal file
View File

@@ -0,0 +1,866 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"math"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// Val represents a BSON value.
type Val struct {
// NOTE: The bootstrap is a small amount of space that'll be on the stack. At 15 bytes this
// doesn't make this type any larger, since there are 7 bytes of padding and we want an int64 to
// store small values (e.g. boolean, double, int64, etc...). The primitive property is where all
// of the larger values go. They will use either Go primitives or the primitive.* types.
t bsontype.Type
bootstrap [15]byte
primitive interface{}
}
func (v Val) string() string {
if v.primitive != nil {
return v.primitive.(string)
}
// The string will either end with a null byte or it fills the entire bootstrap space.
length := v.bootstrap[0]
return string(v.bootstrap[1 : length+1])
}
func (v Val) writestring(str string) Val {
switch {
case len(str) < 15:
v.bootstrap[0] = uint8(len(str))
copy(v.bootstrap[1:], str)
default:
v.primitive = str
}
return v
}
func (v Val) i64() int64 {
return int64(v.bootstrap[0]) | int64(v.bootstrap[1])<<8 | int64(v.bootstrap[2])<<16 |
int64(v.bootstrap[3])<<24 | int64(v.bootstrap[4])<<32 | int64(v.bootstrap[5])<<40 |
int64(v.bootstrap[6])<<48 | int64(v.bootstrap[7])<<56
}
func (v Val) writei64(i64 int64) Val {
v.bootstrap[0] = byte(i64)
v.bootstrap[1] = byte(i64 >> 8)
v.bootstrap[2] = byte(i64 >> 16)
v.bootstrap[3] = byte(i64 >> 24)
v.bootstrap[4] = byte(i64 >> 32)
v.bootstrap[5] = byte(i64 >> 40)
v.bootstrap[6] = byte(i64 >> 48)
v.bootstrap[7] = byte(i64 >> 56)
return v
}
// IsZero returns true if this value is zero or a BSON null.
func (v Val) IsZero() bool { return v.t == bsontype.Type(0) || v.t == bsontype.Null }
func (v Val) String() string {
// TODO(GODRIVER-612): When bsoncore has appenders for extended JSON use that here.
return fmt.Sprintf("%v", v.Interface())
}
// Interface returns the Go value of this Value as an empty interface.
//
// This method will return nil if it is empty, otherwise it will return a Go primitive or a
// primitive.* instance.
func (v Val) Interface() interface{} {
switch v.Type() {
case bsontype.Double:
return v.Double()
case bsontype.String:
return v.StringValue()
case bsontype.EmbeddedDocument:
switch v.primitive.(type) {
case Doc:
return v.primitive.(Doc)
case MDoc:
return v.primitive.(MDoc)
default:
return primitive.Null{}
}
case bsontype.Array:
return v.Array()
case bsontype.Binary:
return v.primitive.(primitive.Binary)
case bsontype.Undefined:
return primitive.Undefined{}
case bsontype.ObjectID:
return v.ObjectID()
case bsontype.Boolean:
return v.Boolean()
case bsontype.DateTime:
return v.DateTime()
case bsontype.Null:
return primitive.Null{}
case bsontype.Regex:
return v.primitive.(primitive.Regex)
case bsontype.DBPointer:
return v.primitive.(primitive.DBPointer)
case bsontype.JavaScript:
return v.JavaScript()
case bsontype.Symbol:
return v.Symbol()
case bsontype.CodeWithScope:
return v.primitive.(primitive.CodeWithScope)
case bsontype.Int32:
return v.Int32()
case bsontype.Timestamp:
t, i := v.Timestamp()
return primitive.Timestamp{T: t, I: i}
case bsontype.Int64:
return v.Int64()
case bsontype.Decimal128:
return v.Decimal128()
case bsontype.MinKey:
return primitive.MinKey{}
case bsontype.MaxKey:
return primitive.MaxKey{}
default:
return primitive.Null{}
}
}
// MarshalBSONValue implements the bsoncodec.ValueMarshaler interface.
func (v Val) MarshalBSONValue() (bsontype.Type, []byte, error) {
return v.MarshalAppendBSONValue(nil)
}
// MarshalAppendBSONValue is similar to MarshalBSONValue, but allows the caller to specify a slice
// to add the bytes to.
func (v Val) MarshalAppendBSONValue(dst []byte) (bsontype.Type, []byte, error) {
t := v.Type()
switch v.Type() {
case bsontype.Double:
dst = bsoncore.AppendDouble(dst, v.Double())
case bsontype.String:
dst = bsoncore.AppendString(dst, v.String())
case bsontype.EmbeddedDocument:
switch v.primitive.(type) {
case Doc:
t, dst, _ = v.primitive.(Doc).MarshalBSONValue() // Doc.MarshalBSONValue never returns an error.
case MDoc:
t, dst, _ = v.primitive.(MDoc).MarshalBSONValue() // MDoc.MarshalBSONValue never returns an error.
}
case bsontype.Array:
t, dst, _ = v.Array().MarshalBSONValue() // Arr.MarshalBSON never returns an error.
case bsontype.Binary:
subtype, bindata := v.Binary()
dst = bsoncore.AppendBinary(dst, subtype, bindata)
case bsontype.Undefined:
case bsontype.ObjectID:
dst = bsoncore.AppendObjectID(dst, v.ObjectID())
case bsontype.Boolean:
dst = bsoncore.AppendBoolean(dst, v.Boolean())
case bsontype.DateTime:
dst = bsoncore.AppendDateTime(dst, v.DateTime())
case bsontype.Null:
case bsontype.Regex:
pattern, options := v.Regex()
dst = bsoncore.AppendRegex(dst, pattern, options)
case bsontype.DBPointer:
ns, ptr := v.DBPointer()
dst = bsoncore.AppendDBPointer(dst, ns, ptr)
case bsontype.JavaScript:
dst = bsoncore.AppendJavaScript(dst, v.JavaScript())
case bsontype.Symbol:
dst = bsoncore.AppendSymbol(dst, v.Symbol())
case bsontype.CodeWithScope:
code, doc := v.CodeWithScope()
var scope []byte
scope, _ = doc.MarshalBSON() // Doc.MarshalBSON never returns an error.
dst = bsoncore.AppendCodeWithScope(dst, code, scope)
case bsontype.Int32:
dst = bsoncore.AppendInt32(dst, v.Int32())
case bsontype.Timestamp:
t, i := v.Timestamp()
dst = bsoncore.AppendTimestamp(dst, t, i)
case bsontype.Int64:
dst = bsoncore.AppendInt64(dst, v.Int64())
case bsontype.Decimal128:
dst = bsoncore.AppendDecimal128(dst, v.Decimal128())
case bsontype.MinKey:
case bsontype.MaxKey:
default:
panic(fmt.Errorf("invalid BSON type %v", t))
}
return t, dst, nil
}
// UnmarshalBSONValue implements the bsoncodec.ValueUnmarshaler interface.
func (v *Val) UnmarshalBSONValue(t bsontype.Type, data []byte) error {
if v == nil {
return errors.New("cannot unmarshal into nil Value")
}
var err error
var ok = true
var rem []byte
switch t {
case bsontype.Double:
var f64 float64
f64, rem, ok = bsoncore.ReadDouble(data)
*v = Double(f64)
case bsontype.String:
var str string
str, rem, ok = bsoncore.ReadString(data)
*v = String(str)
case bsontype.EmbeddedDocument:
var raw []byte
var doc Doc
raw, rem, ok = bsoncore.ReadDocument(data)
doc, err = ReadDoc(raw)
*v = Document(doc)
case bsontype.Array:
var raw []byte
arr := make(Arr, 0)
raw, rem, ok = bsoncore.ReadArray(data)
err = arr.UnmarshalBSONValue(t, raw)
*v = Array(arr)
case bsontype.Binary:
var subtype byte
var bindata []byte
subtype, bindata, rem, ok = bsoncore.ReadBinary(data)
*v = Binary(subtype, bindata)
case bsontype.Undefined:
*v = Undefined()
case bsontype.ObjectID:
var oid primitive.ObjectID
oid, rem, ok = bsoncore.ReadObjectID(data)
*v = ObjectID(oid)
case bsontype.Boolean:
var b bool
b, rem, ok = bsoncore.ReadBoolean(data)
*v = Boolean(b)
case bsontype.DateTime:
var dt int64
dt, rem, ok = bsoncore.ReadDateTime(data)
*v = DateTime(dt)
case bsontype.Null:
*v = Null()
case bsontype.Regex:
var pattern, options string
pattern, options, rem, ok = bsoncore.ReadRegex(data)
*v = Regex(pattern, options)
case bsontype.DBPointer:
var ns string
var ptr primitive.ObjectID
ns, ptr, rem, ok = bsoncore.ReadDBPointer(data)
*v = DBPointer(ns, ptr)
case bsontype.JavaScript:
var js string
js, rem, ok = bsoncore.ReadJavaScript(data)
*v = JavaScript(js)
case bsontype.Symbol:
var symbol string
symbol, rem, ok = bsoncore.ReadSymbol(data)
*v = Symbol(symbol)
case bsontype.CodeWithScope:
var raw []byte
var code string
var scope Doc
code, raw, rem, ok = bsoncore.ReadCodeWithScope(data)
scope, err = ReadDoc(raw)
*v = CodeWithScope(code, scope)
case bsontype.Int32:
var i32 int32
i32, rem, ok = bsoncore.ReadInt32(data)
*v = Int32(i32)
case bsontype.Timestamp:
var i, t uint32
t, i, rem, ok = bsoncore.ReadTimestamp(data)
*v = Timestamp(t, i)
case bsontype.Int64:
var i64 int64
i64, rem, ok = bsoncore.ReadInt64(data)
*v = Int64(i64)
case bsontype.Decimal128:
var d128 primitive.Decimal128
d128, rem, ok = bsoncore.ReadDecimal128(data)
*v = Decimal128(d128)
case bsontype.MinKey:
*v = MinKey()
case bsontype.MaxKey:
*v = MaxKey()
default:
err = fmt.Errorf("invalid BSON type %v", t)
}
if !ok && err == nil {
err = bsoncore.NewInsufficientBytesError(data, rem)
}
return err
}
// Type returns the BSON type of this value.
func (v Val) Type() bsontype.Type {
if v.t == bsontype.Type(0) {
return bsontype.Null
}
return v.t
}
// IsNumber returns true if the type of v is a numberic BSON type.
func (v Val) IsNumber() bool {
switch v.Type() {
case bsontype.Double, bsontype.Int32, bsontype.Int64, bsontype.Decimal128:
return true
default:
return false
}
}
// Double returns the BSON double value the Value represents. It panics if the value is a BSON type
// other than double.
func (v Val) Double() float64 {
if v.t != bsontype.Double {
panic(ElementTypeError{"bson.Value.Double", v.t})
}
return math.Float64frombits(binary.LittleEndian.Uint64(v.bootstrap[0:8]))
}
// DoubleOK is the same as Double, but returns a boolean instead of panicking.
func (v Val) DoubleOK() (float64, bool) {
if v.t != bsontype.Double {
return 0, false
}
return math.Float64frombits(binary.LittleEndian.Uint64(v.bootstrap[0:8])), true
}
// StringValue returns the BSON string the Value represents. It panics if the value is a BSON type
// other than string.
//
// NOTE: This method is called StringValue to avoid it implementing the
// fmt.Stringer interface.
func (v Val) StringValue() string {
if v.t != bsontype.String {
panic(ElementTypeError{"bson.Value.StringValue", v.t})
}
return v.string()
}
// StringValueOK is the same as StringValue, but returns a boolean instead of
// panicking.
func (v Val) StringValueOK() (string, bool) {
if v.t != bsontype.String {
return "", false
}
return v.string(), true
}
func (v Val) asDoc() Doc {
doc, ok := v.primitive.(Doc)
if ok {
return doc
}
mdoc := v.primitive.(MDoc)
for k, v := range mdoc {
doc = append(doc, Elem{k, v})
}
return doc
}
func (v Val) asMDoc() MDoc {
mdoc, ok := v.primitive.(MDoc)
if ok {
return mdoc
}
mdoc = make(MDoc)
doc := v.primitive.(Doc)
for _, elem := range doc {
mdoc[elem.Key] = elem.Value
}
return mdoc
}
// Document returns the BSON embedded document value the Value represents. It panics if the value
// is a BSON type other than embedded document.
func (v Val) Document() Doc {
if v.t != bsontype.EmbeddedDocument {
panic(ElementTypeError{"bson.Value.Document", v.t})
}
return v.asDoc()
}
// DocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (v Val) DocumentOK() (Doc, bool) {
if v.t != bsontype.EmbeddedDocument {
return nil, false
}
return v.asDoc(), true
}
// MDocument returns the BSON embedded document value the Value represents. It panics if the value
// is a BSON type other than embedded document.
func (v Val) MDocument() MDoc {
if v.t != bsontype.EmbeddedDocument {
panic(ElementTypeError{"bson.Value.MDocument", v.t})
}
return v.asMDoc()
}
// MDocumentOK is the same as Document, except it returns a boolean
// instead of panicking.
func (v Val) MDocumentOK() (MDoc, bool) {
if v.t != bsontype.EmbeddedDocument {
return nil, false
}
return v.asMDoc(), true
}
// Array returns the BSON array value the Value represents. It panics if the value is a BSON type
// other than array.
func (v Val) Array() Arr {
if v.t != bsontype.Array {
panic(ElementTypeError{"bson.Value.Array", v.t})
}
return v.primitive.(Arr)
}
// ArrayOK is the same as Array, except it returns a boolean
// instead of panicking.
func (v Val) ArrayOK() (Arr, bool) {
if v.t != bsontype.Array {
return nil, false
}
return v.primitive.(Arr), true
}
// Binary returns the BSON binary value the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) Binary() (byte, []byte) {
if v.t != bsontype.Binary {
panic(ElementTypeError{"bson.Value.Binary", v.t})
}
bin := v.primitive.(primitive.Binary)
return bin.Subtype, bin.Data
}
// BinaryOK is the same as Binary, except it returns a boolean instead of
// panicking.
func (v Val) BinaryOK() (byte, []byte, bool) {
if v.t != bsontype.Binary {
return 0x00, nil, false
}
bin := v.primitive.(primitive.Binary)
return bin.Subtype, bin.Data, true
}
// Undefined returns the BSON undefined the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) Undefined() {
if v.t != bsontype.Undefined {
panic(ElementTypeError{"bson.Value.Undefined", v.t})
}
}
// UndefinedOK is the same as Undefined, except it returns a boolean instead of
// panicking.
func (v Val) UndefinedOK() bool {
return v.t == bsontype.Undefined
}
// ObjectID returns the BSON ObjectID the Value represents. It panics if the value is a BSON type
// other than ObjectID.
func (v Val) ObjectID() primitive.ObjectID {
if v.t != bsontype.ObjectID {
panic(ElementTypeError{"bson.Value.ObjectID", v.t})
}
var oid primitive.ObjectID
copy(oid[:], v.bootstrap[:12])
return oid
}
// ObjectIDOK is the same as ObjectID, except it returns a boolean instead of
// panicking.
func (v Val) ObjectIDOK() (primitive.ObjectID, bool) {
if v.t != bsontype.ObjectID {
return primitive.ObjectID{}, false
}
var oid primitive.ObjectID
copy(oid[:], v.bootstrap[:12])
return oid, true
}
// Boolean returns the BSON boolean the Value represents. It panics if the value is a BSON type
// other than boolean.
func (v Val) Boolean() bool {
if v.t != bsontype.Boolean {
panic(ElementTypeError{"bson.Value.Boolean", v.t})
}
return v.bootstrap[0] == 0x01
}
// BooleanOK is the same as Boolean, except it returns a boolean instead of
// panicking.
func (v Val) BooleanOK() (bool, bool) {
if v.t != bsontype.Boolean {
return false, false
}
return v.bootstrap[0] == 0x01, true
}
// DateTime returns the BSON datetime the Value represents. It panics if the value is a BSON type
// other than datetime.
func (v Val) DateTime() int64 {
if v.t != bsontype.DateTime {
panic(ElementTypeError{"bson.Value.DateTime", v.t})
}
return v.i64()
}
// DateTimeOK is the same as DateTime, except it returns a boolean instead of
// panicking.
func (v Val) DateTimeOK() (int64, bool) {
if v.t != bsontype.DateTime {
return 0, false
}
return v.i64(), true
}
// Time returns the BSON datetime the Value represents as time.Time. It panics if the value is a BSON
// type other than datetime.
func (v Val) Time() time.Time {
if v.t != bsontype.DateTime {
panic(ElementTypeError{"bson.Value.Time", v.t})
}
i := v.i64()
return time.Unix(i/1000, i%1000*1000000)
}
// TimeOK is the same as Time, except it returns a boolean instead of
// panicking.
func (v Val) TimeOK() (time.Time, bool) {
if v.t != bsontype.DateTime {
return time.Time{}, false
}
i := v.i64()
return time.Unix(i/1000, i%1000*1000000), true
}
// Null returns the BSON undefined the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) Null() {
if v.t != bsontype.Null && v.t != bsontype.Type(0) {
panic(ElementTypeError{"bson.Value.Null", v.t})
}
}
// NullOK is the same as Null, except it returns a boolean instead of
// panicking.
func (v Val) NullOK() bool {
if v.t != bsontype.Null && v.t != bsontype.Type(0) {
return false
}
return true
}
// Regex returns the BSON regex the Value represents. It panics if the value is a BSON type
// other than regex.
func (v Val) Regex() (pattern, options string) {
if v.t != bsontype.Regex {
panic(ElementTypeError{"bson.Value.Regex", v.t})
}
regex := v.primitive.(primitive.Regex)
return regex.Pattern, regex.Options
}
// RegexOK is the same as Regex, except that it returns a boolean
// instead of panicking.
func (v Val) RegexOK() (pattern, options string, ok bool) {
if v.t != bsontype.Regex {
return "", "", false
}
regex := v.primitive.(primitive.Regex)
return regex.Pattern, regex.Options, true
}
// DBPointer returns the BSON dbpointer the Value represents. It panics if the value is a BSON type
// other than dbpointer.
func (v Val) DBPointer() (string, primitive.ObjectID) {
if v.t != bsontype.DBPointer {
panic(ElementTypeError{"bson.Value.DBPointer", v.t})
}
dbptr := v.primitive.(primitive.DBPointer)
return dbptr.DB, dbptr.Pointer
}
// DBPointerOK is the same as DBPoitner, except that it returns a boolean
// instead of panicking.
func (v Val) DBPointerOK() (string, primitive.ObjectID, bool) {
if v.t != bsontype.DBPointer {
return "", primitive.ObjectID{}, false
}
dbptr := v.primitive.(primitive.DBPointer)
return dbptr.DB, dbptr.Pointer, true
}
// JavaScript returns the BSON JavaScript the Value represents. It panics if the value is a BSON type
// other than JavaScript.
func (v Val) JavaScript() string {
if v.t != bsontype.JavaScript {
panic(ElementTypeError{"bson.Value.JavaScript", v.t})
}
return v.string()
}
// JavaScriptOK is the same as Javascript, except that it returns a boolean
// instead of panicking.
func (v Val) JavaScriptOK() (string, bool) {
if v.t != bsontype.JavaScript {
return "", false
}
return v.string(), true
}
// Symbol returns the BSON symbol the Value represents. It panics if the value is a BSON type
// other than symbol.
func (v Val) Symbol() string {
if v.t != bsontype.Symbol {
panic(ElementTypeError{"bson.Value.Symbol", v.t})
}
return v.string()
}
// SymbolOK is the same as Javascript, except that it returns a boolean
// instead of panicking.
func (v Val) SymbolOK() (string, bool) {
if v.t != bsontype.Symbol {
return "", false
}
return v.string(), true
}
// CodeWithScope returns the BSON code with scope value the Value represents. It panics if the
// value is a BSON type other than code with scope.
func (v Val) CodeWithScope() (string, Doc) {
if v.t != bsontype.CodeWithScope {
panic(ElementTypeError{"bson.Value.CodeWithScope", v.t})
}
cws := v.primitive.(primitive.CodeWithScope)
return string(cws.Code), cws.Scope.(Doc)
}
// CodeWithScopeOK is the same as JavascriptWithScope,
// except that it returns a boolean instead of panicking.
func (v Val) CodeWithScopeOK() (string, Doc, bool) {
if v.t != bsontype.CodeWithScope {
return "", nil, false
}
cws := v.primitive.(primitive.CodeWithScope)
return string(cws.Code), cws.Scope.(Doc), true
}
// Int32 returns the BSON int32 the Value represents. It panics if the value is a BSON type
// other than int32.
func (v Val) Int32() int32 {
if v.t != bsontype.Int32 {
panic(ElementTypeError{"bson.Value.Int32", v.t})
}
return int32(v.bootstrap[0]) | int32(v.bootstrap[1])<<8 |
int32(v.bootstrap[2])<<16 | int32(v.bootstrap[3])<<24
}
// Int32OK is the same as Int32, except that it returns a boolean instead of
// panicking.
func (v Val) Int32OK() (int32, bool) {
if v.t != bsontype.Int32 {
return 0, false
}
return int32(v.bootstrap[0]) | int32(v.bootstrap[1])<<8 |
int32(v.bootstrap[2])<<16 | int32(v.bootstrap[3])<<24,
true
}
// Timestamp returns the BSON timestamp the Value represents. It panics if the value is a
// BSON type other than timestamp.
func (v Val) Timestamp() (t, i uint32) {
if v.t != bsontype.Timestamp {
panic(ElementTypeError{"bson.Value.Timestamp", v.t})
}
return uint32(v.bootstrap[4]) | uint32(v.bootstrap[5])<<8 |
uint32(v.bootstrap[6])<<16 | uint32(v.bootstrap[7])<<24,
uint32(v.bootstrap[0]) | uint32(v.bootstrap[1])<<8 |
uint32(v.bootstrap[2])<<16 | uint32(v.bootstrap[3])<<24
}
// TimestampOK is the same as Timestamp, except that it returns a boolean
// instead of panicking.
func (v Val) TimestampOK() (t uint32, i uint32, ok bool) {
if v.t != bsontype.Timestamp {
return 0, 0, false
}
return uint32(v.bootstrap[4]) | uint32(v.bootstrap[5])<<8 |
uint32(v.bootstrap[6])<<16 | uint32(v.bootstrap[7])<<24,
uint32(v.bootstrap[0]) | uint32(v.bootstrap[1])<<8 |
uint32(v.bootstrap[2])<<16 | uint32(v.bootstrap[3])<<24,
true
}
// Int64 returns the BSON int64 the Value represents. It panics if the value is a BSON type
// other than int64.
func (v Val) Int64() int64 {
if v.t != bsontype.Int64 {
panic(ElementTypeError{"bson.Value.Int64", v.t})
}
return v.i64()
}
// Int64OK is the same as Int64, except that it returns a boolean instead of
// panicking.
func (v Val) Int64OK() (int64, bool) {
if v.t != bsontype.Int64 {
return 0, false
}
return v.i64(), true
}
// Decimal128 returns the BSON decimal128 value the Value represents. It panics if the value is a
// BSON type other than decimal128.
func (v Val) Decimal128() primitive.Decimal128 {
if v.t != bsontype.Decimal128 {
panic(ElementTypeError{"bson.Value.Decimal128", v.t})
}
return v.primitive.(primitive.Decimal128)
}
// Decimal128OK is the same as Decimal128, except that it returns a boolean
// instead of panicking.
func (v Val) Decimal128OK() (primitive.Decimal128, bool) {
if v.t != bsontype.Decimal128 {
return primitive.Decimal128{}, false
}
return v.primitive.(primitive.Decimal128), true
}
// MinKey returns the BSON minkey the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) MinKey() {
if v.t != bsontype.MinKey {
panic(ElementTypeError{"bson.Value.MinKey", v.t})
}
}
// MinKeyOK is the same as MinKey, except it returns a boolean instead of
// panicking.
func (v Val) MinKeyOK() bool {
return v.t == bsontype.MinKey
}
// MaxKey returns the BSON maxkey the Value represents. It panics if the value is a BSON type
// other than binary.
func (v Val) MaxKey() {
if v.t != bsontype.MaxKey {
panic(ElementTypeError{"bson.Value.MaxKey", v.t})
}
}
// MaxKeyOK is the same as MaxKey, except it returns a boolean instead of
// panicking.
func (v Val) MaxKeyOK() bool {
return v.t == bsontype.MaxKey
}
// Equal compares v to v2 and returns true if they are equal. Unknown BSON types are
// never equal. Two empty values are equal.
func (v Val) Equal(v2 Val) bool {
if v.Type() != v2.Type() {
return false
}
if v.IsZero() && v2.IsZero() {
return true
}
switch v.Type() {
case bsontype.Double, bsontype.DateTime, bsontype.Timestamp, bsontype.Int64:
return bytes.Equal(v.bootstrap[0:8], v2.bootstrap[0:8])
case bsontype.String:
return v.string() == v2.string()
case bsontype.EmbeddedDocument:
return v.equalDocs(v2)
case bsontype.Array:
return v.Array().Equal(v2.Array())
case bsontype.Binary:
return v.primitive.(primitive.Binary).Equal(v2.primitive.(primitive.Binary))
case bsontype.Undefined:
return true
case bsontype.ObjectID:
return bytes.Equal(v.bootstrap[0:12], v2.bootstrap[0:12])
case bsontype.Boolean:
return v.bootstrap[0] == v2.bootstrap[0]
case bsontype.Null:
return true
case bsontype.Regex:
return v.primitive.(primitive.Regex).Equal(v2.primitive.(primitive.Regex))
case bsontype.DBPointer:
return v.primitive.(primitive.DBPointer).Equal(v2.primitive.(primitive.DBPointer))
case bsontype.JavaScript:
return v.JavaScript() == v2.JavaScript()
case bsontype.Symbol:
return v.Symbol() == v2.Symbol()
case bsontype.CodeWithScope:
code1, scope1 := v.primitive.(primitive.CodeWithScope).Code, v.primitive.(primitive.CodeWithScope).Scope
code2, scope2 := v2.primitive.(primitive.CodeWithScope).Code, v2.primitive.(primitive.CodeWithScope).Scope
return code1 == code2 && v.equalInterfaceDocs(scope1, scope2)
case bsontype.Int32:
return v.Int32() == v2.Int32()
case bsontype.Decimal128:
h, l := v.Decimal128().GetBytes()
h2, l2 := v2.Decimal128().GetBytes()
return h == h2 && l == l2
case bsontype.MinKey:
return true
case bsontype.MaxKey:
return true
default:
return false
}
}
func (v Val) equalDocs(v2 Val) bool {
_, ok1 := v.primitive.(MDoc)
_, ok2 := v2.primitive.(MDoc)
if ok1 || ok2 {
return v.asMDoc().Equal(v2.asMDoc())
}
return v.asDoc().Equal(v2.asDoc())
}
func (Val) equalInterfaceDocs(i, i2 interface{}) bool {
switch d := i.(type) {
case MDoc:
d2, ok := i2.(IDoc)
if !ok {
return false
}
return d.Equal(d2)
case Doc:
d2, ok := i2.(IDoc)
if !ok {
return false
}
return d.Equal(d2)
case nil:
return i2 == nil
default:
return false
}
}

271
mongo/x/bsonx/value_test.go Normal file
View File

@@ -0,0 +1,271 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonx
import (
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func TestValue(t *testing.T) {
longstr := "foobarbazqux, hello, world!"
bytestr14 := "fourteen bytes"
bin := primitive.Binary{Subtype: 0xFF, Data: []byte{0x01, 0x02, 0x03}}
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
now := time.Now().Truncate(time.Millisecond)
nowdt := now.Unix()*1e3 + int64(now.Nanosecond()/1e6)
regex := primitive.Regex{Pattern: "/foobarbaz/", Options: "abr"}
dbptr := primitive.DBPointer{DB: "foobar", Pointer: oid}
js := "var hello ='world';"
symbol := "foobarbaz"
cws := primitive.CodeWithScope{Code: primitive.JavaScript(js), Scope: Doc{}}
code, scope := js, Doc{}
ts := primitive.Timestamp{I: 12345, T: 67890}
d128 := primitive.NewDecimal128(12345, 67890)
t.Parallel()
testCases := []struct {
name string
fn interface{} // method to call
ret []interface{} // return value
err interface{} // panic result or bool
}{
{"Interface/Double", Double(3.14159).Interface, []interface{}{float64(3.14159)}, nil},
{"Interface/String", String("foo").Interface, []interface{}{"foo"}, nil},
{"Interface/Document", Document(Doc{}).Interface, []interface{}{Doc{}}, nil},
{"Interface/Array", Array(Arr{}).Interface, []interface{}{Arr{}}, nil},
{"Interface/Binary", Binary(bin.Subtype, bin.Data).Interface, []interface{}{bin}, nil},
{"Interface/Undefined", Undefined().Interface, []interface{}{primitive.Undefined{}}, nil},
{"Interface/Null", Null().Interface, []interface{}{primitive.Null{}}, nil},
{"Interface/ObjectID", ObjectID(oid).Interface, []interface{}{oid}, nil},
{"Interface/Boolean", Boolean(true).Interface, []interface{}{bool(true)}, nil},
{"Interface/DateTime", DateTime(1234567890).Interface, []interface{}{int64(1234567890)}, nil},
{"Interface/Time", Time(now).Interface, []interface{}{nowdt}, nil},
{"Interface/Regex", Regex(regex.Pattern, regex.Options).Interface, []interface{}{regex}, nil},
{"Interface/DBPointer", DBPointer(dbptr.DB, dbptr.Pointer).Interface, []interface{}{dbptr}, nil},
{"Interface/JavaScript", JavaScript(js).Interface, []interface{}{js}, nil},
{"Interface/Symbol", Symbol(symbol).Interface, []interface{}{symbol}, nil},
{"Interface/CodeWithScope", CodeWithScope(string(cws.Code), cws.Scope.(Doc)).Interface, []interface{}{cws}, nil},
{"Interface/Int32", Int32(12345).Interface, []interface{}{int32(12345)}, nil},
{"Interface/Timestamp", Timestamp(ts.T, ts.I).Interface, []interface{}{ts}, nil},
{"Interface/Int64", Int64(1234567890).Interface, []interface{}{int64(1234567890)}, nil},
{"Interface/Decimal128", Decimal128(d128).Interface, []interface{}{d128}, nil},
{"Interface/MinKey", MinKey().Interface, []interface{}{primitive.MinKey{}}, nil},
{"Interface/MaxKey", MaxKey().Interface, []interface{}{primitive.MaxKey{}}, nil},
{"Interface/Empty", Val{}.Interface, []interface{}{primitive.Null{}}, nil},
{"IsNumber/Double", Double(0).IsNumber, []interface{}{bool(true)}, nil},
{"IsNumber/Int32", Int32(0).IsNumber, []interface{}{bool(true)}, nil},
{"IsNumber/Int64", Int64(0).IsNumber, []interface{}{bool(true)}, nil},
{"IsNumber/Decimal128", Decimal128(primitive.Decimal128{}).IsNumber, []interface{}{bool(true)}, nil},
{"IsNumber/String", String("").IsNumber, []interface{}{bool(false)}, nil},
{"Double/panic", String("").Double, nil, ElementTypeError{"bson.Value.Double", bsontype.String}},
{"Double/success", Double(3.14159).Double, []interface{}{float64(3.14159)}, nil},
{"DoubleOK/error", String("").DoubleOK, []interface{}{float64(0), false}, nil},
{"DoubleOK/success", Double(3.14159).DoubleOK, []interface{}{float64(3.14159), true}, nil},
{"String/panic", Double(0).StringValue, nil, ElementTypeError{"bson.Value.StringValue", bsontype.Double}},
{"String/success", String("bar").StringValue, []interface{}{"bar"}, nil},
{"String/14bytes", String(bytestr14).StringValue, []interface{}{bytestr14}, nil},
{"String/success(long)", String(longstr).StringValue, []interface{}{longstr}, nil},
{"StringOK/error", Double(0).StringValueOK, []interface{}{"", false}, nil},
{"StringOK/success", String("bar").StringValueOK, []interface{}{"bar", true}, nil},
{"Document/panic", Double(0).Document, nil, ElementTypeError{"bson.Value.Document", bsontype.Double}},
{"Document/success", Document(Doc{}).Document, []interface{}{Doc{}}, nil},
{"DocumentOK/error", Double(0).DocumentOK, []interface{}{(Doc)(nil), false}, nil},
{"DocumentOK/success", Document(Doc{}).DocumentOK, []interface{}{Doc{}, true}, nil},
{"MDocument/panic", Double(0).MDocument, nil, ElementTypeError{"bson.Value.MDocument", bsontype.Double}},
{"MDocument/success", Document(MDoc{}).MDocument, []interface{}{MDoc{}}, nil},
{"MDocumentOK/error", Double(0).MDocumentOK, []interface{}{(MDoc)(nil), false}, nil},
{"MDocumentOK/success", Document(MDoc{}).MDocumentOK, []interface{}{MDoc{}, true}, nil},
{"Document->MDocument/success", Document(Doc{}).MDocument, []interface{}{MDoc{}}, nil},
{"MDocument->Document/success", Document(MDoc{}).Document, []interface{}{Doc{}}, nil},
{"Document->MDocumentOK/success", Document(Doc{}).MDocumentOK, []interface{}{MDoc{}, true}, nil},
{"MDocument->DocumentOK/success", Document(MDoc{}).DocumentOK, []interface{}{Doc{}, true}, nil},
{"Array/panic", Double(0).Array, nil, ElementTypeError{"bson.Value.Array", bsontype.Double}},
{"Array/success", Array(Arr{}).Array, []interface{}{Arr{}}, nil},
{"ArrayOK/error", Double(0).ArrayOK, []interface{}{(Arr)(nil), false}, nil},
{"ArrayOK/success", Array(Arr{}).ArrayOK, []interface{}{Arr{}, true}, nil},
{"Document/NilDocument", Document((Doc)(nil)).Interface, []interface{}{primitive.Null{}}, nil},
{"Array/NilArray", Array((Arr)(nil)).Interface, []interface{}{primitive.Null{}}, nil},
{"Document/Nil", Document(nil).Interface, []interface{}{primitive.Null{}}, nil},
{"Array/Nil", Array(nil).Interface, []interface{}{primitive.Null{}}, nil},
{"Binary/panic", Double(0).Binary, nil, ElementTypeError{"bson.Value.Binary", bsontype.Double}},
{"Binary/success", Binary(bin.Subtype, bin.Data).Binary, []interface{}{bin.Subtype, bin.Data}, nil},
{"BinaryOK/error", Double(0).BinaryOK, []interface{}{byte(0x00), []byte(nil), false}, nil},
{"BinaryOK/success", Binary(bin.Subtype, bin.Data).BinaryOK, []interface{}{bin.Subtype, bin.Data, true}, nil},
{"Undefined/panic", Double(0).Undefined, nil, ElementTypeError{"bson.Value.Undefined", bsontype.Double}},
{"Undefined/success", Undefined().Undefined, nil, nil},
{"UndefinedOK/error", Double(0).UndefinedOK, []interface{}{false}, nil},
{"UndefinedOK/success", Undefined().UndefinedOK, []interface{}{true}, nil},
{"ObjectID/panic", Double(0).ObjectID, nil, ElementTypeError{"bson.Value.ObjectID", bsontype.Double}},
{"ObjectID/success", ObjectID(oid).ObjectID, []interface{}{oid}, nil},
{"ObjectIDOK/error", Double(0).ObjectIDOK, []interface{}{primitive.ObjectID{}, false}, nil},
{"ObjectIDOK/success", ObjectID(oid).ObjectIDOK, []interface{}{oid, true}, nil},
{"Boolean/panic", Double(0).Boolean, nil, ElementTypeError{"bson.Value.Boolean", bsontype.Double}},
{"Boolean/success", Boolean(true).Boolean, []interface{}{bool(true)}, nil},
{"BooleanOK/error", Double(0).BooleanOK, []interface{}{bool(false), false}, nil},
{"BooleanOK/success", Boolean(false).BooleanOK, []interface{}{false, true}, nil},
{"DateTime/panic", Double(0).DateTime, nil, ElementTypeError{"bson.Value.DateTime", bsontype.Double}},
{"DateTime/success", DateTime(1234567890).DateTime, []interface{}{int64(1234567890)}, nil},
{"DateTimeOK/error", Double(0).DateTimeOK, []interface{}{int64(0), false}, nil},
{"DateTimeOK/success", DateTime(987654321).DateTimeOK, []interface{}{int64(987654321), true}, nil},
{"Time/panic", Double(0).Time, nil, ElementTypeError{"bson.Value.Time", bsontype.Double}},
{"Time/success", Time(now).Time, []interface{}{now}, nil},
{"TimeOK/error", Double(0).TimeOK, []interface{}{time.Time{}, false}, nil},
{"TimeOK/success", Time(now).TimeOK, []interface{}{now, true}, nil},
{"Time->DateTime", Time(now).DateTime, []interface{}{nowdt}, nil},
{"DateTime->Time", DateTime(nowdt).Time, []interface{}{now}, nil},
{"Null/panic", Double(0).Null, nil, ElementTypeError{"bson.Value.Null", bsontype.Double}},
{"Null/success", Null().Null, nil, nil},
{"NullOK/error", Double(0).NullOK, []interface{}{false}, nil},
{"NullOK/success", Null().NullOK, []interface{}{true}, nil},
{"Regex/panic", Double(0).Regex, nil, ElementTypeError{"bson.Value.Regex", bsontype.Double}},
{"Regex/success", Regex(regex.Pattern, regex.Options).Regex, []interface{}{regex.Pattern, regex.Options}, nil},
{"RegexOK/error", Double(0).RegexOK, []interface{}{"", "", false}, nil},
{"RegexOK/success", Regex(regex.Pattern, regex.Options).RegexOK, []interface{}{regex.Pattern, regex.Options, true}, nil},
{"DBPointer/panic", Double(0).DBPointer, nil, ElementTypeError{"bson.Value.DBPointer", bsontype.Double}},
{"DBPointer/success", DBPointer(dbptr.DB, dbptr.Pointer).DBPointer, []interface{}{dbptr.DB, dbptr.Pointer}, nil},
{"DBPointerOK/error", Double(0).DBPointerOK, []interface{}{"", primitive.ObjectID{}, false}, nil},
{"DBPointerOK/success", DBPointer(dbptr.DB, dbptr.Pointer).DBPointerOK, []interface{}{dbptr.DB, dbptr.Pointer, true}, nil},
{"JavaScript/panic", Double(0).JavaScript, nil, ElementTypeError{"bson.Value.JavaScript", bsontype.Double}},
{"JavaScript/success", JavaScript(js).JavaScript, []interface{}{js}, nil},
{"JavaScriptOK/error", Double(0).JavaScriptOK, []interface{}{"", false}, nil},
{"JavaScriptOK/success", JavaScript(js).JavaScriptOK, []interface{}{js, true}, nil},
{"Symbol/panic", Double(0).Symbol, nil, ElementTypeError{"bson.Value.Symbol", bsontype.Double}},
{"Symbol/success", Symbol(symbol).Symbol, []interface{}{symbol}, nil},
{"SymbolOK/error", Double(0).SymbolOK, []interface{}{"", false}, nil},
{"SymbolOK/success", Symbol(symbol).SymbolOK, []interface{}{symbol, true}, nil},
{"CodeWithScope/panic", Double(0).CodeWithScope, nil, ElementTypeError{"bson.Value.CodeWithScope", bsontype.Double}},
{"CodeWithScope/success", CodeWithScope(code, scope).CodeWithScope, []interface{}{code, scope}, nil},
{"CodeWithScopeOK/error", Double(0).CodeWithScopeOK, []interface{}{"", (Doc)(nil), false}, nil},
{"CodeWithScopeOK/success", CodeWithScope(code, scope).CodeWithScopeOK, []interface{}{code, scope, true}, nil},
{"Int32/panic", Double(0).Int32, nil, ElementTypeError{"bson.Value.Int32", bsontype.Double}},
{"Int32/success", Int32(12345).Int32, []interface{}{int32(12345)}, nil},
{"Int32OK/error", Double(0).Int32OK, []interface{}{int32(0), false}, nil},
{"Int32OK/success", Int32(54321).Int32OK, []interface{}{int32(54321), true}, nil},
{"Timestamp/panic", Double(0).Timestamp, nil, ElementTypeError{"bson.Value.Timestamp", bsontype.Double}},
{"Timestamp/success", Timestamp(ts.T, ts.I).Timestamp, []interface{}{ts.T, ts.I}, nil},
{"TimestampOK/error", Double(0).TimestampOK, []interface{}{uint32(0), uint32(0), false}, nil},
{"TimestampOK/success", Timestamp(ts.T, ts.I).TimestampOK, []interface{}{ts.T, ts.I, true}, nil},
{"Int64/panic", Double(0).Int64, nil, ElementTypeError{"bson.Value.Int64", bsontype.Double}},
{"Int64/success", Int64(1234567890).Int64, []interface{}{int64(1234567890)}, nil},
{"Int64OK/error", Double(0).Int64OK, []interface{}{int64(0), false}, nil},
{"Int64OK/success", Int64(9876543210).Int64OK, []interface{}{int64(9876543210), true}, nil},
{"Decimal128/panic", Double(0).Decimal128, nil, ElementTypeError{"bson.Value.Decimal128", bsontype.Double}},
{"Decimal128/success", Decimal128(d128).Decimal128, []interface{}{d128}, nil},
{"Decimal128OK/error", Double(0).Decimal128OK, []interface{}{primitive.Decimal128{}, false}, nil},
{"Decimal128OK/success", Decimal128(d128).Decimal128OK, []interface{}{d128, true}, nil},
{"MinKey/panic", Double(0).MinKey, nil, ElementTypeError{"bson.Value.MinKey", bsontype.Double}},
{"MinKey/success", MinKey().MinKey, nil, nil},
{"MinKeyOK/error", Double(0).MinKeyOK, []interface{}{false}, nil},
{"MinKeyOK/success", MinKey().MinKeyOK, []interface{}{true}, nil},
{"MaxKey/panic", Double(0).MaxKey, nil, ElementTypeError{"bson.Value.MaxKey", bsontype.Double}},
{"MaxKey/success", MaxKey().MaxKey, nil, nil},
{"MaxKeyOK/error", Double(0).MaxKeyOK, []interface{}{false}, nil},
{"MaxKeyOK/success", MaxKey().MaxKeyOK, []interface{}{true}, nil},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
defer func() {
err := recover()
if err != nil && !cmp.Equal(err, tc.err) {
t.Errorf("panic errors are not equal. got %v; want %v", err, tc.err)
if tc.err == nil {
panic(err)
}
}
}()
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be a function, but is a %s", fn.Kind())
}
ret := fn.Call(nil)
if len(ret) != len(tc.ret) {
t.Fatalf("number of returned values does not match. got %d; want %d", len(ret), len(tc.ret))
}
for idx := range ret {
got, want := ret[idx].Interface(), tc.ret[idx]
if !cmp.Equal(got, want, cmp.Comparer(compareDecimal128)) {
t.Errorf("Return %d does not match. got %v; want %v", idx, got, want)
}
}
})
}
t.Run("Equal", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
v1 Val
v2 Val
res bool
}{
{"Different Types", String(""), Double(0), false},
{"Unknown Types", Val{t: bsontype.Type(0x77)}, Val{t: bsontype.Type(0x77)}, false},
{"Empty Types", Val{}, Val{}, true},
{"Double/Equal", Double(3.14159), Double(3.14159), true},
{"Double/Not Equal", Double(3.14159), Double(9.51413), false},
{"DateTime/Equal", DateTime(nowdt), DateTime(nowdt), true},
{"DateTime/Not Equal", DateTime(nowdt), DateTime(0), false},
{"String/Equal", String("hello"), String("hello"), true},
{"String/Not Equal", String("hello"), String("world"), false},
{"Document/Equal", Document(Doc{}), Document(Doc{}), true},
{"Document/Not Equal", Document(Doc{}), Document(Doc{{"", Null()}}), false},
{"Array/Equal", Array(Arr{}), Array(Arr{}), true},
{"Array/Not Equal", Array(Arr{}), Array(Arr{Null()}), false},
{"Binary/Equal", Binary(bin.Subtype, bin.Data), Binary(bin.Subtype, bin.Data), true},
{"Binary/Not Equal", Binary(bin.Subtype, bin.Data), Binary(0x00, nil), false},
{"Undefined/Equal", Undefined(), Undefined(), true},
{"ObjectID/Equal", ObjectID(oid), ObjectID(oid), true},
{"ObjectID/Not Equal", ObjectID(oid), ObjectID(primitive.ObjectID{}), false},
{"Boolean/Equal", Boolean(true), Boolean(true), true},
{"Boolean/Not Equal", Boolean(true), Boolean(false), false},
{"Null/Equal", Null(), Null(), true},
{"Regex/Equal", Regex(regex.Pattern, regex.Options), Regex(regex.Pattern, regex.Options), true},
{"Regex/Not Equal", Regex(regex.Pattern, regex.Options), Regex("", ""), false},
{"DBPointer/Equal", DBPointer(dbptr.DB, dbptr.Pointer), DBPointer(dbptr.DB, dbptr.Pointer), true},
{"DBPointer/Not Equal", DBPointer(dbptr.DB, dbptr.Pointer), DBPointer("", primitive.ObjectID{}), false},
{"JavaScript/Equal", JavaScript(js), JavaScript(js), true},
{"JavaScript/Not Equal", JavaScript(js), JavaScript(""), false},
{"Symbol/Equal", Symbol(symbol), Symbol(symbol), true},
{"Symbol/Not Equal", Symbol(symbol), Symbol(""), false},
{"CodeWithScope/Equal", CodeWithScope(code, scope), CodeWithScope(code, scope), true},
{"CodeWithScope/Equal (equal scope)", CodeWithScope(code, scope), CodeWithScope(code, Doc{}), true},
{"CodeWithScope/Not Equal", CodeWithScope(code, scope), CodeWithScope("", nil), false},
{"Int32/Equal", Int32(12345), Int32(12345), true},
{"Int32/Not Equal", Int32(12345), Int32(54321), false},
{"Timestamp/Equal", Timestamp(ts.T, ts.I), Timestamp(ts.T, ts.I), true},
{"Timestamp/Not Equal", Timestamp(ts.T, ts.I), Timestamp(0, 0), false},
{"Int64/Equal", Int64(1234567890), Int64(1234567890), true},
{"Int64/Not Equal", Int64(1234567890), Int64(9876543210), false},
{"Decimal128/Equal", Decimal128(d128), Decimal128(d128), true},
{"Decimal128/Not Equal", Decimal128(d128), Decimal128(primitive.Decimal128{}), false},
{"MinKey/Equal", MinKey(), MinKey(), true},
{"MaxKey/Equal", MaxKey(), MaxKey(), true},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
res := tc.v1.Equal(tc.v2)
if res != tc.res {
t.Errorf("results do not match. got %v; want %v", res, tc.res)
}
})
}
})
}

View File

@@ -0,0 +1,23 @@
# Driver Library Design
This document outlines the design for this package.
## Deployment, Server, and Connection
Acquiring a `Connection` from a `Server` selected from a `Deployment` enables sending and receiving
wire messages. A `Deployment` represents an set of MongoDB servers and a `Server` represents a
member of that set. These three types form the operation execution stack.
### Compression
Compression is handled by Connection type while uncompression is handled automatically by the
Operation type. This is done because the compressor to use for compressing a wire message is
chosen by the connection during handshake, while uncompression can be performed without this
information. This does make the design of compression non-symmetric, but it makes the design simpler
to implement and more consistent.
## Operation
The `Operation` type handles executing a series of commands using a `Deployment`. For most uses
`Operation` will only execute a single command, but the main use case for a series of commands is
batch split write commands, such as insert. The type itself is heavily documented, so reading the
code and comments together should provide an understanding of how the type works.
This type is not meant to be used directly by callers. Instead a wrapping type should be defined
using the IDL.

View File

@@ -0,0 +1,229 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"errors"
"fmt"
"net/http"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// AuthenticatorFactory constructs an authenticator.
type AuthenticatorFactory func(cred *Cred) (Authenticator, error)
var authFactories = make(map[string]AuthenticatorFactory)
func init() {
RegisterAuthenticatorFactory("", newDefaultAuthenticator)
RegisterAuthenticatorFactory(SCRAMSHA1, newScramSHA1Authenticator)
RegisterAuthenticatorFactory(SCRAMSHA256, newScramSHA256Authenticator)
RegisterAuthenticatorFactory(MONGODBCR, newMongoDBCRAuthenticator)
RegisterAuthenticatorFactory(PLAIN, newPlainAuthenticator)
RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator)
RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator)
RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator)
}
// CreateAuthenticator creates an authenticator.
func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
if f, ok := authFactories[name]; ok {
return f(cred)
}
return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
}
// RegisterAuthenticatorFactory registers the authenticator factory.
func RegisterAuthenticatorFactory(name string, factory AuthenticatorFactory) {
authFactories[name] = factory
}
// HandshakeOptions packages options that can be passed to the Handshaker()
// function. DBUser is optional but must be of the form <dbname.username>;
// if non-empty, then the connection will do SASL mechanism negotiation.
type HandshakeOptions struct {
AppName string
Authenticator Authenticator
Compressors []string
DBUser string
PerformAuthentication func(description.Server) bool
ClusterClock *session.ClusterClock
ServerAPI *driver.ServerAPIOptions
LoadBalanced bool
HTTPClient *http.Client
}
type authHandshaker struct {
wrapped driver.Handshaker
options *HandshakeOptions
handshakeInfo driver.HandshakeInformation
conversation SpeculativeConversation
}
var _ driver.Handshaker = (*authHandshaker)(nil)
// GetHandshakeInformation performs the initial MongoDB handshake to retrieve the required information for the provided
// connection.
func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
if ah.wrapped != nil {
return ah.wrapped.GetHandshakeInformation(ctx, addr, conn)
}
op := operation.NewHello().
AppName(ah.options.AppName).
Compressors(ah.options.Compressors).
SASLSupportedMechs(ah.options.DBUser).
ClusterClock(ah.options.ClusterClock).
ServerAPI(ah.options.ServerAPI).
LoadBalanced(ah.options.LoadBalanced)
if ah.options.Authenticator != nil {
if speculativeAuth, ok := ah.options.Authenticator.(SpeculativeAuthenticator); ok {
var err error
ah.conversation, err = speculativeAuth.CreateSpeculativeConversation()
if err != nil {
return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err)
}
firstMsg, err := ah.conversation.FirstMessage()
if err != nil {
return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err)
}
op = op.SpeculativeAuthenticate(firstMsg)
}
}
var err error
ah.handshakeInfo, err = op.GetHandshakeInformation(ctx, addr, conn)
if err != nil {
return driver.HandshakeInformation{}, newAuthError("handshake failure", err)
}
return ah.handshakeInfo, nil
}
// FinishHandshake performs authentication for conn if necessary.
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
performAuth := ah.options.PerformAuthentication
if performAuth == nil {
performAuth = func(serv description.Server) bool {
// Authentication is possible against all server types except arbiters
return serv.Kind != description.RSArbiter
}
}
desc := conn.Description()
if performAuth(desc) && ah.options.Authenticator != nil {
cfg := &Config{
Description: desc,
Connection: conn,
ClusterClock: ah.options.ClusterClock,
HandshakeInfo: ah.handshakeInfo,
ServerAPI: ah.options.ServerAPI,
HTTPClient: ah.options.HTTPClient,
}
if err := ah.authenticate(ctx, cfg); err != nil {
return newAuthError("auth error", err)
}
}
if ah.wrapped == nil {
return nil
}
return ah.wrapped.FinishHandshake(ctx, conn)
}
func (ah *authHandshaker) authenticate(ctx context.Context, cfg *Config) error {
// If the initial hello reply included a response to the speculative authentication attempt, we only need to
// conduct the remainder of the conversation.
if speculativeResponse := ah.handshakeInfo.SpeculativeAuthenticate; speculativeResponse != nil {
// Defensively ensure that the server did not include a response if speculative auth was not attempted.
if ah.conversation == nil {
return errors.New("speculative auth was not attempted but the server included a response")
}
return ah.conversation.Finish(ctx, cfg, speculativeResponse)
}
// If the server does not support speculative authentication or the first attempt was not successful, we need to
// perform authentication from scratch.
return ah.options.Authenticator.Auth(ctx, cfg)
}
// Handshaker creates a connection handshaker for the given authenticator.
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
return &authHandshaker{
wrapped: h,
options: options,
}
}
// Config holds the information necessary to perform an authentication attempt.
type Config struct {
Description description.Server
Connection driver.Connection
ClusterClock *session.ClusterClock
HandshakeInfo driver.HandshakeInformation
ServerAPI *driver.ServerAPIOptions
HTTPClient *http.Client
}
// Authenticator handles authenticating a connection.
type Authenticator interface {
// Auth authenticates the connection.
Auth(context.Context, *Config) error
}
func newAuthError(msg string, inner error) error {
return &Error{
message: msg,
inner: inner,
}
}
func newError(err error, mech string) error {
return &Error{
message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
inner: err,
}
}
// Error is an error that occurred during authentication.
type Error struct {
message string
inner error
}
func (e *Error) Error() string {
if e.inner == nil {
return e.message
}
return fmt.Sprintf("%s: %s", e.message, e.inner)
}
// Inner returns the wrapped error.
func (e *Error) Inner() error {
return e.inner
}
// Unwrap returns the underlying error.
func (e *Error) Unwrap() error {
return e.inner
}
// Message returns the message.
func (e *Error) Message() string {
return e.message
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"path"
"testing"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
"go.mongodb.org/mongo-driver/mongo/options"
)
type credential struct {
Username string
Password *string
Source string
Mechanism string
MechProps map[string]interface{} `json:"mechanism_properties"`
}
type testCase struct {
Description string
URI string
Valid bool
Credential *credential
}
type testContainer struct {
Tests []testCase
}
// Note a test supporting the deprecated gssapiServiceName property was removed from data/auth/auth_tests.json
const authTestsDir = "../../../../testdata/auth/"
func runTestsInFile(t *testing.T, dirname string, filename string) {
filepath := path.Join(dirname, filename)
content, err := ioutil.ReadFile(filepath)
require.NoError(t, err)
var container testContainer
require.NoError(t, json.Unmarshal(content, &container))
// Remove ".json" from filename.
filename = filename[:len(filename)-5]
for _, testCase := range container.Tests {
runTest(t, filename, testCase)
}
}
func runTest(t *testing.T, filename string, test testCase) {
t.Run(filename+":"+test.Description, func(t *testing.T) {
opts := options.Client().ApplyURI(test.URI)
if test.Valid {
require.NoError(t, opts.Validate())
} else {
require.Error(t, opts.Validate())
return
}
if test.Credential == nil {
require.Nil(t, opts.Auth)
return
}
require.NotNil(t, opts.Auth)
require.Equal(t, test.Credential.Username, opts.Auth.Username)
if test.Credential.Password == nil {
require.False(t, opts.Auth.PasswordSet)
} else {
require.True(t, opts.Auth.PasswordSet)
require.Equal(t, *test.Credential.Password, opts.Auth.Password)
}
require.Equal(t, test.Credential.Source, opts.Auth.AuthSource)
require.Equal(t, test.Credential.Mechanism, opts.Auth.AuthMechanism)
if len(test.Credential.MechProps) > 0 {
require.Equal(t, mapInterfaceToString(test.Credential.MechProps), opts.Auth.AuthMechanismProperties)
} else {
require.Equal(t, 0, len(opts.Auth.AuthMechanismProperties))
}
})
}
// Convert each interface{} value in the map to a string.
func mapInterfaceToString(m map[string]interface{}) map[string]string {
out := make(map[string]string)
for key, value := range m {
out[key] = fmt.Sprint(value)
}
return out
}
// Test case for all connection string spec tests.
func TestAuthSpec(t *testing.T) {
for _, file := range helpers.FindJSONFilesInDir(t, authTestsDir) {
runTestsInFile(t, authTestsDir, file)
}
}

View File

@@ -0,0 +1,126 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
func TestCreateAuthenticator(t *testing.T) {
tests := []struct {
name string
source string
auth Authenticator
}{
{name: "", auth: &DefaultAuthenticator{}},
{name: "SCRAM-SHA-1", auth: &ScramAuthenticator{}},
{name: "SCRAM-SHA-256", auth: &ScramAuthenticator{}},
{name: "MONGODB-CR", auth: &MongoDBCRAuthenticator{}},
{name: "PLAIN", auth: &PlainAuthenticator{}},
{name: "MONGODB-X509", auth: &MongoDBX509Authenticator{}},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cred := &Cred{
Username: "user",
Password: "pencil",
PasswordSet: true,
}
a, err := CreateAuthenticator(test.name, cred)
require.NoError(t, err)
require.IsType(t, test.auth, a)
})
}
}
func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document, dbName string) {
_, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
var actualPayload bsoncore.Document
switch opcode {
case wiremessage.OpQuery:
_, wm, ok := wiremessage.ReadQueryFlags(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
actualPayload, _, ok = wiremessage.ReadQueryQuery(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
case wiremessage.OpMsg:
// Append the $db field.
elems, err := expectedPayload.Elements()
if err != nil {
t.Fatalf("expectedPayload is not valid: %v", err)
}
elems = append(elems, bsoncore.AppendStringElement(nil, "$db", dbName))
elems = append(elems, bsoncore.AppendDocumentElement(nil,
"$readPreference",
bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred")),
))
bslc := make([][]byte, 0, len(elems)) // BuildDocumentFromElements takes a [][]byte, not a []bsoncore.Element.
for _, elem := range elems {
bslc = append(bslc, elem)
}
expectedPayload = bsoncore.BuildDocumentFromElements(nil, bslc...)
_, wm, ok := wiremessage.ReadMsgFlags(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
loop:
for {
var stype wiremessage.SectionType
stype, wm, ok = wiremessage.ReadMsgSectionType(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
break
}
switch stype {
case wiremessage.DocumentSequence:
_, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
break loop
}
case wiremessage.SingleDocument:
actualPayload, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm)
if !ok {
t.Fatalf("wiremessage is too short to unmarshal")
}
break loop
}
}
}
if !cmp.Equal(actualPayload, expectedPayload) {
t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload)
}
}

View File

@@ -0,0 +1,348 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/awsv4"
)
type clientState int
const (
clientStarting clientState = iota
clientFirst
clientFinal
clientDone
)
type awsConversation struct {
state clientState
valid bool
nonce []byte
username string
password string
token string
httpClient *http.Client
}
type serverMessage struct {
Nonce primitive.Binary `bson:"s"`
Host string `bson:"h"`
}
type ecsResponse struct {
AccessKeyID string `json:"AccessKeyId"`
SecretAccessKey string `json:"SecretAccessKey"`
Token string `json:"Token"`
}
const (
amzDateFormat = "20060102T150405Z"
awsRelativeURI = "http://169.254.170.2/"
awsEC2URI = "http://169.254.169.254/"
awsEC2RolePath = "latest/meta-data/iam/security-credentials/"
awsEC2TokenPath = "latest/api/token"
defaultRegion = "us-east-1"
maxHostLength = 255
defaultHTTPTimeout = 10 * time.Second
responceNonceLength = 64
)
// Step takes a string provided from a server (or just an empty string for the
// very first conversation step) and attempts to move the authentication
// conversation forward. It returns a string to be sent to the server or an
// error if the server message is invalid. Calling Step after a conversation
// completes is also an error.
func (ac *awsConversation) Step(challenge []byte) (response []byte, err error) {
switch ac.state {
case clientStarting:
ac.state = clientFirst
response = ac.firstMsg()
case clientFirst:
ac.state = clientFinal
response, err = ac.finalMsg(challenge)
case clientFinal:
ac.state = clientDone
ac.valid = true
default:
response, err = nil, errors.New("Conversation already completed")
}
return
}
// Done returns true if the conversation is completed or has errored.
func (ac *awsConversation) Done() bool {
return ac.state == clientDone
}
// Valid returns true if the conversation successfully authenticated with the
// server, including counter-validation that the server actually has the
// user's stored credentials.
func (ac *awsConversation) Valid() bool {
return ac.valid
}
func getRegion(host string) (string, error) {
region := defaultRegion
if len(host) == 0 {
return "", errors.New("invalid STS host: empty")
}
if len(host) > maxHostLength {
return "", errors.New("invalid STS host: too large")
}
// The implicit region for sts.amazonaws.com is us-east-1
if host == "sts.amazonaws.com" {
return region, nil
}
if strings.HasPrefix(host, ".") || strings.HasSuffix(host, ".") || strings.Contains(host, "..") {
return "", errors.New("invalid STS host: empty part")
}
// If the host has multiple parts, the second part is the region
parts := strings.Split(host, ".")
if len(parts) >= 2 {
region = parts[1]
}
return region, nil
}
func (ac *awsConversation) validateAndMakeCredentials() (*awsv4.StaticProvider, error) {
if ac.username != "" && ac.password == "" {
return nil, errors.New("ACCESS_KEY_ID is set, but SECRET_ACCESS_KEY is missing")
}
if ac.username == "" && ac.password != "" {
return nil, errors.New("SECRET_ACCESS_KEY is set, but ACCESS_KEY_ID is missing")
}
if ac.username == "" && ac.password == "" && ac.token != "" {
return nil, errors.New("AWS_SESSION_TOKEN is set, but ACCESS_KEY_ID and SECRET_ACCESS_KEY are missing")
}
if ac.username != "" || ac.password != "" || ac.token != "" {
return &awsv4.StaticProvider{Value: awsv4.Value{
AccessKeyID: ac.username,
SecretAccessKey: ac.password,
SessionToken: ac.token,
}}, nil
}
return nil, nil
}
func executeAWSHTTPRequest(httpClient *http.Client, req *http.Request) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultHTTPTimeout)
defer cancel()
resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
defer resp.Body.Close()
return ioutil.ReadAll(resp.Body)
}
func (ac *awsConversation) getEC2Credentials() (*awsv4.StaticProvider, error) {
// get token
req, err := http.NewRequest("PUT", awsEC2URI+awsEC2TokenPath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "30")
token, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
if len(token) == 0 {
return nil, errors.New("unable to retrieve token from EC2 metadata")
}
tokenStr := string(token)
// get role name
req, err = http.NewRequest("GET", awsEC2URI+awsEC2RolePath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
role, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
if len(role) == 0 {
return nil, errors.New("unable to retrieve role_name from EC2 metadata")
}
// get credentials
pathWithRole := awsEC2URI + awsEC2RolePath + string(role)
req, err = http.NewRequest("GET", pathWithRole, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-aws-ec2-metadata-token", tokenStr)
creds, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
var es2Resp ecsResponse
err = json.Unmarshal(creds, &es2Resp)
if err != nil {
return nil, err
}
ac.username = es2Resp.AccessKeyID
ac.password = es2Resp.SecretAccessKey
ac.token = es2Resp.Token
return ac.validateAndMakeCredentials()
}
func (ac *awsConversation) getCredentials() (*awsv4.StaticProvider, error) {
// Credentials passed through URI
creds, err := ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
// Credentials from environment variables
ac.username = os.Getenv("AWS_ACCESS_KEY_ID")
ac.password = os.Getenv("AWS_SECRET_ACCESS_KEY")
ac.token = os.Getenv("AWS_SESSION_TOKEN")
creds, err = ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
// Credentials from ECS metadata
relativeEcsURI := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI")
if len(relativeEcsURI) > 0 {
fullURI := awsRelativeURI + relativeEcsURI
req, err := http.NewRequest("GET", fullURI, nil)
if err != nil {
return nil, err
}
body, err := executeAWSHTTPRequest(ac.httpClient, req)
if err != nil {
return nil, err
}
var espResp ecsResponse
err = json.Unmarshal(body, &espResp)
if err != nil {
return nil, err
}
ac.username = espResp.AccessKeyID
ac.password = espResp.SecretAccessKey
ac.token = espResp.Token
creds, err = ac.validateAndMakeCredentials()
if creds != nil || err != nil {
return creds, err
}
}
// Credentials from EC2 metadata
creds, err = ac.getEC2Credentials()
if creds == nil && err == nil {
return nil, errors.New("unable to get credentials")
}
return creds, err
}
func (ac *awsConversation) firstMsg() []byte {
// Values are cached for use in final message parameters
ac.nonce = make([]byte, 32)
_, _ = rand.Read(ac.nonce)
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendInt32Element(msg, "p", 110)
msg = bsoncore.AppendBinaryElement(msg, "r", 0x00, ac.nonce)
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg
}
func (ac *awsConversation) finalMsg(s1 []byte) ([]byte, error) {
var sm serverMessage
err := bson.Unmarshal(s1, &sm)
if err != nil {
return nil, err
}
// Check nonce prefix
if sm.Nonce.Subtype != 0x00 {
return nil, errors.New("server reply contained unexpected binary subtype")
}
if len(sm.Nonce.Data) != responceNonceLength {
return nil, fmt.Errorf("server reply nonce was not %v bytes", responceNonceLength)
}
if !bytes.HasPrefix(sm.Nonce.Data, ac.nonce) {
return nil, errors.New("server nonce did not extend client nonce")
}
region, err := getRegion(sm.Host)
if err != nil {
return nil, err
}
creds, err := ac.getCredentials()
if err != nil {
return nil, err
}
currentTime := time.Now().UTC()
body := "Action=GetCallerIdentity&Version=2011-06-15"
// Create http.Request
req, _ := http.NewRequest("POST", "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Content-Length", "43")
req.Host = sm.Host
req.Header.Set("X-Amz-Date", currentTime.Format(amzDateFormat))
if len(ac.token) > 0 {
req.Header.Set("X-Amz-Security-Token", ac.token)
}
req.Header.Set("X-MongoDB-Server-Nonce", base64.StdEncoding.EncodeToString(sm.Nonce.Data))
req.Header.Set("X-MongoDB-GS2-CB-Flag", "n")
// Create signer with credentials
signer := awsv4.NewSigner(creds)
// Get signed header
_, err = signer.Sign(req, strings.NewReader(body), "sts", region, currentTime)
if err != nil {
return nil, err
}
// create message
idx, msg := bsoncore.AppendDocumentStart(nil)
msg = bsoncore.AppendStringElement(msg, "a", req.Header.Get("Authorization"))
msg = bsoncore.AppendStringElement(msg, "d", req.Header.Get("X-Amz-Date"))
if len(ac.token) > 0 {
msg = bsoncore.AppendStringElement(msg, "t", ac.token)
}
msg, _ = bsoncore.AppendDocumentEnd(msg, idx)
return msg, nil
}

View File

@@ -0,0 +1,31 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// SpeculativeConversation represents an authentication conversation that can be merged with the initial connection
// handshake.
//
// FirstMessage method returns the first message to be sent to the server. This message will be included in the initial
// hello command.
//
// Finish takes the server response to the initial message and conducts the remainder of the conversation to
// authenticate the provided connection.
type SpeculativeConversation interface {
FirstMessage() (bsoncore.Document, error)
Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error
}
// SpeculativeAuthenticator represents an authenticator that supports speculative authentication.
type SpeculativeAuthenticator interface {
CreateSpeculativeConversation() (SpeculativeConversation, error)
}

View File

@@ -0,0 +1,16 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
// Cred is a user's credential.
type Cred struct {
Source string
Username string
Password string
PasswordSet bool
Props map[string]string
}

View File

@@ -0,0 +1,98 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo/description"
)
func newDefaultAuthenticator(cred *Cred) (Authenticator, error) {
scram, err := newScramSHA256Authenticator(cred)
if err != nil {
return nil, newAuthError("failed to create internal authenticator", err)
}
speculative, ok := scram.(SpeculativeAuthenticator)
if !ok {
typeErr := fmt.Errorf("expected SCRAM authenticator to be SpeculativeAuthenticator but got %T", scram)
return nil, newAuthError("failed to create internal authenticator", typeErr)
}
return &DefaultAuthenticator{
Cred: cred,
speculativeAuthenticator: speculative,
}, nil
}
// DefaultAuthenticator uses SCRAM-SHA-1 or MONGODB-CR depending
// on the server version.
type DefaultAuthenticator struct {
Cred *Cred
// The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing
// the initial hello, SCRAM-SHA-256 is used for the speculative attempt.
speculativeAuthenticator SpeculativeAuthenticator
}
var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil)
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
func (a *DefaultAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return a.speculativeAuthenticator.CreateSpeculativeConversation()
}
// Auth authenticates the connection.
func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error {
var actual Authenticator
var err error
switch chooseAuthMechanism(cfg) {
case SCRAMSHA256:
actual, err = newScramSHA256Authenticator(a.Cred)
case SCRAMSHA1:
actual, err = newScramSHA1Authenticator(a.Cred)
default:
actual, err = newMongoDBCRAuthenticator(a.Cred)
}
if err != nil {
return newAuthError("error creating authenticator", err)
}
return actual.Auth(ctx, cfg)
}
// If a server provides a list of supported mechanisms, we choose
// SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1.
// Otherwise, we decide based on what is supported.
func chooseAuthMechanism(cfg *Config) string {
if saslSupportedMechs := cfg.HandshakeInfo.SaslSupportedMechs; saslSupportedMechs != nil {
for _, v := range saslSupportedMechs {
if v == SCRAMSHA256 {
return v
}
}
return SCRAMSHA1
}
if err := scramSHA1Supported(cfg.HandshakeInfo.Description.WireVersion); err == nil {
return SCRAMSHA1
}
return MONGODBCR
}
// scramSHA1Supported returns an error if the given server version does not support scram-sha-1.
func scramSHA1Supported(wireVersion *description.VersionRange) error {
if wireVersion != nil && wireVersion.Max < 3 {
return fmt.Errorf("SCRAM-SHA-1 is only supported for servers 3.0 or newer")
}
return nil
}

View File

@@ -0,0 +1,23 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package auth is not for public use.
//
// The API for packages in the 'private' directory have no stability
// guarantee.
//
// The packages within the 'private' directory would normally be put into an
// 'internal' directory to prohibit their use outside the 'mongo' directory.
// However, some MongoDB tools require very low-level access to the building
// blocks of a driver, so we have placed them under 'private' to allow these
// packages to be imported by projects that need them.
//
// These package APIs may be modified in backwards-incompatible ways at any
// time.
//
// You are strongly discouraged from directly using any packages
// under 'private'.
package auth

View File

@@ -0,0 +1,59 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && (windows || linux || darwin)
// +build gssapi
// +build windows linux darwin
package auth
import (
"context"
"fmt"
"net"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
}
return &GSSAPIAuthenticator{
Username: cred.Username,
Password: cred.Password,
PasswordSet: cred.PasswordSet,
Props: cred.Props,
}, nil
}
// GSSAPIAuthenticator uses the GSSAPI algorithm over SASL to authenticate a connection.
type GSSAPIAuthenticator struct {
Username string
Password string
PasswordSet bool
Props map[string]string
}
// Auth authenticates the connection.
func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error {
target := cfg.Description.Addr.String()
hostname, _, err := net.SplitHostPort(target)
if err != nil {
return newAuthError(fmt.Sprintf("invalid endpoint (%s) specified: %s", target, err), nil)
}
client, err := gssapi.New(hostname, a.Username, a.Password, a.PasswordSet, a.Props)
if err != nil {
return newAuthError("error creating gssapi", err)
}
return ConductSaslConversation(ctx, cfg, "$external", client)
}

View File

@@ -0,0 +1,17 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !gssapi
// +build !gssapi
package auth
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil)
}

View File

@@ -0,0 +1,22 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && !windows && !linux && !darwin
// +build gssapi,!windows,!linux,!darwin
package auth
import (
"fmt"
"runtime"
)
// GSSAPI is the mechanism name for GSSAPI.
const GSSAPI = "GSSAPI"
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil)
}

View File

@@ -0,0 +1,45 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi
// +build gssapi
package auth
import (
"context"
"testing"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
)
func TestGSSAPIAuthenticator(t *testing.T) {
t.Run("PropsError", func(t *testing.T) {
// Cannot specify both CANONICALIZE_HOST_NAME and SERVICE_HOST
authenticator := &GSSAPIAuthenticator{
Username: "foo",
Password: "bar",
PasswordSet: true,
Props: map[string]string{
"CANONICALIZE_HOST_NAME": "true",
"SERVICE_HOST": "localhost",
},
}
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
Addr: address.Address("foo:27017"),
}
err := authenticator.Auth(context.Background(), &Config{Description: desc})
if err == nil {
t.Fatalf("expected err, got nil")
}
})
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/credentials/static_provider.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/credentials/credentials.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"errors"
)
// StaticProviderName provides a name of Static provider
const StaticProviderName = "StaticProvider"
var (
// ErrStaticCredentialsEmpty is emitted when static credentials are empty.
ErrStaticCredentialsEmpty = errors.New("EmptyStaticCreds: static credentials are empty")
)
// A Value is the AWS credentials value for individual credential fields.
type Value struct {
// AWS Access key ID
AccessKeyID string
// AWS Secret Access Key
SecretAccessKey string
// AWS Session Token
SessionToken string
// Provider used to get credentials
ProviderName string
}
// HasKeys returns if the credentials Value has both AccessKeyID and
// SecretAccessKey value set.
func (v Value) HasKeys() bool {
return len(v.AccessKeyID) != 0 && len(v.SecretAccessKey) != 0
}
// A StaticProvider is a set of credentials which are set programmatically,
// and will never expire.
type StaticProvider struct {
Value
}
// Retrieve returns the credentials or error if the credentials are invalid.
func (s *StaticProvider) Retrieve() (Value, error) {
if s.AccessKeyID == "" || s.SecretAccessKey == "" {
return Value{ProviderName: StaticProviderName}, ErrStaticCredentialsEmpty
}
if len(s.Value.ProviderName) == 0 {
s.Value.ProviderName = StaticProviderName
}
return s.Value, nil
}

View File

@@ -0,0 +1,15 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go v1.34.28 by Amazon.com, Inc.
// See THIRD-PARTY-NOTICES for original license terms
// Package awsv4 implements signing for AWS V4 signer with static credentials,
// and is based on and modified from code in the package aws-sdk-go. The
// modifications remove non-static credentials, support for non-sts services,
// and the options for v4.Signer. They also reduce the number of non-Go
// library dependencies.
package awsv4

View File

@@ -0,0 +1,80 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"net/http"
"strings"
)
// Returns host from request
func getHost(r *http.Request) string {
if r.Host != "" {
return r.Host
}
if r.URL == nil {
return ""
}
return r.URL.Host
}
// Hostname returns u.Host, without any port number.
//
// If Host is an IPv6 literal with a port number, Hostname returns the
// IPv6 literal without the square brackets. IPv6 literals may include
// a zone identifier.
//
// Copied from the Go 1.8 standard library (net/url)
func stripPort(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return hostport
}
if i := strings.IndexByte(hostport, ']'); i != -1 {
return strings.TrimPrefix(hostport[:i], "[")
}
return hostport[:colon]
}
// Port returns the port part of u.Host, without the leading colon.
// If u.Host doesn't contain a port, Port returns an empty string.
//
// Copied from the Go 1.8 standard library (net/url)
func portOnly(hostport string) string {
colon := strings.IndexByte(hostport, ':')
if colon == -1 {
return ""
}
if i := strings.Index(hostport, "]:"); i != -1 {
return hostport[i+len("]:"):]
}
if strings.Contains(hostport, "]") {
return ""
}
return hostport[colon+len(":"):]
}
// Returns true if the specified URI is using the standard port
// (i.e. port 80 for HTTP URIs or 443 for HTTPS URIs)
func isDefaultPort(scheme, port string) bool {
if port == "" {
return true
}
lowerCaseScheme := strings.ToLower(scheme)
if (lowerCaseScheme == "http" && port == "80") || (lowerCaseScheme == "https" && port == "443") {
return true
}
return false
}

View File

@@ -0,0 +1,46 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/private/protocol/rest/build.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"bytes"
"fmt"
)
// Whether the byte value can be sent without escaping in AWS URLs
var noEscape [256]bool
func init() {
for i := 0; i < len(noEscape); i++ {
// AWS expects every character except these to be escaped
noEscape[i] = (i >= 'A' && i <= 'Z') ||
(i >= 'a' && i <= 'z') ||
(i >= '0' && i <= '9') ||
i == '-' ||
i == '.' ||
i == '_' ||
i == '~'
}
}
// EscapePath escapes part of a URL path in Amazon style
func EscapePath(path string, encodeSep bool) string {
var buf bytes.Buffer
for i := 0; i < len(path); i++ {
c := path[i]
if noEscape[c] || (c == '/' && !encodeSep) {
buf.WriteByte(c)
} else {
fmt.Fprintf(&buf, "%%%02X", c)
}
}
return buf.String()
}

View File

@@ -0,0 +1,98 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/header_rules.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/internal/strings/strings.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"strings"
)
// validator houses a set of rule needed for validation of a
// string value
type rules []rule
// rule interface allows for more flexible rules and just simply
// checks whether or not a value adheres to that rule
type rule interface {
IsValid(value string) bool
}
// IsValid will iterate through all rules and see if any rules
// apply to the value and supports nested rules
func (r rules) IsValid(value string) bool {
for _, rule := range r {
if rule.IsValid(value) {
return true
}
}
return false
}
// mapRule generic rule for maps
type mapRule map[string]struct{}
// IsValid for the map rule satisfies whether it exists in the map
func (m mapRule) IsValid(value string) bool {
_, ok := m[value]
return ok
}
// allowlist is a generic rule for allowlisting
type allowlist struct {
rule
}
// IsValid for allowlist checks if the value is within the allowlist
func (a allowlist) IsValid(value string) bool {
return a.rule.IsValid(value)
}
// denylist is a generic rule for denylisting
type denylist struct {
rule
}
// IsValid for allowlist checks if the value is within the allowlist
func (d denylist) IsValid(value string) bool {
return !d.rule.IsValid(value)
}
type patterns []string
// hasPrefixFold tests whether the string s begins with prefix, interpreted as UTF-8 strings,
// under Unicode case-folding.
func hasPrefixFold(s, prefix string) bool {
return len(s) >= len(prefix) && strings.EqualFold(s[0:len(prefix)], prefix)
}
// IsValid for patterns checks each pattern and returns if a match has
// been found
func (p patterns) IsValid(value string) bool {
for _, pattern := range p {
if hasPrefixFold(value, pattern) {
return true
}
}
return false
}
// inclusiveRules rules allow for rules to depend on one another
type inclusiveRules []rule
// IsValid will return true if all rules are true
func (r inclusiveRules) IsValid(value string) bool {
for _, rule := range r {
if !rule.IsValid(value) {
return false
}
}
return true
}

View File

@@ -0,0 +1,472 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/aws/aws-sdk-go by Amazon.com, Inc. with code from:
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/request/request.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/v4.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/signer/v4/uri_path.go
// - github.com/aws/aws-sdk-go/blob/v1.34.28/aws/types.go
// See THIRD-PARTY-NOTICES for original license terms
package awsv4
import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
const (
authorizationHeader = "Authorization"
authHeaderSignatureElem = "Signature="
authHeaderPrefix = "AWS4-HMAC-SHA256"
timeFormat = "20060102T150405Z"
shortTimeFormat = "20060102"
awsV4Request = "aws4_request"
// emptyStringSHA256 is a SHA256 of an empty string
emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855`
)
var ignoredHeaders = rules{
denylist{
mapRule{
authorizationHeader: struct{}{},
"User-Agent": struct{}{},
"X-Amzn-Trace-Id": struct{}{},
},
},
}
// Signer applies AWS v4 signing to given request. Use this to sign requests
// that need to be signed with AWS V4 Signatures.
type Signer struct {
Credentials *StaticProvider
}
// NewSigner returns a Signer pointer configured with the credentials and optional
// option values provided. If not options are provided the Signer will use its
// default configuration.
func NewSigner(credentials *StaticProvider) *Signer {
v4 := &Signer{
Credentials: credentials,
}
return v4
}
type signingCtx struct {
ServiceName string
Region string
Request *http.Request
Body io.ReadSeeker
Query url.Values
Time time.Time
SignedHeaderVals http.Header
credValues Value
bodyDigest string
signedHeaders string
canonicalHeaders string
canonicalString string
credentialString string
stringToSign string
signature string
authorization string
}
// Sign signs AWS v4 requests with the provided body, service name, region the
// request is made to, and time the request is signed at. The signTime allows
// you to specify that a request is signed for the future, and cannot be
// used until then.
//
// Returns a list of HTTP headers that were included in the signature or an
// error if signing the request failed. Generally for signed requests this value
// is not needed as the full request context will be captured by the http.Request
// value. It is included for reference though.
//
// Sign will set the request's Body to be the `body` parameter passed in. If
// the body is not already an io.ReadCloser, it will be wrapped within one. If
// a `nil` body parameter passed to Sign, the request's Body field will be
// also set to nil. Its important to note that this functionality will not
// change the request's ContentLength of the request.
//
// Sign differs from Presign in that it will sign the request using HTTP
// header values. This type of signing is intended for http.Request values that
// will not be shared, or are shared in a way the header values on the request
// will not be lost.
//
// The requests body is an io.ReadSeeker so the SHA256 of the body can be
// generated. To bypass the signer computing the hash you can set the
// "X-Amz-Content-Sha256" header with a precomputed value. The signer will
// only compute the hash if the request header value is empty.
func (v4 Signer) Sign(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
return v4.signWithBody(r, body, service, region, signTime)
}
func (v4 Signer) signWithBody(r *http.Request, body io.ReadSeeker, service, region string, signTime time.Time) (http.Header, error) {
ctx := &signingCtx{
Request: r,
Body: body,
Query: r.URL.Query(),
Time: signTime,
ServiceName: service,
Region: region,
}
for key := range ctx.Query {
sort.Strings(ctx.Query[key])
}
if ctx.isRequestSigned() {
ctx.Time = time.Now()
}
var err error
ctx.credValues, err = v4.Credentials.Retrieve()
if err != nil {
return http.Header{}, err
}
ctx.sanitizeHostForHeader()
ctx.assignAmzQueryValues()
if err := ctx.build(); err != nil {
return nil, err
}
var reader io.ReadCloser
if body != nil {
var ok bool
if reader, ok = body.(io.ReadCloser); !ok {
reader = ioutil.NopCloser(body)
}
}
r.Body = reader
return ctx.SignedHeaderVals, nil
}
// sanitizeHostForHeader removes default port from host and updates request.Host
func (ctx *signingCtx) sanitizeHostForHeader() {
r := ctx.Request
host := getHost(r)
port := portOnly(host)
if port != "" && isDefaultPort(r.URL.Scheme, port) {
r.Host = stripPort(host)
}
}
func (ctx *signingCtx) assignAmzQueryValues() {
if ctx.credValues.SessionToken != "" {
ctx.Request.Header.Set("X-Amz-Security-Token", ctx.credValues.SessionToken)
}
}
func (ctx *signingCtx) build() error {
ctx.buildTime() // no depends
ctx.buildCredentialString() // no depends
if err := ctx.buildBodyDigest(); err != nil {
return err
}
unsignedHeaders := ctx.Request.Header
ctx.buildCanonicalHeaders(ignoredHeaders, unsignedHeaders)
ctx.buildCanonicalString() // depends on canon headers / signed headers
ctx.buildStringToSign() // depends on canon string
ctx.buildSignature() // depends on string to sign
parts := []string{
authHeaderPrefix + " Credential=" + ctx.credValues.AccessKeyID + "/" + ctx.credentialString,
"SignedHeaders=" + ctx.signedHeaders,
authHeaderSignatureElem + ctx.signature,
}
ctx.Request.Header.Set(authorizationHeader, strings.Join(parts, ", "))
return nil
}
// GetSignedRequestSignature attempts to extract the signature of the request.
// Returning an error if the request is unsigned, or unable to extract the
// signature.
func GetSignedRequestSignature(r *http.Request) ([]byte, error) {
if auth := r.Header.Get(authorizationHeader); len(auth) != 0 {
ps := strings.Split(auth, ", ")
for _, p := range ps {
if idx := strings.Index(p, authHeaderSignatureElem); idx >= 0 {
sig := p[len(authHeaderSignatureElem):]
if len(sig) == 0 {
return nil, fmt.Errorf("invalid request signature authorization header")
}
return hex.DecodeString(sig)
}
}
}
if sig := r.URL.Query().Get("X-Amz-Signature"); len(sig) != 0 {
return hex.DecodeString(sig)
}
return nil, fmt.Errorf("request not signed")
}
func (ctx *signingCtx) buildTime() {
ctx.Request.Header.Set("X-Amz-Date", formatTime(ctx.Time))
}
func (ctx *signingCtx) buildCredentialString() {
ctx.credentialString = buildSigningScope(ctx.Region, ctx.ServiceName, ctx.Time)
}
func (ctx *signingCtx) buildCanonicalHeaders(r rule, header http.Header) {
headers := make([]string, 0, len(header))
headers = append(headers, "host")
for k, v := range header {
if !r.IsValid(k) {
continue // ignored header
}
if ctx.SignedHeaderVals == nil {
ctx.SignedHeaderVals = make(http.Header)
}
lowerCaseKey := strings.ToLower(k)
if _, ok := ctx.SignedHeaderVals[lowerCaseKey]; ok {
// include additional values
ctx.SignedHeaderVals[lowerCaseKey] = append(ctx.SignedHeaderVals[lowerCaseKey], v...)
continue
}
headers = append(headers, lowerCaseKey)
ctx.SignedHeaderVals[lowerCaseKey] = v
}
sort.Strings(headers)
ctx.signedHeaders = strings.Join(headers, ";")
headerValues := make([]string, len(headers))
for i, k := range headers {
if k == "host" {
if ctx.Request.Host != "" {
headerValues[i] = "host:" + ctx.Request.Host
} else {
headerValues[i] = "host:" + ctx.Request.URL.Host
}
} else {
headerValues[i] = k + ":" +
strings.Join(ctx.SignedHeaderVals[k], ",")
}
}
stripExcessSpaces(headerValues)
ctx.canonicalHeaders = strings.Join(headerValues, "\n")
}
func getURIPath(u *url.URL) string {
var uri string
if len(u.Opaque) > 0 {
uri = "/" + strings.Join(strings.Split(u.Opaque, "/")[3:], "/")
} else {
uri = u.EscapedPath()
}
if len(uri) == 0 {
uri = "/"
}
return uri
}
func (ctx *signingCtx) buildCanonicalString() {
ctx.Request.URL.RawQuery = strings.Replace(ctx.Query.Encode(), "+", "%20", -1)
uri := getURIPath(ctx.Request.URL)
uri = EscapePath(uri, false)
ctx.canonicalString = strings.Join([]string{
ctx.Request.Method,
uri,
ctx.Request.URL.RawQuery,
ctx.canonicalHeaders + "\n",
ctx.signedHeaders,
ctx.bodyDigest,
}, "\n")
}
func (ctx *signingCtx) buildStringToSign() {
ctx.stringToSign = strings.Join([]string{
authHeaderPrefix,
formatTime(ctx.Time),
ctx.credentialString,
hex.EncodeToString(hashSHA256([]byte(ctx.canonicalString))),
}, "\n")
}
func (ctx *signingCtx) buildSignature() {
creds := deriveSigningKey(ctx.Region, ctx.ServiceName, ctx.credValues.SecretAccessKey, ctx.Time)
signature := hmacSHA256(creds, []byte(ctx.stringToSign))
ctx.signature = hex.EncodeToString(signature)
}
func (ctx *signingCtx) buildBodyDigest() error {
hash := ctx.Request.Header.Get("X-Amz-Content-Sha256")
if hash == "" {
if ctx.Body == nil {
hash = emptyStringSHA256
} else {
hashBytes, err := makeSha256Reader(ctx.Body)
if err != nil {
return err
}
hash = hex.EncodeToString(hashBytes)
}
}
ctx.bodyDigest = hash
return nil
}
// isRequestSigned returns if the request is currently signed or presigned
func (ctx *signingCtx) isRequestSigned() bool {
return ctx.Request.Header.Get("Authorization") != ""
}
func hmacSHA256(key []byte, data []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(data)
return hash.Sum(nil)
}
func hashSHA256(data []byte) []byte {
hash := sha256.New()
hash.Write(data)
return hash.Sum(nil)
}
// seekerLen attempts to get the number of bytes remaining at the seeker's
// current position. Returns the number of bytes remaining or error.
func seekerLen(s io.Seeker) (int64, error) {
curOffset, err := s.Seek(0, io.SeekCurrent)
if err != nil {
return 0, err
}
endOffset, err := s.Seek(0, io.SeekEnd)
if err != nil {
return 0, err
}
_, err = s.Seek(curOffset, io.SeekStart)
if err != nil {
return 0, err
}
return endOffset - curOffset, nil
}
func makeSha256Reader(reader io.ReadSeeker) (hashBytes []byte, err error) {
hash := sha256.New()
start, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
defer func() {
// ensure error is return if unable to seek back to start of payload.
_, err = reader.Seek(start, io.SeekStart)
}()
// Use CopyN to avoid allocating the 32KB buffer in io.Copy for bodies
// smaller than 32KB. Fall back to io.Copy if we fail to determine the size.
size, err := seekerLen(reader)
if err != nil {
_, _ = io.Copy(hash, reader)
} else {
_, _ = io.CopyN(hash, reader, size)
}
return hash.Sum(nil), nil
}
const doubleSpace = " "
// stripExcessSpaces will rewrite the passed in slice's string values to not
// contain multiple side-by-side spaces.
func stripExcessSpaces(vals []string) {
var j, k, l, m, spaces int
for i, str := range vals {
// Trim trailing spaces
for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- {
}
// Trim leading spaces
for k = 0; k < j && str[k] == ' '; k++ {
}
str = str[k : j+1]
// Strip multiple spaces.
j = strings.Index(str, doubleSpace)
if j < 0 {
vals[i] = str
continue
}
buf := []byte(str)
for k, m, l = j, j, len(buf); k < l; k++ {
if buf[k] == ' ' {
if spaces == 0 {
// First space.
buf[m] = buf[k]
m++
}
spaces++
} else {
// End of multiple spaces.
spaces = 0
buf[m] = buf[k]
m++
}
}
vals[i] = string(buf[:m])
}
}
func buildSigningScope(region, service string, dt time.Time) string {
return strings.Join([]string{
formatShortTime(dt),
region,
service,
awsV4Request,
}, "/")
}
func deriveSigningKey(region, service, secretKey string, dt time.Time) []byte {
keyDate := hmacSHA256([]byte("AWS4"+secretKey), []byte(formatShortTime(dt)))
keyRegion := hmacSHA256(keyDate, []byte(region))
keyService := hmacSHA256(keyRegion, []byte(service))
signingKey := hmacSHA256(keyService, []byte(awsV4Request))
return signingKey
}
func formatShortTime(dt time.Time) string {
return dt.UTC().Format(shortTimeFormat)
}
func formatTime(dt time.Time) string {
return dt.UTC().Format(timeFormat)
}

View File

@@ -0,0 +1,167 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && (linux || darwin)
// +build gssapi
// +build linux darwin
package gssapi
/*
#cgo linux CFLAGS: -DGOOS_linux
#cgo linux LDFLAGS: -lgssapi_krb5 -lkrb5
#cgo darwin CFLAGS: -DGOOS_darwin
#cgo darwin LDFLAGS: -framework GSS
#include "gss_wrapper.h"
*/
import "C"
import (
"fmt"
"runtime"
"strings"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
serviceName := "mongodb"
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_REALM":
return nil, fmt.Errorf("SERVICE_REALM is not supported when using gssapi on %s", runtime.GOOS)
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
target = value
default:
return nil, fmt.Errorf("unknown mechanism property %s", key)
}
}
servicePrincipalName := fmt.Sprintf("%s@%s", serviceName, target)
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.gssapi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.gssapi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.gssapi_client_init(&sc.state, cservicePrincipalName, cusername, cpassword)
if status != C.GSSAPI_OK {
return mechName, nil, sc.getError("unable to initialize client")
}
payload, err := sc.Next(nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
var buf unsafe.Pointer
var bufLen C.size_t
var outBuf unsafe.Pointer
var outBufLen C.size_t
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.gssapi_client_username(&sc.state, &cusername)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf = unsafe.Pointer(&bytes[0])
bufLen = C.size_t(len(bytes))
status := C.gssapi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.GSSAPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
if len(challenge) > 0 {
buf = unsafe.Pointer(&challenge[0])
bufLen = C.size_t(len(challenge))
}
status := C.gssapi_client_negotiate(&sc.state, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.GSSAPI_OK:
sc.contextComplete = true
case C.GSSAPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != nil {
defer C.free(outBuf)
}
return C.GoBytes(outBuf, C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
var desc *C.char
status := C.gssapi_error_desc(sc.state.maj_stat, sc.state.min_stat, &desc)
if status != C.GSSAPI_OK {
if desc != nil {
C.free(unsafe.Pointer(desc))
}
return fmt.Errorf("%s: (%v, %v)", prefix, sc.state.maj_stat, sc.state.min_stat)
}
defer C.free(unsafe.Pointer(desc))
return fmt.Errorf("%s: %v(%v,%v)", prefix, C.GoString(desc), int32(sc.state.maj_stat), int32(sc.state.min_stat))
}

View File

@@ -0,0 +1,254 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
#include <string.h>
#include <stdio.h>
#include "gss_wrapper.h"
OM_uint32 gssapi_canonicalize_name(
OM_uint32* minor_status,
char *input_name,
gss_OID input_name_type,
gss_name_t *output_name
)
{
OM_uint32 major_status;
gss_name_t imported_name = GSS_C_NO_NAME;
gss_buffer_desc buffer = GSS_C_EMPTY_BUFFER;
buffer.value = input_name;
buffer.length = strlen(input_name);
major_status = gss_import_name(minor_status, &buffer, input_name_type, &imported_name);
if (GSS_ERROR(major_status)) {
return major_status;
}
major_status = gss_canonicalize_name(minor_status, imported_name, (gss_OID)gss_mech_krb5, output_name);
if (imported_name != GSS_C_NO_NAME) {
OM_uint32 ignored;
gss_release_name(&ignored, &imported_name);
}
return major_status;
}
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
)
{
OM_uint32 stat = maj_stat;
int stat_type = GSS_C_GSS_CODE;
if (min_stat != 0) {
stat = min_stat;
stat_type = GSS_C_MECH_CODE;
}
OM_uint32 local_maj_stat, local_min_stat;
OM_uint32 msg_ctx = 0;
gss_buffer_desc desc_buffer;
do
{
local_maj_stat = gss_display_status(
&local_min_stat,
stat,
stat_type,
GSS_C_NO_OID,
&msg_ctx,
&desc_buffer
);
if (GSS_ERROR(local_maj_stat)) {
return GSSAPI_ERROR;
}
if (*desc) {
free(*desc);
}
*desc = malloc(desc_buffer.length+1);
memcpy(*desc, desc_buffer.value, desc_buffer.length+1);
gss_release_buffer(&local_min_stat, &desc_buffer);
}
while(msg_ctx != 0);
return GSSAPI_OK;
}
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
)
{
client->cred = GSS_C_NO_CREDENTIAL;
client->ctx = GSS_C_NO_CONTEXT;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, spn, GSS_C_NT_HOSTBASED_SERVICE, &client->spn);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (username) {
gss_name_t name;
client->maj_stat = gssapi_canonicalize_name(&client->min_stat, username, GSS_C_NT_USER_NAME, &name);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
if (password) {
gss_buffer_desc password_buffer;
password_buffer.value = password;
password_buffer.length = strlen(password);
client->maj_stat = gss_acquire_cred_with_password(&client->min_stat, name, &password_buffer, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
} else {
client->maj_stat = gss_acquire_cred(&client->min_stat, name, GSS_C_INDEFINITE, GSS_C_NO_OID_SET, GSS_C_INITIATE, &client->cred, NULL, NULL);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
OM_uint32 ignored;
gss_release_name(&ignored, &name);
}
return GSSAPI_OK;
}
int gssapi_client_username(
gssapi_client_state *client,
char** username
)
{
OM_uint32 ignored;
gss_name_t name = GSS_C_NO_NAME;
client->maj_stat = gss_inquire_context(&client->min_stat, client->ctx, &name, NULL, NULL, NULL, NULL, NULL, NULL);
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
gss_buffer_desc name_buffer;
client->maj_stat = gss_display_name(&client->min_stat, name, &name_buffer, NULL);
if (GSS_ERROR(client->maj_stat)) {
gss_release_name(&ignored, &name);
return GSSAPI_ERROR;
}
*username = malloc(name_buffer.length+1);
memcpy(*username, name_buffer.value, name_buffer.length+1);
gss_release_buffer(&ignored, &name_buffer);
gss_release_name(&ignored, &name);
return GSSAPI_OK;
}
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
if (input) {
input_buffer.value = input;
input_buffer.length = input_length;
}
client->maj_stat = gss_init_sec_context(
&client->min_stat,
client->cred,
&client->ctx,
client->spn,
GSS_C_NO_OID,
GSS_C_MUTUAL_FLAG | GSS_C_SEQUENCE_FLAG,
0,
GSS_C_NO_CHANNEL_BINDINGS,
&input_buffer,
NULL,
&output_buffer,
NULL,
NULL
);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
OM_uint32 ignored;
gss_release_buffer(&ignored, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
} else if (client->maj_stat == GSS_S_CONTINUE_NEEDED) {
return GSSAPI_CONTINUE;
}
return GSSAPI_OK;
}
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
)
{
gss_buffer_desc input_buffer = GSS_C_EMPTY_BUFFER;
gss_buffer_desc output_buffer = GSS_C_EMPTY_BUFFER;
input_buffer.value = input;
input_buffer.length = input_length;
client->maj_stat = gss_wrap(&client->min_stat, client->ctx, 0, GSS_C_QOP_DEFAULT, &input_buffer, NULL, &output_buffer);
if (output_buffer.length) {
*output = malloc(output_buffer.length);
*output_length = output_buffer.length;
memcpy(*output, output_buffer.value, output_buffer.length);
gss_release_buffer(&client->min_stat, &output_buffer);
}
if (GSS_ERROR(client->maj_stat)) {
return GSSAPI_ERROR;
}
return GSSAPI_OK;
}
int gssapi_client_destroy(
gssapi_client_state *client
)
{
OM_uint32 ignored;
if (client->ctx != GSS_C_NO_CONTEXT) {
gss_delete_sec_context(&ignored, &client->ctx, GSS_C_NO_BUFFER);
}
if (client->spn != GSS_C_NO_NAME) {
gss_release_name(&ignored, &client->spn);
}
if (client->cred != GSS_C_NO_CREDENTIAL) {
gss_release_cred(&ignored, &client->cred);
}
return GSSAPI_OK;
}

View File

@@ -0,0 +1,72 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi
//+build linux darwin
#ifndef GSS_WRAPPER_H
#define GSS_WRAPPER_H
#include <stdlib.h>
#ifdef GOOS_linux
#include <gssapi/gssapi.h>
#include <gssapi/gssapi_krb5.h>
#endif
#ifdef GOOS_darwin
#include <GSS/GSS.h>
#endif
#define GSSAPI_OK 0
#define GSSAPI_CONTINUE 1
#define GSSAPI_ERROR 2
typedef struct {
gss_name_t spn;
gss_cred_id_t cred;
gss_ctx_id_t ctx;
OM_uint32 maj_stat;
OM_uint32 min_stat;
} gssapi_client_state;
int gssapi_error_desc(
OM_uint32 maj_stat,
OM_uint32 min_stat,
char **desc
);
int gssapi_client_init(
gssapi_client_state *client,
char* spn,
char* username,
char* password
);
int gssapi_client_username(
gssapi_client_state *client,
char** username
);
int gssapi_client_negotiate(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_wrap_msg(
gssapi_client_state *client,
void* input,
size_t input_length,
void** output,
size_t* output_length
);
int gssapi_client_destroy(
gssapi_client_state *client
);
#endif

View File

@@ -0,0 +1,353 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build gssapi && windows
// +build gssapi,windows
package gssapi
// #include "sspi_wrapper.h"
import "C"
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
"unsafe"
)
// New creates a new SaslClient. The target parameter should be a hostname with no port.
func New(target, username, password string, passwordSet bool, props map[string]string) (*SaslClient, error) {
initOnce.Do(initSSPI)
if initError != nil {
return nil, initError
}
var err error
serviceName := "mongodb"
serviceRealm := ""
canonicalizeHostName := false
var serviceHostSet bool
for key, value := range props {
switch strings.ToUpper(key) {
case "CANONICALIZE_HOST_NAME":
canonicalizeHostName, err = strconv.ParseBool(value)
if err != nil {
return nil, fmt.Errorf("%s must be a boolean (true, false, 0, 1) but got '%s'", key, value)
}
case "SERVICE_REALM":
serviceRealm = value
case "SERVICE_NAME":
serviceName = value
case "SERVICE_HOST":
serviceHostSet = true
target = value
}
}
if canonicalizeHostName {
// Should not canonicalize the SERVICE_HOST
if serviceHostSet {
return nil, fmt.Errorf("CANONICALIZE_HOST_NAME and SERVICE_HOST canonot both be specified")
}
names, err := net.LookupAddr(target)
if err != nil || len(names) == 0 {
return nil, fmt.Errorf("unable to canonicalize hostname: %s", err)
}
target = names[0]
if target[len(target)-1] == '.' {
target = target[:len(target)-1]
}
}
servicePrincipalName := fmt.Sprintf("%s/%s", serviceName, target)
if serviceRealm != "" {
servicePrincipalName += "@" + serviceRealm
}
return &SaslClient{
servicePrincipalName: servicePrincipalName,
username: username,
password: password,
passwordSet: passwordSet,
}, nil
}
type SaslClient struct {
servicePrincipalName string
username string
password string
passwordSet bool
// state
state C.sspi_client_state
contextComplete bool
done bool
}
func (sc *SaslClient) Close() {
C.sspi_client_destroy(&sc.state)
}
func (sc *SaslClient) Start() (string, []byte, error) {
const mechName = "GSSAPI"
var cusername *C.char
var cpassword *C.char
if sc.username != "" {
cusername = C.CString(sc.username)
defer C.free(unsafe.Pointer(cusername))
if sc.passwordSet {
cpassword = C.CString(sc.password)
defer C.free(unsafe.Pointer(cpassword))
}
}
status := C.sspi_client_init(&sc.state, cusername, cpassword)
if status != C.SSPI_OK {
return mechName, nil, sc.getError("unable to intitialize client")
}
payload, err := sc.Next(nil)
return mechName, payload, err
}
func (sc *SaslClient) Next(challenge []byte) ([]byte, error) {
var outBuf C.PVOID
var outBufLen C.ULONG
if sc.contextComplete {
if sc.username == "" {
var cusername *C.char
status := C.sspi_client_username(&sc.state, &cusername)
if status != C.SSPI_OK {
return nil, sc.getError("unable to acquire username")
}
defer C.free(unsafe.Pointer(cusername))
sc.username = C.GoString((*C.char)(unsafe.Pointer(cusername)))
}
bytes := append([]byte{1, 0, 0, 0}, []byte(sc.username)...)
buf := (C.PVOID)(unsafe.Pointer(&bytes[0]))
bufLen := C.ULONG(len(bytes))
status := C.sspi_client_wrap_msg(&sc.state, buf, bufLen, &outBuf, &outBufLen)
if status != C.SSPI_OK {
return nil, sc.getError("unable to wrap authz")
}
sc.done = true
} else {
var buf C.PVOID
var bufLen C.ULONG
if len(challenge) > 0 {
buf = (C.PVOID)(unsafe.Pointer(&challenge[0]))
bufLen = C.ULONG(len(challenge))
}
cservicePrincipalName := C.CString(sc.servicePrincipalName)
defer C.free(unsafe.Pointer(cservicePrincipalName))
status := C.sspi_client_negotiate(&sc.state, cservicePrincipalName, buf, bufLen, &outBuf, &outBufLen)
switch status {
case C.SSPI_OK:
sc.contextComplete = true
case C.SSPI_CONTINUE:
default:
return nil, sc.getError("unable to negotiate with server")
}
}
if outBuf != C.PVOID(nil) {
defer C.free(unsafe.Pointer(outBuf))
}
return C.GoBytes(unsafe.Pointer(outBuf), C.int(outBufLen)), nil
}
func (sc *SaslClient) Completed() bool {
return sc.done
}
func (sc *SaslClient) getError(prefix string) error {
return getError(prefix, sc.state.status)
}
var initOnce sync.Once
var initError error
func initSSPI() {
rc := C.sspi_init()
if rc != 0 {
initError = fmt.Errorf("error initializing sspi: %v", rc)
}
}
func getError(prefix string, status C.SECURITY_STATUS) error {
var s string
switch status {
case C.SEC_E_ALGORITHM_MISMATCH:
s = "The client and server cannot communicate because they do not possess a common algorithm."
case C.SEC_E_BAD_BINDINGS:
s = "The SSPI channel bindings supplied by the client are incorrect."
case C.SEC_E_BAD_PKGID:
s = "The requested package identifier does not exist."
case C.SEC_E_BUFFER_TOO_SMALL:
s = "The buffers supplied to the function are not large enough to contain the information."
case C.SEC_E_CANNOT_INSTALL:
s = "The security package cannot initialize successfully and should not be installed."
case C.SEC_E_CANNOT_PACK:
s = "The package is unable to pack the context."
case C.SEC_E_CERT_EXPIRED:
s = "The received certificate has expired."
case C.SEC_E_CERT_UNKNOWN:
s = "An unknown error occurred while processing the certificate."
case C.SEC_E_CERT_WRONG_USAGE:
s = "The certificate is not valid for the requested usage."
case C.SEC_E_CONTEXT_EXPIRED:
s = "The application is referencing a context that has already been closed. A properly written application should not receive this error."
case C.SEC_E_CROSSREALM_DELEGATION_FAILURE:
s = "The server attempted to make a Kerberos-constrained delegation request for a target outside the server's realm."
case C.SEC_E_CRYPTO_SYSTEM_INVALID:
s = "The cryptographic system or checksum function is not valid because a required function is unavailable."
case C.SEC_E_DECRYPT_FAILURE:
s = "The specified data could not be decrypted."
case C.SEC_E_DELEGATION_REQUIRED:
s = "The requested operation cannot be completed. The computer must be trusted for delegation"
case C.SEC_E_DOWNGRADE_DETECTED:
s = "The system detected a possible attempt to compromise security. Verify that the server that authenticated you can be contacted."
case C.SEC_E_ENCRYPT_FAILURE:
s = "The specified data could not be encrypted."
case C.SEC_E_ILLEGAL_MESSAGE:
s = "The message received was unexpected or badly formatted."
case C.SEC_E_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. The context could not be initialized."
case C.SEC_E_INCOMPLETE_MESSAGE:
s = "The message supplied was incomplete. The signature was not verified."
case C.SEC_E_INSUFFICIENT_MEMORY:
s = "Not enough memory is available to complete the request."
case C.SEC_E_INTERNAL_ERROR:
s = "An error occurred that did not map to an SSPI error code."
case C.SEC_E_INVALID_HANDLE:
s = "The handle passed to the function is not valid."
case C.SEC_E_INVALID_TOKEN:
s = "The token passed to the function is not valid."
case C.SEC_E_ISSUING_CA_UNTRUSTED:
s = "An untrusted certification authority (CA) was detected while processing the smart card certificate used for authentication."
case C.SEC_E_ISSUING_CA_UNTRUSTED_KDC:
s = "An untrusted CA was detected while processing the domain controller certificate used for authentication. The system event log contains additional information."
case C.SEC_E_KDC_CERT_EXPIRED:
s = "The domain controller certificate used for smart card logon has expired."
case C.SEC_E_KDC_CERT_REVOKED:
s = "The domain controller certificate used for smart card logon has been revoked."
case C.SEC_E_KDC_INVALID_REQUEST:
s = "A request that is not valid was sent to the KDC."
case C.SEC_E_KDC_UNABLE_TO_REFER:
s = "The KDC was unable to generate a referral for the service requested."
case C.SEC_E_KDC_UNKNOWN_ETYPE:
s = "The requested encryption type is not supported by the KDC."
case C.SEC_E_LOGON_DENIED:
s = "The logon has been denied"
case C.SEC_E_MAX_REFERRALS_EXCEEDED:
s = "The number of maximum ticket referrals has been exceeded."
case C.SEC_E_MESSAGE_ALTERED:
s = "The message supplied for verification has been altered."
case C.SEC_E_MULTIPLE_ACCOUNTS:
s = "The received certificate was mapped to multiple accounts."
case C.SEC_E_MUST_BE_KDC:
s = "The local computer must be a Kerberos domain controller (KDC)"
case C.SEC_E_NO_AUTHENTICATING_AUTHORITY:
s = "No authority could be contacted for authentication."
case C.SEC_E_NO_CREDENTIALS:
s = "No credentials are available."
case C.SEC_E_NO_IMPERSONATION:
s = "No impersonation is allowed for this context."
case C.SEC_E_NO_IP_ADDRESSES:
s = "Unable to accomplish the requested task because the local computer does not have any IP addresses."
case C.SEC_E_NO_KERB_KEY:
s = "No Kerberos key was found."
case C.SEC_E_NO_PA_DATA:
s = "Policy administrator (PA) data is needed to determine the encryption type"
case C.SEC_E_NO_S4U_PROT_SUPPORT:
s = "The Kerberos subsystem encountered an error. A service for user protocol request was made against a domain controller which does not support service for a user."
case C.SEC_E_NO_TGT_REPLY:
s = "The client is trying to negotiate a context and the server requires a user-to-user connection"
case C.SEC_E_NOT_OWNER:
s = "The caller of the function does not own the credentials."
case C.SEC_E_OK:
s = "The operation completed successfully."
case C.SEC_E_OUT_OF_SEQUENCE:
s = "The message supplied for verification is out of sequence."
case C.SEC_E_PKINIT_CLIENT_FAILURE:
s = "The smart card certificate used for authentication is not trusted."
case C.SEC_E_PKINIT_NAME_MISMATCH:
s = "The client certificate does not contain a valid UPN or does not match the client name in the logon request."
case C.SEC_E_QOP_NOT_SUPPORTED:
s = "The quality of protection attribute is not supported by this package."
case C.SEC_E_REVOCATION_OFFLINE_C:
s = "The revocation status of the smart card certificate used for authentication could not be determined."
case C.SEC_E_REVOCATION_OFFLINE_KDC:
s = "The revocation status of the domain controller certificate used for smart card authentication could not be determined. The system event log contains additional information."
case C.SEC_E_SECPKG_NOT_FOUND:
s = "The security package was not recognized."
case C.SEC_E_SECURITY_QOS_FAILED:
s = "The security context could not be established due to a failure in the requested quality of service (for example"
case C.SEC_E_SHUTDOWN_IN_PROGRESS:
s = "A system shutdown is in progress."
case C.SEC_E_SMARTCARD_CERT_EXPIRED:
s = "The smart card certificate used for authentication has expired."
case C.SEC_E_SMARTCARD_CERT_REVOKED:
s = "The smart card certificate used for authentication has been revoked. Additional information may exist in the event log."
case C.SEC_E_SMARTCARD_LOGON_REQUIRED:
s = "Smart card logon is required and was not used."
case C.SEC_E_STRONG_CRYPTO_NOT_SUPPORTED:
s = "The other end of the security negotiation requires strong cryptography"
case C.SEC_E_TARGET_UNKNOWN:
s = "The target was not recognized."
case C.SEC_E_TIME_SKEW:
s = "The clocks on the client and server computers do not match."
case C.SEC_E_TOO_MANY_PRINCIPALS:
s = "The KDC reply contained more than one principal name."
case C.SEC_E_UNFINISHED_CONTEXT_DELETED:
s = "A security context was deleted before the context was completed. This is considered a logon failure."
case C.SEC_E_UNKNOWN_CREDENTIALS:
s = "The credentials provided were not recognized."
case C.SEC_E_UNSUPPORTED_FUNCTION:
s = "The requested function is not supported."
case C.SEC_E_UNSUPPORTED_PREAUTH:
s = "An unsupported preauthentication mechanism was presented to the Kerberos package."
case C.SEC_E_UNTRUSTED_ROOT:
s = "The certificate chain was issued by an authority that is not trusted."
case C.SEC_E_WRONG_CREDENTIAL_HANDLE:
s = "The supplied credential handle does not match the credential associated with the security context."
case C.SEC_E_WRONG_PRINCIPAL:
s = "The target principal name is incorrect."
case C.SEC_I_COMPLETE_AND_CONTINUE:
s = "The function completed successfully"
case C.SEC_I_COMPLETE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_CONTEXT_EXPIRED:
s = "The message sender has finished using the connection and has initiated a shutdown. For information about initiating or recognizing a shutdown"
case C.SEC_I_CONTINUE_NEEDED:
s = "The function completed successfully"
case C.SEC_I_INCOMPLETE_CREDENTIALS:
s = "The credentials supplied were not complete and could not be verified. Additional information can be returned from the context."
case C.SEC_I_LOCAL_LOGON:
s = "The logon was completed"
case C.SEC_I_NO_LSA_CONTEXT:
s = "There is no LSA mode context associated with this context."
case C.SEC_I_RENEGOTIATE:
s = "The context data must be renegotiated with the peer."
default:
return fmt.Errorf("%s: 0x%x", prefix, uint32(status))
}
return fmt.Errorf("%s: %s(0x%x)", prefix, s, uint32(status))
}

View File

@@ -0,0 +1,249 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
#include "sspi_wrapper.h"
static HINSTANCE sspi_secur32_dll = NULL;
static PSecurityFunctionTable sspi_functions = NULL;
static const LPSTR SSPI_PACKAGE_NAME = "kerberos";
int sspi_init(
)
{
// Load the secur32.dll library using its exact path. Passing the exact DLL path rather than allowing LoadLibrary to
// search in different locations removes the possibility of DLL preloading attacks. We use GetSystemDirectoryA and
// LoadLibraryA rather than the GetSystemDirectory/LoadLibrary aliases to ensure the ANSI versions are used so we
// don't have to account for variations in char sizes if UNICODE is enabled.
// Passing a 0 size will return the required buffer length to hold the path, including the null terminator.
int requiredLen = GetSystemDirectoryA(NULL, 0);
if (!requiredLen) {
return GetLastError();
}
// Allocate a buffer to hold the system directory + "\secur32.dll" (length 12, not including null terminator).
int actualLen = requiredLen + 12;
char *directoryBuffer = (char *) calloc(1, actualLen);
int directoryLen = GetSystemDirectoryA(directoryBuffer, actualLen);
if (!directoryLen) {
free(directoryBuffer);
return GetLastError();
}
// Append the DLL name to the buffer.
char *dllName = "\\secur32.dll";
strcpy_s(&(directoryBuffer[directoryLen]), actualLen - directoryLen, dllName);
sspi_secur32_dll = LoadLibraryA(directoryBuffer);
free(directoryBuffer);
if (!sspi_secur32_dll) {
return GetLastError();
}
INIT_SECURITY_INTERFACE init_security_interface = (INIT_SECURITY_INTERFACE)GetProcAddress(sspi_secur32_dll, SECURITY_ENTRYPOINT);
if (!init_security_interface) {
return -1;
}
sspi_functions = (*init_security_interface)();
if (!sspi_functions) {
return -2;
}
return SSPI_OK;
}
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
)
{
TimeStamp timestamp;
if (username) {
if (password) {
SEC_WINNT_AUTH_IDENTITY auth_identity;
#ifdef _UNICODE
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
#else
auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
#endif
auth_identity.User = (LPSTR) username;
auth_identity.UserLength = strlen(username);
auth_identity.Password = (LPSTR) password;
auth_identity.PasswordLength = strlen(password);
auth_identity.Domain = NULL;
auth_identity.DomainLength = 0;
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, &client->cred, &timestamp);
} else {
client->status = sspi_functions->AcquireCredentialsHandle(username, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
} else {
client->status = sspi_functions->AcquireCredentialsHandle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, NULL, NULL, NULL, &client->cred, &timestamp);
}
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
return SSPI_OK;
}
int sspi_client_username(
sspi_client_state *client,
char** username
)
{
SecPkgCredentials_Names names;
client->status = sspi_functions->QueryCredentialsAttributes(&client->cred, SECPKG_CRED_ATTR_NAMES, &names);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
int len = strlen(names.sUserName) + 1;
*username = malloc(len);
memcpy(*username, names.sUserName, len);
sspi_functions->FreeContextBuffer(names.sUserName);
return SSPI_OK;
}
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecBufferDesc inbuf;
SecBuffer in_bufs[1];
SecBufferDesc outbuf;
SecBuffer out_bufs[1];
if (client->has_ctx > 0) {
inbuf.ulVersion = SECBUFFER_VERSION;
inbuf.cBuffers = 1;
inbuf.pBuffers = in_bufs;
in_bufs[0].pvBuffer = input;
in_bufs[0].cbBuffer = input_length;
in_bufs[0].BufferType = SECBUFFER_TOKEN;
}
outbuf.ulVersion = SECBUFFER_VERSION;
outbuf.cBuffers = 1;
outbuf.pBuffers = out_bufs;
out_bufs[0].pvBuffer = NULL;
out_bufs[0].cbBuffer = 0;
out_bufs[0].BufferType = SECBUFFER_TOKEN;
ULONG context_attr = 0;
client->status = sspi_functions->InitializeSecurityContext(
&client->cred,
client->has_ctx > 0 ? &client->ctx : NULL,
(LPSTR) spn,
ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH,
0,
SECURITY_NETWORK_DREP,
client->has_ctx > 0 ? &inbuf : NULL,
0,
&client->ctx,
&outbuf,
&context_attr,
NULL);
if (client->status != SEC_E_OK && client->status != SEC_I_CONTINUE_NEEDED) {
return SSPI_ERROR;
}
client->has_ctx = 1;
*output = malloc(out_bufs[0].cbBuffer);
*output_length = out_bufs[0].cbBuffer;
memcpy(*output, out_bufs[0].pvBuffer, *output_length);
sspi_functions->FreeContextBuffer(out_bufs[0].pvBuffer);
if (client->status == SEC_I_CONTINUE_NEEDED) {
return SSPI_CONTINUE;
}
return SSPI_OK;
}
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
)
{
SecPkgContext_Sizes sizes;
client->status = sspi_functions->QueryContextAttributes(&client->ctx, SECPKG_ATTR_SIZES, &sizes);
if (client->status != SEC_E_OK) {
return SSPI_ERROR;
}
char *msg = malloc((sizes.cbSecurityTrailer + input_length + sizes.cbBlockSize) * sizeof(char));
memcpy(&msg[sizes.cbSecurityTrailer], input, input_length);
SecBuffer wrap_bufs[3];
SecBufferDesc wrap_buf_desc;
wrap_buf_desc.cBuffers = 3;
wrap_buf_desc.pBuffers = wrap_bufs;
wrap_buf_desc.ulVersion = SECBUFFER_VERSION;
wrap_bufs[0].cbBuffer = sizes.cbSecurityTrailer;
wrap_bufs[0].BufferType = SECBUFFER_TOKEN;
wrap_bufs[0].pvBuffer = msg;
wrap_bufs[1].cbBuffer = input_length;
wrap_bufs[1].BufferType = SECBUFFER_DATA;
wrap_bufs[1].pvBuffer = msg + sizes.cbSecurityTrailer;
wrap_bufs[2].cbBuffer = sizes.cbBlockSize;
wrap_bufs[2].BufferType = SECBUFFER_PADDING;
wrap_bufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + input_length;
client->status = sspi_functions->EncryptMessage(&client->ctx, SECQOP_WRAP_NO_ENCRYPT, &wrap_buf_desc, 0);
if (client->status != SEC_E_OK) {
free(msg);
return SSPI_ERROR;
}
*output_length = wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer + wrap_bufs[2].cbBuffer;
*output = malloc(*output_length);
memcpy(*output, wrap_bufs[0].pvBuffer, wrap_bufs[0].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer, wrap_bufs[1].pvBuffer, wrap_bufs[1].cbBuffer);
memcpy(*output + wrap_bufs[0].cbBuffer + wrap_bufs[1].cbBuffer, wrap_bufs[2].pvBuffer, wrap_bufs[2].cbBuffer);
free(msg);
return SSPI_OK;
}
int sspi_client_destroy(
sspi_client_state *client
)
{
if (client->has_ctx > 0) {
sspi_functions->DeleteSecurityContext(&client->ctx);
}
sspi_functions->FreeCredentialsHandle(&client->cred);
return SSPI_OK;
}

View File

@@ -0,0 +1,64 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//+build gssapi,windows
#ifndef SSPI_WRAPPER_H
#define SSPI_WRAPPER_H
#define SECURITY_WIN32 1 /* Required for SSPI */
#include <windows.h>
#include <sspi.h>
#define SSPI_OK 0
#define SSPI_CONTINUE 1
#define SSPI_ERROR 2
typedef struct {
CredHandle cred;
CtxtHandle ctx;
int has_ctx;
SECURITY_STATUS status;
} sspi_client_state;
int sspi_init();
int sspi_client_init(
sspi_client_state *client,
char* username,
char* password
);
int sspi_client_username(
sspi_client_state *client,
char** username
);
int sspi_client_negotiate(
sspi_client_state *client,
char* spn,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_wrap_msg(
sspi_client_state *client,
PVOID input,
ULONG input_length,
PVOID* output,
ULONG* output_length
);
int sspi_client_destroy(
sspi_client_state *client
);
#endif

View File

@@ -0,0 +1,82 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"errors"
)
// MongoDBAWS is the mechanism name for MongoDBAWS.
const MongoDBAWS = "MONGODB-AWS"
func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) {
if cred.Source != "" && cred.Source != "$external" {
return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil)
}
return &MongoDBAWSAuthenticator{
source: cred.Source,
username: cred.Username,
password: cred.Password,
sessionToken: cred.Props["AWS_SESSION_TOKEN"],
}, nil
}
// MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection.
type MongoDBAWSAuthenticator struct {
source string
username string
password string
sessionToken string
}
// Auth authenticates the connection.
func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error {
httpClient := cfg.HTTPClient
if httpClient == nil {
return errors.New("cfg.HTTPClient must not be nil")
}
adapter := &awsSaslAdapter{
conversation: &awsConversation{
username: a.username,
password: a.password,
token: a.sessionToken,
httpClient: httpClient,
},
}
err := ConductSaslConversation(ctx, cfg, a.source, adapter)
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
type awsSaslAdapter struct {
conversation *awsConversation
}
var _ SaslClient = (*awsSaslAdapter)(nil)
func (a *awsSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step(nil)
if err != nil {
return MongoDBAWS, nil, err
}
return MongoDBAWS, step, nil
}
func (a *awsSaslAdapter) Next(challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(challenge)
if err != nil {
return nil, err
}
return step, nil
}
func (a *awsSaslAdapter) Completed() bool {
return a.conversation.Done()
}

View File

@@ -0,0 +1,48 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"errors"
"testing"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestGetRegion(t *testing.T) {
longHost := make([]rune, 256)
emptyErr := errors.New("invalid STS host: empty")
tooLongErr := errors.New("invalid STS host: too large")
emptyPartErr := errors.New("invalid STS host: empty part")
testCases := []struct {
name string
host string
err error
region string
}{
{"success default", "sts.amazonaws.com", nil, "us-east-1"},
{"success parse", "first.second", nil, "second"},
{"success no region", "first", nil, "us-east-1"},
{"error host too long", string(longHost), tooLongErr, ""},
{"error host empty", "", emptyErr, ""},
{"error empty middle part", "abc..def", emptyPartErr, ""},
{"error empty part", "first.", emptyPartErr, ""},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
reg, err := getRegion(tc.host)
if tc.err == nil {
assert.Nil(t, err, "error getting region: %v", err)
assert.Equal(t, tc.region, reg, "expected %v, got %v", tc.region, reg)
return
}
assert.NotNil(t, err, "expected error, got nil")
assert.Equal(t, err, tc.err, "expected error: %v, got: %v", tc.err, err)
})
}
}

View File

@@ -0,0 +1,110 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"io"
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
// to use MD5 here to implement the MONGODB-CR specification.
/* #nosec G501 */
"crypto/md5"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// MONGODBCR is the mechanism name for MONGODB-CR.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
// MongoDB 4.0.
const MONGODBCR = "MONGODB-CR"
func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) {
return &MongoDBCRAuthenticator{
DB: cred.Source,
Username: cred.Username,
Password: cred.Password,
}, nil
}
// MongoDBCRAuthenticator uses the MONGODB-CR algorithm to authenticate a connection.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
// MongoDB 4.0.
type MongoDBCRAuthenticator struct {
DB string
Username string
Password string
}
// Auth authenticates the connection.
//
// The MONGODB-CR authentication mechanism is deprecated in MongoDB 3.6 and removed in
// MongoDB 4.0.
func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error {
db := a.DB
if db == "" {
db = defaultAuthDB
}
doc := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
cmd := operation.NewCommand(doc).
Database(db).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err := cmd.Execute(ctx)
if err != nil {
return newError(err, MONGODBCR)
}
rdr := cmd.Result()
var getNonceResult struct {
Nonce string `bson:"nonce"`
}
err = bson.Unmarshal(rdr, &getNonceResult)
if err != nil {
return newAuthError("unmarshal error", err)
}
doc = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "user", a.Username),
bsoncore.AppendStringElement(nil, "nonce", getNonceResult.Nonce),
bsoncore.AppendStringElement(nil, "key", a.createKey(getNonceResult.Nonce)),
)
cmd = operation.NewCommand(doc).
Database(db).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err = cmd.Execute(ctx)
if err != nil {
return newError(err, MONGODBCR)
}
return nil
}
func (a *MongoDBCRAuthenticator) createKey(nonce string) string {
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
// implement the MONGODB-CR specification.
/* #nosec G401 */
h := md5.New()
_, _ = io.WriteString(h, nonce)
_, _ = io.WriteString(h, a.Username)
_, _ = io.WriteString(h, mongoPasswordDigest(a.Username, a.Password))
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,114 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth_test
import (
"context"
"testing"
"strings"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
)
func TestMongoDBCRAuthenticator_Fails(t *testing.T) {
t.Parallel()
authenticator := MongoDBCRAuthenticator{
DB: "source",
Username: "user",
Password: "pencil",
}
resps := make(chan []byte, 2)
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"),
), bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 0),
))
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}
c := &drivertest.ChannelConn{
Written: make(chan []byte, 2),
ReadResp: resps,
Desc: desc,
}
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
if err == nil {
t.Fatalf("expected an error but got none")
}
errPrefix := "unable to authenticate using mechanism \"MONGODB-CR\""
if !strings.HasPrefix(err.Error(), errPrefix) {
t.Fatalf("expected an err starting with \"%s\" but got \"%s\"", errPrefix, err)
}
}
func TestMongoDBCRAuthenticator_Succeeds(t *testing.T) {
t.Parallel()
authenticator := MongoDBCRAuthenticator{
DB: "source",
Username: "user",
Password: "pencil",
}
resps := make(chan []byte, 2)
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"),
), bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
))
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}
c := &drivertest.ChannelConn{
Written: make(chan []byte, 2),
ReadResp: resps,
Desc: desc,
}
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
if err != nil {
t.Fatalf("expected no error but got \"%s\"", err)
}
if len(c.Written) != 2 {
t.Fatalf("expected 2 messages to be sent but had %d", len(c.Written))
}
want := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendInt32Element(nil, "getnonce", 1))
compareResponses(t, <-c.Written, want, "source")
expectedAuthenticateDoc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "user", "user"),
bsoncore.AppendStringElement(nil, "nonce", "2375531c32080ae8"),
bsoncore.AppendStringElement(nil, "key", "21742f26431831d5cfca035a08c5bdf6"),
)
compareResponses(t, <-c.Written, expectedAuthenticateDoc, "source")
}
func writeReplies(c chan []byte, docs ...bsoncore.Document) {
for _, doc := range docs {
reply := drivertest.MakeReply(doc)
c <- reply
}
}

View File

@@ -0,0 +1,55 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
)
// PLAIN is the mechanism name for PLAIN.
const PLAIN = "PLAIN"
func newPlainAuthenticator(cred *Cred) (Authenticator, error) {
return &PlainAuthenticator{
Username: cred.Username,
Password: cred.Password,
}, nil
}
// PlainAuthenticator uses the PLAIN algorithm over SASL to authenticate a connection.
type PlainAuthenticator struct {
Username string
Password string
}
// Auth authenticates the connection.
func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error {
return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{
username: a.Username,
password: a.Password,
})
}
type plainSaslClient struct {
username string
password string
}
var _ SaslClient = (*plainSaslClient)(nil)
func (c *plainSaslClient) Start() (string, []byte, error) {
b := []byte("\x00" + c.username + "\x00" + c.password)
return PLAIN, b, nil
}
func (c *plainSaslClient) Next(challenge []byte) ([]byte, error) {
return nil, newAuthError("unexpected server challenge", nil)
}
func (c *plainSaslClient) Completed() bool {
return true
}

View File

@@ -0,0 +1,147 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth_test
import (
"context"
"strings"
"testing"
"encoding/base64"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
. "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
)
func TestPlainAuthenticator_Fails(t *testing.T) {
t.Parallel()
authenticator := PlainAuthenticator{
Username: "user",
Password: "pencil",
}
resps := make(chan []byte, 1)
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendInt32Element(nil, "conversationId", 1),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
bsoncore.AppendInt32Element(nil, "code", 143),
bsoncore.AppendBooleanElement(nil, "done", true),
))
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}
c := &drivertest.ChannelConn{
Written: make(chan []byte, 1),
ReadResp: resps,
Desc: desc,
}
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
if err == nil {
t.Fatalf("expected an error but got none")
}
errPrefix := "unable to authenticate using mechanism \"PLAIN\""
if !strings.HasPrefix(err.Error(), errPrefix) {
t.Fatalf("expected an err starting with \"%s\" but got \"%s\"", errPrefix, err)
}
}
func TestPlainAuthenticator_Extra_server_message(t *testing.T) {
t.Parallel()
authenticator := PlainAuthenticator{
Username: "user",
Password: "pencil",
}
resps := make(chan []byte, 2)
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendInt32Element(nil, "conversationId", 1),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
bsoncore.AppendBooleanElement(nil, "done", false),
), bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendInt32Element(nil, "conversationId", 1),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
bsoncore.AppendBooleanElement(nil, "done", true),
))
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}
c := &drivertest.ChannelConn{
Written: make(chan []byte, 1),
ReadResp: resps,
Desc: desc,
}
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
if err == nil {
t.Fatalf("expected an error but got none")
}
errPrefix := "unable to authenticate using mechanism \"PLAIN\": unexpected server challenge"
if !strings.HasPrefix(err.Error(), errPrefix) {
t.Fatalf("expected an err starting with \"%s\" but got \"%s\"", errPrefix, err)
}
}
func TestPlainAuthenticator_Succeeds(t *testing.T) {
t.Parallel()
authenticator := PlainAuthenticator{
Username: "user",
Password: "pencil",
}
resps := make(chan []byte, 1)
writeReplies(resps, bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendInt32Element(nil, "conversationId", 1),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, []byte{}),
bsoncore.AppendBooleanElement(nil, "done", true),
))
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 6,
},
}
c := &drivertest.ChannelConn{
Written: make(chan []byte, 1),
ReadResp: resps,
Desc: desc,
}
err := authenticator.Auth(context.Background(), &Config{Description: desc, Connection: c})
if err != nil {
t.Fatalf("expected no error but got \"%s\"", err)
}
if len(c.Written) != 1 {
t.Fatalf("expected 1 messages to be sent but had %d", len(c.Written))
}
payload, _ := base64.StdEncoding.DecodeString("AHVzZXIAcGVuY2ls")
expectedCmd := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslStart", 1),
bsoncore.AppendStringElement(nil, "mechanism", "PLAIN"),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
)
compareResponses(t, <-c.Written, expectedCmd, "$external")
}

View File

@@ -0,0 +1,174 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// SaslClient is the client piece of a sasl conversation.
type SaslClient interface {
Start() (string, []byte, error)
Next(challenge []byte) ([]byte, error)
Completed() bool
}
// SaslClientCloser is a SaslClient that has resources to clean up.
type SaslClientCloser interface {
SaslClient
Close()
}
// ExtraOptionsSaslClient is a SaslClient that appends options to the saslStart command.
type ExtraOptionsSaslClient interface {
StartCommandOptions() bsoncore.Document
}
// saslConversation represents a SASL conversation. This type implements the SpeculativeConversation interface so the
// conversation can be executed in multi-step speculative fashion.
type saslConversation struct {
client SaslClient
source string
mechanism string
speculative bool
}
var _ SpeculativeConversation = (*saslConversation)(nil)
func newSaslConversation(client SaslClient, source string, speculative bool) *saslConversation {
authSource := source
if authSource == "" {
authSource = defaultAuthDB
}
return &saslConversation{
client: client,
source: authSource,
speculative: speculative,
}
}
// FirstMessage returns the first message to be sent to the server. This message contains a "db" field so it can be used
// for speculative authentication.
func (sc *saslConversation) FirstMessage() (bsoncore.Document, error) {
var payload []byte
var err error
sc.mechanism, payload, err = sc.client.Start()
if err != nil {
return nil, err
}
saslCmdElements := [][]byte{
bsoncore.AppendInt32Element(nil, "saslStart", 1),
bsoncore.AppendStringElement(nil, "mechanism", sc.mechanism),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
}
if sc.speculative {
// The "db" field is only appended for speculative auth because the hello command is executed against admin
// so this is needed to tell the server the user's auth source. For a non-speculative attempt, the SASL commands
// will be executed against the auth source.
saslCmdElements = append(saslCmdElements, bsoncore.AppendStringElement(nil, "db", sc.source))
}
if extraOptionsClient, ok := sc.client.(ExtraOptionsSaslClient); ok {
optionsDoc := extraOptionsClient.StartCommandOptions()
saslCmdElements = append(saslCmdElements, bsoncore.AppendDocumentElement(nil, "options", optionsDoc))
}
return bsoncore.BuildDocumentFromElements(nil, saslCmdElements...), nil
}
type saslResponse struct {
ConversationID int `bson:"conversationId"`
Code int `bson:"code"`
Done bool `bson:"done"`
Payload []byte `bson:"payload"`
}
// Finish completes the conversation based on the first server response to authenticate the given connection.
func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstResponse bsoncore.Document) error {
if closer, ok := sc.client.(SaslClientCloser); ok {
defer closer.Close()
}
var saslResp saslResponse
err := bson.Unmarshal(firstResponse, &saslResp)
if err != nil {
fullErr := fmt.Errorf("unmarshal error: %v", err)
return newError(fullErr, sc.mechanism)
}
cid := saslResp.ConversationID
var payload []byte
var rdr bsoncore.Document
for {
if saslResp.Code != 0 {
return newError(err, sc.mechanism)
}
if saslResp.Done && sc.client.Completed() {
return nil
}
payload, err = sc.client.Next(saslResp.Payload)
if err != nil {
return newError(err, sc.mechanism)
}
if saslResp.Done && sc.client.Completed() {
return nil
}
doc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslContinue", 1),
bsoncore.AppendInt32Element(nil, "conversationId", int32(cid)),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
)
saslContinueCmd := operation.NewCommand(doc).
Database(sc.source).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err = saslContinueCmd.Execute(ctx)
if err != nil {
return newError(err, sc.mechanism)
}
rdr = saslContinueCmd.Result()
err = bson.Unmarshal(rdr, &saslResp)
if err != nil {
fullErr := fmt.Errorf("unmarshal error: %v", err)
return newError(fullErr, sc.mechanism)
}
}
}
// ConductSaslConversation runs a full SASL conversation to authenticate the given connection.
func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error {
// Create a non-speculative SASL conversation.
conversation := newSaslConversation(client, authSource, false)
saslStartDoc, err := conversation.FirstMessage()
if err != nil {
return newError(err, conversation.mechanism)
}
saslStartCmd := operation.NewCommand(saslStartDoc).
Database(authSource).
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
if err := saslStartCmd.Execute(ctx); err != nil {
return newError(err, conversation.mechanism)
}
return conversation.Finish(ctx, cfg, saslStartCmd.Result())
}

View File

@@ -0,0 +1,130 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Copyright (C) MongoDB, Inc. 2018-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"fmt"
"github.com/xdg-go/scram"
"github.com/xdg-go/stringprep"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
const (
// SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1"
SCRAMSHA1 = "SCRAM-SHA-1"
// SCRAMSHA256 holds the mechanism name "SCRAM-SHA-256"
SCRAMSHA256 = "SCRAM-SHA-256"
)
var (
// Additional options for the saslStart command to enable a shorter SCRAM conversation
scramStartOptions bsoncore.Document = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
)
)
func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) {
passdigest := mongoPasswordDigest(cred.Username, cred.Password)
client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-1 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA1,
source: cred.Source,
client: client,
}, nil
}
func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) {
passprep, err := stringprep.SASLprep.Prepare(cred.Password)
if err != nil {
return nil, newAuthError(fmt.Sprintf("error SASLprepping password '%s'", cred.Password), err)
}
client, err := scram.SHA256.NewClientUnprepped(cred.Username, passprep, "")
if err != nil {
return nil, newAuthError("error initializing SCRAM-SHA-256 client", err)
}
client.WithMinIterations(4096)
return &ScramAuthenticator{
mechanism: SCRAMSHA256,
source: cred.Source,
client: client,
}, nil
}
// ScramAuthenticator uses the SCRAM algorithm over SASL to authenticate a connection.
type ScramAuthenticator struct {
mechanism string
source string
client *scram.Client
}
var _ SpeculativeAuthenticator = (*ScramAuthenticator)(nil)
// Auth authenticates the provided connection by conducting a full SASL conversation.
func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error {
err := ConductSaslConversation(ctx, cfg, a.source, a.createSaslClient())
if err != nil {
return newAuthError("sasl conversation error", err)
}
return nil
}
// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication.
func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return newSaslConversation(a.createSaslClient(), a.source, true), nil
}
func (a *ScramAuthenticator) createSaslClient() SaslClient {
return &scramSaslAdapter{
conversation: a.client.NewConversation(),
mechanism: a.mechanism,
}
}
type scramSaslAdapter struct {
mechanism string
conversation *scram.ClientConversation
}
var _ SaslClient = (*scramSaslAdapter)(nil)
var _ ExtraOptionsSaslClient = (*scramSaslAdapter)(nil)
func (a *scramSaslAdapter) Start() (string, []byte, error) {
step, err := a.conversation.Step("")
if err != nil {
return a.mechanism, nil, err
}
return a.mechanism, []byte(step), nil
}
func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) {
step, err := a.conversation.Step(string(challenge))
if err != nil {
return nil, err
}
return []byte(step), nil
}
func (a *scramSaslAdapter) Completed() bool {
return a.conversation.Done()
}
func (*scramSaslAdapter) StartCommandOptions() bsoncore.Document {
return scramStartOptions
}

View File

@@ -0,0 +1,120 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"testing"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
)
const (
scramSha1Nonce = "fyko+d2lbbFgONRv9qkxdawL"
scramSha256Nonce = "rOprNGfwEbeRWgbNEkqO"
)
var (
scramSha1ShortPayloads = [][]byte{
[]byte("r=fyko+d2lbbFgONRv9qkxdawLHo+Vgk7qvUOKUwuWLIWg4l/9SraGMHEE,s=rQ9ZY3MntBeuP3E1TDVC4w==,i=10000"),
[]byte("v=UMWeI25JD1yNYZRMpZ4VHvhZ9e0="),
}
scramSha256ShortPayloads = [][]byte{
[]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"),
[]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="),
}
scramSha1LongPayloads = append(scramSha1ShortPayloads, []byte{})
scramSha256LongPayloads = append(scramSha256ShortPayloads, []byte{})
)
func TestSCRAM(t *testing.T) {
t.Run("conversation", func(t *testing.T) {
testCases := []struct {
name string
createAuthenticatorFn func(*Cred) (Authenticator, error)
payloads [][]byte
nonce string
}{
{"scram-sha-1 short conversation", newScramSHA1Authenticator, scramSha1ShortPayloads, scramSha1Nonce},
{"scram-sha-256 short conversation", newScramSHA256Authenticator, scramSha256ShortPayloads, scramSha256Nonce},
{"scram-sha-1 long conversation", newScramSHA1Authenticator, scramSha1LongPayloads, scramSha1Nonce},
{"scram-sha-256 long conversation", newScramSHA256Authenticator, scramSha256LongPayloads, scramSha256Nonce},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
authenticator, err := tc.createAuthenticatorFn(&Cred{
Username: "user",
Password: "pencil",
Source: "admin",
})
assert.Nil(t, err, "error creating authenticator: %v", err)
sa, _ := authenticator.(*ScramAuthenticator)
sa.client = sa.client.WithNonceGenerator(func() string {
return tc.nonce
})
responses := make(chan []byte, len(tc.payloads))
writeReplies(responses, createSCRAMConversation(tc.payloads)...)
desc := description.Server{
WireVersion: &description.VersionRange{
Max: 4,
},
}
conn := &drivertest.ChannelConn{
Written: make(chan []byte, len(tc.payloads)),
ReadResp: responses,
Desc: desc,
}
err = authenticator.Auth(context.Background(), &Config{Description: desc, Connection: conn})
assert.Nil(t, err, "Auth error: %v\n", err)
// Verify that the first command sent is saslStart.
assert.True(t, len(conn.Written) > 1, "wire messages were written to the connection")
startCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing wire message: %v", err)
cmdName := startCmd.Index(0).Key()
assert.Equal(t, cmdName, "saslStart", "cmd name mismatch; expected 'saslStart', got %v", cmdName)
// Verify that the saslStart command always has {options: {skipEmptyExchange: true}}
optionsVal, err := startCmd.LookupErr("options")
assert.Nil(t, err, "no options found in saslStart command")
optionsDoc := optionsVal.Document()
assert.Equal(t, optionsDoc, scramStartOptions, "expected options %v, got %v", scramStartOptions, optionsDoc)
})
}
})
}
func createSCRAMConversation(payloads [][]byte) []bsoncore.Document {
responses := make([]bsoncore.Document, len(payloads))
for idx, payload := range payloads {
res := createSCRAMServerResponse(payload, idx == len(payloads)-1)
responses[idx] = res
}
return responses
}
func createSCRAMServerResponse(payload []byte, done bool) bsoncore.Document {
return bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "conversationId", 1),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, payload),
bsoncore.AppendBooleanElement(nil, "done", done),
bsoncore.AppendInt32Element(nil, "ok", 1),
)
}
func writeReplies(c chan []byte, docs ...bsoncore.Document) {
for _, doc := range docs {
reply := drivertest.MakeReply(doc)
c <- reply
}
}

View File

@@ -0,0 +1,253 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"bytes"
"context"
"testing"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
)
var (
// The base elements for a hello response.
handshakeHelloElements = [][]byte{
bsoncore.AppendInt32Element(nil, "ok", 1),
bsoncore.AppendBooleanElement(nil, internal.LegacyHelloLowercase, true),
bsoncore.AppendInt32Element(nil, "maxBsonObjectSize", 16777216),
bsoncore.AppendInt32Element(nil, "maxMessageSizeBytes", 48000000),
bsoncore.AppendInt32Element(nil, "minWireVersion", 0),
bsoncore.AppendInt32Element(nil, "maxWireVersion", 6),
}
// The first payload sent by the driver for SCRAM-SHA-1/256 authentication.
firstScramSha1ClientPayload = []byte("n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL")
firstScramSha256ClientPayload = []byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO")
)
func TestSpeculativeSCRAM(t *testing.T) {
cred := &Cred{
Username: "user",
Password: "pencil",
PasswordSet: true,
Source: "admin",
}
t.Run("speculative response included", func(t *testing.T) {
// Tests for SCRAM-SHA1 and SCRAM-SHA-256 when the hello response contains a reply to the speculative
// authentication attempt. The driver should only send a saslContinue after the hello to complete
// authentication.
testCases := []struct {
name string
mechanism string
firstClientPayload []byte
payloads [][]byte
nonce string
}{
{"SCRAM-SHA-1", "SCRAM-SHA-1", firstScramSha1ClientPayload, scramSha1ShortPayloads, scramSha1Nonce},
{"SCRAM-SHA-256", "SCRAM-SHA-256", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
{"Default", "", firstScramSha256ClientPayload, scramSha256ShortPayloads, scramSha256Nonce},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create a SCRAM authenticator and overwrite the nonce generator to make the conversation
// deterministic.
authenticator, err := CreateAuthenticator(tc.mechanism, cred)
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
setNonce(t, authenticator, tc.nonce)
// Create a Handshaker and fake connection to authenticate.
handshaker := Handshaker(nil, &HandshakeOptions{
Authenticator: authenticator,
DBUser: "admin.user",
})
responses := make(chan []byte, len(tc.payloads))
writeReplies(responses, createSpeculativeSCRAMHandshake(tc.payloads)...)
conn := &drivertest.ChannelConn{
Written: make(chan []byte, len(tc.payloads)),
ReadResp: responses,
}
// Do both parts of the handshake.
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
conn.Desc = info.Description // Set conn.Desc so the new description will be used for the authentication.
err = handshaker.FinishHandshake(context.Background(), conn)
assert.Nil(t, err, "FinishHandshake error: %v", err)
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
// Assert that the driver sent hello with the speculative authentication message.
assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d",
len(tc.payloads), (conn.Written))
helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, helloCmd, internal.LegacyHello)
// Assert that the correct document was sent for speculative authentication.
authDocVal, err := helloCmd.LookupErr("speculativeAuthenticate")
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(helloCmd))
authDoc := authDocVal.Document()
sentMechanism := tc.mechanism
if sentMechanism == "" {
sentMechanism = "SCRAM-SHA-256"
}
expectedAuthDoc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "saslStart", 1),
bsoncore.AppendStringElement(nil, "mechanism", sentMechanism),
bsoncore.AppendBinaryElement(nil, "payload", 0x00, tc.firstClientPayload),
bsoncore.AppendStringElement(nil, "db", "admin"),
bsoncore.AppendDocumentElement(nil, "options", bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendBooleanElement(nil, "skipEmptyExchange", true),
)),
)
assert.True(t, bytes.Equal(expectedAuthDoc, authDoc),
"expected speculative auth document %s, got %s",
bson.Raw(expectedAuthDoc),
authDoc,
)
// Assert that the last command sent in the handshake is saslContinue.
saslContinueCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing saslContinue command: %v", err)
assertCommandName(t, saslContinueCmd, "saslContinue")
})
}
})
t.Run("speculative response not included", func(t *testing.T) {
// Tests for SCRAM-SHA-1 and SCRAM-SHA-256 when the hello response does not contain a reply to the
// speculative authentication attempt. The driver should send both saslStart and saslContinue after the initial
// hello.
// There is no test for the default mechanism because we can't control the nonce used for the actual
// authentication attempt after the speculative attempt fails.
testCases := []struct {
mechanism string
payloads [][]byte
nonce string
}{
{"SCRAM-SHA-1", scramSha1ShortPayloads, scramSha1Nonce},
{"SCRAM-SHA-256", scramSha256ShortPayloads, scramSha256Nonce},
}
for _, tc := range testCases {
t.Run(tc.mechanism, func(t *testing.T) {
authenticator, err := CreateAuthenticator(tc.mechanism, cred)
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
setNonce(t, authenticator, tc.nonce)
handshaker := Handshaker(nil, &HandshakeOptions{
Authenticator: authenticator,
DBUser: "admin.user",
})
numResponses := len(tc.payloads) + 1 // +1 for hello response
responses := make(chan []byte, numResponses)
writeReplies(responses, createRegularSCRAMHandshake(tc.payloads)...)
conn := &drivertest.ChannelConn{
Written: make(chan []byte, numResponses),
ReadResp: responses,
}
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
assert.Nil(t, err, "GetHandshakeInformation error: %v", err)
assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
bson.Raw(info.SpeculativeAuthenticate))
conn.Desc = info.Description
err = handshaker.FinishHandshake(context.Background(), conn)
assert.Nil(t, err, "FinishHandshake error: %v", err)
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, internal.LegacyHello)
_, err = hello.LookupErr("speculativeAuthenticate")
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
saslStart, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing saslStart command: %v", err)
assertCommandName(t, saslStart, "saslStart")
saslContinue, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing saslContinue command: %v", err)
assertCommandName(t, saslContinue, "saslContinue")
})
}
})
}
func setNonce(t *testing.T, authenticator Authenticator, nonce string) {
t.Helper()
nonceGenerator := func() string {
return nonce
}
switch converted := authenticator.(type) {
case *ScramAuthenticator:
converted.client = converted.client.WithNonceGenerator(nonceGenerator)
case *DefaultAuthenticator:
sa := converted.speculativeAuthenticator.(*ScramAuthenticator)
sa.client = sa.client.WithNonceGenerator(nonceGenerator)
default:
t.Fatalf("invalid authenticator type %T", authenticator)
}
}
// createSpeculativeSCRAMHandshake creates the server replies for a successful speculative SCRAM authentication attempt.
// There are two replies:
//
// 1. hello reply containing a "speculativeAuthenticate" document.
// 2. saslContinue reply with done:true
func createSpeculativeSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
firstAuthResponse := createSCRAMServerResponse(payloads[0], false)
firstAuthElem := bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", firstAuthResponse)
hello := bsoncore.BuildDocumentFromElements(nil, append(handshakeHelloElements, firstAuthElem)...)
responses := []bsoncore.Document{hello}
for idx := 1; idx < len(payloads); idx++ {
responses = append(responses, createSCRAMServerResponse(payloads[idx], idx == len(payloads)-1))
}
return responses
}
// createRegularSCRAMHandshake creates the server replies for a handshake + SCRAM authentication attempt. There are
// three replies:
//
// 1. hello reply
// 2. saslStart reply with done:false
// 3. saslContinue reply with done:true
func createRegularSCRAMHandshake(payloads [][]byte) []bsoncore.Document {
hello := bsoncore.BuildDocumentFromElements(nil, handshakeHelloElements...)
responses := []bsoncore.Document{hello}
for idx, payload := range payloads {
responses = append(responses, createSCRAMServerResponse(payload, idx == len(payloads)-1))
}
return responses
}
func assertCommandName(t *testing.T, cmd bsoncore.Document, expectedName string) {
t.Helper()
actualName := cmd.Index(0).Key()
assert.Equal(t, expectedName, actualName, "expected command name '%s', got '%s'", expectedName, actualName)
}

View File

@@ -0,0 +1,136 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"bytes"
"context"
"testing"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
)
var (
x509Response bsoncore.Document = bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendStringElement(nil, "dbname", "$external"),
bsoncore.AppendStringElement(nil, "user", "username"),
bsoncore.AppendInt32Element(nil, "ok", 1),
)
)
func TestSpeculativeX509(t *testing.T) {
t.Run("speculative response included", func(t *testing.T) {
// Tests for X509 when the hello response contains a reply to the speculative authentication attempt. The
// driver should not send any more commands after the hello.
authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{})
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
handshaker := Handshaker(nil, &HandshakeOptions{
Authenticator: authenticator,
})
numResponses := 1
responses := make(chan []byte, numResponses)
writeReplies(responses, createSpeculativeX509Handshake()...)
conn := &drivertest.ChannelConn{
Written: make(chan []byte, numResponses),
ReadResp: responses,
}
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
assert.Nil(t, err, "GetDescription error: %v", err)
assert.NotNil(t, info.SpeculativeAuthenticate, "desc.SpeculativeAuthenticate not set")
conn.Desc = info.Description
err = handshaker.FinishHandshake(context.Background(), conn)
assert.Nil(t, err, "FinishHandshake error: %v", err)
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, internal.LegacyHello)
authDocVal, err := hello.LookupErr("speculativeAuthenticate")
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
authDoc := authDocVal.Document()
expectedAuthDoc := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "mechanism", "MONGODB-X509"),
)
assert.True(t, bytes.Equal(expectedAuthDoc, authDoc), "expected speculative auth document %s, got %s",
expectedAuthDoc, authDoc)
})
t.Run("speculative response not included", func(t *testing.T) {
// Tests for X509 when the hello response does not contain a reply to the speculative authentication attempt.
// The driver should send an authenticate command after the hello.
authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{})
assert.Nil(t, err, "CreateAuthenticator error: %v", err)
handshaker := Handshaker(nil, &HandshakeOptions{
Authenticator: authenticator,
})
numResponses := 2
responses := make(chan []byte, numResponses)
writeReplies(responses, createRegularX509Handshake()...)
conn := &drivertest.ChannelConn{
Written: make(chan []byte, numResponses),
ReadResp: responses,
}
info, err := handshaker.GetHandshakeInformation(context.Background(), address.Address("localhost:27017"), conn)
assert.Nil(t, err, "GetDescription error: %v", err)
assert.Nil(t, info.SpeculativeAuthenticate, "expected desc.SpeculativeAuthenticate to be unset, got %s",
bson.Raw(info.SpeculativeAuthenticate))
conn.Desc = info.Description
err = handshaker.FinishHandshake(context.Background(), conn)
assert.Nil(t, err, "FinishHandshake error: %v", err)
assert.Equal(t, 0, len(conn.ReadResp), "%d messages left unread", len(conn.ReadResp))
assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, internal.LegacyHello)
_, err = hello.LookupErr("speculativeAuthenticate")
assert.Nil(t, err, "expected command %s to contain 'speculativeAuthenticate'", bson.Raw(hello))
authenticate, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing authenticate command: %v", err)
assertCommandName(t, authenticate, "authenticate")
})
}
// createSpeculativeX509Handshake creates the server replies for a successful speculative X509 authentication attempt.
// There is only one reply:
//
// 1. hello reply containing a "speculativeAuthenticate" document.
func createSpeculativeX509Handshake() []bsoncore.Document {
firstAuthElem := bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", x509Response)
hello := bsoncore.BuildDocumentFromElements(nil, append(handshakeHelloElements, firstAuthElem)...)
return []bsoncore.Document{hello}
}
// createSpeculativeX509Handshake creates the server replies for a handshake + X509 authentication attempt.
// There are two replies:
//
// 1. hello reply
// 2. authenticate reply
func createRegularX509Handshake() []bsoncore.Document {
hello := bsoncore.BuildDocumentFromElements(nil, handshakeHelloElements...)
return []bsoncore.Document{hello, x509Response}
}

View File

@@ -0,0 +1,30 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"fmt"
"io"
// Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need
// to use MD5 here to implement the SCRAM specification.
/* #nosec G501 */
"crypto/md5"
)
const defaultAuthDB = "admin"
func mongoPasswordDigest(username, password string) string {
// Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to
// implement the SCRAM specification.
/* #nosec G401 */
h := md5.New()
_, _ = io.WriteString(h, username)
_, _ = io.WriteString(h, ":mongo:")
_, _ = io.WriteString(h, password)
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -0,0 +1,85 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package auth
import (
"context"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
// MongoDBX509 is the mechanism name for MongoDBX509.
const MongoDBX509 = "MONGODB-X509"
func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) {
return &MongoDBX509Authenticator{User: cred.Username}, nil
}
// MongoDBX509Authenticator uses X.509 certificates over TLS to authenticate a connection.
type MongoDBX509Authenticator struct {
User string
}
var _ SpeculativeAuthenticator = (*MongoDBX509Authenticator)(nil)
// x509 represents a X509 authentication conversation. This type implements the SpeculativeConversation interface so the
// conversation can be executed in multi-step speculative fashion.
type x509Conversation struct{}
var _ SpeculativeConversation = (*x509Conversation)(nil)
// FirstMessage returns the first message to be sent to the server.
func (c *x509Conversation) FirstMessage() (bsoncore.Document, error) {
return createFirstX509Message(description.Server{}, ""), nil
}
// createFirstX509Message creates the first message for the X509 conversation.
func createFirstX509Message(desc description.Server, user string) bsoncore.Document {
elements := [][]byte{
bsoncore.AppendInt32Element(nil, "authenticate", 1),
bsoncore.AppendStringElement(nil, "mechanism", MongoDBX509),
}
// Server versions < 3.4 require the username to be included in the message. Versions >= 3.4 will extract the
// username from the certificate.
if desc.WireVersion != nil && desc.WireVersion.Max < 5 {
elements = append(elements, bsoncore.AppendStringElement(nil, "user", user))
}
return bsoncore.BuildDocument(nil, elements...)
}
// Finish implements the SpeculativeConversation interface and is a no-op because an X509 conversation only has one
// step.
func (c *x509Conversation) Finish(context.Context, *Config, bsoncore.Document) error {
return nil
}
// CreateSpeculativeConversation creates a speculative conversation for X509 authentication.
func (a *MongoDBX509Authenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) {
return &x509Conversation{}, nil
}
// Auth authenticates the provided connection by conducting an X509 authentication conversation.
func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error {
requestDoc := createFirstX509Message(cfg.Description, a.User)
authCmd := operation.
NewCommand(requestDoc).
Database("$external").
Deployment(driver.SingleConnectionDeployment{cfg.Connection}).
ClusterClock(cfg.ClusterClock).
ServerAPI(cfg.ServerAPI)
err := authCmd.Execute(ctx)
if err != nil {
return newAuthError("round trip error", err)
}
return nil
}

View File

@@ -0,0 +1,479 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// BatchCursor is a batch implementation of a cursor. It returns documents in entire batches instead
// of one at a time. An individual document cursor can be built on top of this batch cursor.
type BatchCursor struct {
clientSession *session.Client
clock *session.ClusterClock
comment bsoncore.Value
database string
collection string
id int64
err error
server Server
serverDescription description.Server
errorProcessor ErrorProcessor // This will only be set when pinning to a connection.
connection PinnedConnection
batchSize int32
maxTimeMS int64
currentBatch *bsoncore.DocumentSequence
firstBatch bool
cmdMonitor *event.CommandMonitor
postBatchResumeToken bsoncore.Document
crypt Crypt
serverAPI *ServerAPIOptions
// legacy server (< 3.2) fields
limit int32
numReturned int32 // number of docs returned by server
}
// CursorResponse represents the response from a command the results in a cursor. A BatchCursor can
// be constructed from a CursorResponse.
type CursorResponse struct {
Server Server
ErrorProcessor ErrorProcessor // This will only be set when pinning to a connection.
Connection PinnedConnection
Desc description.Server
FirstBatch *bsoncore.DocumentSequence
Database string
Collection string
ID int64
postBatchResumeToken bsoncore.Document
}
// NewCursorResponse constructs a cursor response from the given response and server. This method
// can be used within the ProcessResponse method for an operation.
func NewCursorResponse(info ResponseInfo) (CursorResponse, error) {
response := info.ServerResponse
cur, ok := response.Lookup("cursor").DocumentOK()
if !ok {
return CursorResponse{}, fmt.Errorf("cursor should be an embedded document but is of BSON type %s", response.Lookup("cursor").Type)
}
elems, err := cur.Elements()
if err != nil {
return CursorResponse{}, err
}
curresp := CursorResponse{Server: info.Server, Desc: info.ConnectionDescription}
for _, elem := range elems {
switch elem.Key() {
case "firstBatch":
arr, ok := elem.Value().ArrayOK()
if !ok {
return CursorResponse{}, fmt.Errorf("firstBatch should be an array but is a BSON %s", elem.Value().Type)
}
curresp.FirstBatch = &bsoncore.DocumentSequence{Style: bsoncore.ArrayStyle, Data: arr}
case "ns":
ns, ok := elem.Value().StringValueOK()
if !ok {
return CursorResponse{}, fmt.Errorf("ns should be a string but is a BSON %s", elem.Value().Type)
}
index := strings.Index(ns, ".")
if index == -1 {
return CursorResponse{}, errors.New("ns field must contain a valid namespace, but is missing '.'")
}
curresp.Database = ns[:index]
curresp.Collection = ns[index+1:]
case "id":
curresp.ID, ok = elem.Value().Int64OK()
if !ok {
return CursorResponse{}, fmt.Errorf("id should be an int64 but it is a BSON %s", elem.Value().Type)
}
case "postBatchResumeToken":
curresp.postBatchResumeToken, ok = elem.Value().DocumentOK()
if !ok {
return CursorResponse{}, fmt.Errorf("post batch resume token should be a document but it is a BSON %s", elem.Value().Type)
}
}
}
// If the deployment is behind a load balancer and the cursor has a non-zero ID, pin the cursor to a connection and
// use the same connection to execute getMore and killCursors commands.
if curresp.Desc.LoadBalanced() && curresp.ID != 0 {
// Cache the server as an ErrorProcessor to use when constructing deployments for cursor commands.
ep, ok := curresp.Server.(ErrorProcessor)
if !ok {
return CursorResponse{}, fmt.Errorf("expected Server used to establish a cursor to implement ErrorProcessor, but got %T", curresp.Server)
}
curresp.ErrorProcessor = ep
refConn, ok := info.Connection.(PinnedConnection)
if !ok {
return CursorResponse{}, fmt.Errorf("expected Connection used to establish a cursor to implement PinnedConnection, but got %T", info.Connection)
}
if err := refConn.PinToCursor(); err != nil {
return CursorResponse{}, fmt.Errorf("error incrementing connection reference count when creating a cursor: %v", err)
}
curresp.Connection = refConn
}
return curresp, nil
}
// CursorOptions are extra options that are required to construct a BatchCursor.
type CursorOptions struct {
BatchSize int32
Comment bsoncore.Value
MaxTimeMS int64
Limit int32
CommandMonitor *event.CommandMonitor
Crypt Crypt
ServerAPI *ServerAPIOptions
}
// NewBatchCursor creates a new BatchCursor from the provided parameters.
func NewBatchCursor(cr CursorResponse, clientSession *session.Client, clock *session.ClusterClock, opts CursorOptions) (*BatchCursor, error) {
ds := cr.FirstBatch
bc := &BatchCursor{
clientSession: clientSession,
clock: clock,
comment: opts.Comment,
database: cr.Database,
collection: cr.Collection,
id: cr.ID,
server: cr.Server,
connection: cr.Connection,
errorProcessor: cr.ErrorProcessor,
batchSize: opts.BatchSize,
maxTimeMS: opts.MaxTimeMS,
cmdMonitor: opts.CommandMonitor,
firstBatch: true,
postBatchResumeToken: cr.postBatchResumeToken,
crypt: opts.Crypt,
serverAPI: opts.ServerAPI,
serverDescription: cr.Desc,
}
if ds != nil {
bc.numReturned = int32(ds.DocumentCount())
}
if cr.Desc.WireVersion == nil || cr.Desc.WireVersion.Max < 4 {
bc.limit = opts.Limit
// Take as many documents from the batch as needed.
if bc.limit != 0 && bc.limit < bc.numReturned {
for i := int32(0); i < bc.limit; i++ {
_, err := ds.Next()
if err != nil {
return nil, err
}
}
ds.Data = ds.Data[:ds.Pos]
ds.ResetIterator()
}
}
bc.currentBatch = ds
return bc, nil
}
// NewEmptyBatchCursor returns a batch cursor that is empty.
func NewEmptyBatchCursor() *BatchCursor {
return &BatchCursor{currentBatch: new(bsoncore.DocumentSequence)}
}
// NewBatchCursorFromDocuments returns a batch cursor with current batch set to a sequence-style
// DocumentSequence containing the provided documents.
func NewBatchCursorFromDocuments(documents []byte) *BatchCursor {
return &BatchCursor{
currentBatch: &bsoncore.DocumentSequence{
Data: documents,
Style: bsoncore.SequenceStyle,
},
// BatchCursors created with this function have no associated ID nor server, so no getMore
// calls will be made.
id: 0,
server: nil,
}
}
// ID returns the cursor ID for this batch cursor.
func (bc *BatchCursor) ID() int64 {
return bc.id
}
// Next indicates if there is another batch available. Returning false does not necessarily indicate
// that the cursor is closed. This method will return false when an empty batch is returned.
//
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
// is not a valid batch of documents available.
func (bc *BatchCursor) Next(ctx context.Context) bool {
if ctx == nil {
ctx = context.Background()
}
if bc.firstBatch {
bc.firstBatch = false
return !bc.currentBatch.Empty()
}
if bc.id == 0 || bc.server == nil {
return false
}
bc.getMore(ctx)
return !bc.currentBatch.Empty()
}
// Batch will return a DocumentSequence for the current batch of documents. The returned
// DocumentSequence is only valid until the next call to Next or Close.
func (bc *BatchCursor) Batch() *bsoncore.DocumentSequence { return bc.currentBatch }
// Err returns the latest error encountered.
func (bc *BatchCursor) Err() error { return bc.err }
// Close closes this batch cursor.
func (bc *BatchCursor) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
err := bc.KillCursor(ctx)
bc.id = 0
bc.currentBatch.Data = nil
bc.currentBatch.Style = 0
bc.currentBatch.ResetIterator()
connErr := bc.unpinConnection()
if err == nil {
err = connErr
}
return err
}
func (bc *BatchCursor) unpinConnection() error {
if bc.connection == nil {
return nil
}
err := bc.connection.UnpinFromCursor()
closeErr := bc.connection.Close()
if err == nil && closeErr != nil {
err = closeErr
}
bc.connection = nil
return err
}
// Server returns the server for this cursor.
func (bc *BatchCursor) Server() Server {
return bc.server
}
func (bc *BatchCursor) clearBatch() {
bc.currentBatch.Data = bc.currentBatch.Data[:0]
}
// KillCursor kills cursor on server without closing batch cursor
func (bc *BatchCursor) KillCursor(ctx context.Context) error {
if bc.server == nil || bc.id == 0 {
return nil
}
return Operation{
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendStringElement(dst, "killCursors", bc.collection)
dst = bsoncore.BuildArrayElement(dst, "cursors", bsoncore.Value{Type: bsontype.Int64, Data: bsoncore.AppendInt64(nil, bc.id)})
return dst, nil
},
Database: bc.database,
Deployment: bc.getOperationDeployment(),
Client: bc.clientSession,
Clock: bc.clock,
Legacy: LegacyKillCursors,
CommandMonitor: bc.cmdMonitor,
ServerAPI: bc.serverAPI,
}.Execute(ctx)
}
// calcGetMoreBatchSize calculates the number of documents to return in the
// response of a "getMore" operation based on the given limit, batchSize, and
// number of documents already returned. Returns false if a non-trivial limit is
// lower than or equal to the number of documents already returned.
func calcGetMoreBatchSize(bc BatchCursor) (int32, bool) {
gmBatchSize := bc.batchSize
// Account for legacy operations that don't support setting a limit.
if bc.limit != 0 && bc.numReturned+bc.batchSize >= bc.limit {
gmBatchSize = bc.limit - bc.numReturned
if gmBatchSize <= 0 {
return gmBatchSize, false
}
}
return gmBatchSize, true
}
func (bc *BatchCursor) getMore(ctx context.Context) {
bc.clearBatch()
if bc.id == 0 {
return
}
numToReturn, ok := calcGetMoreBatchSize(*bc)
if !ok {
if err := bc.Close(ctx); err != nil {
bc.err = err
}
return
}
bc.err = Operation{
CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id)
dst = bsoncore.AppendStringElement(dst, "collection", bc.collection)
if numToReturn > 0 {
dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn)
}
if bc.maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", bc.maxTimeMS)
}
// The getMore command does not support commenting pre-4.4.
if bc.comment.Type != bsontype.Type(0) && bc.serverDescription.WireVersion.Max >= 9 {
dst = bsoncore.AppendValueElement(dst, "comment", bc.comment)
}
return dst, nil
},
Database: bc.database,
Deployment: bc.getOperationDeployment(),
ProcessResponseFn: func(info ResponseInfo) error {
response := info.ServerResponse
id, ok := response.Lookup("cursor", "id").Int64OK()
if !ok {
return fmt.Errorf("cursor.id should be an int64 but is a BSON %s", response.Lookup("cursor", "id").Type)
}
bc.id = id
batch, ok := response.Lookup("cursor", "nextBatch").ArrayOK()
if !ok {
return fmt.Errorf("cursor.nextBatch should be an array but is a BSON %s", response.Lookup("cursor", "nextBatch").Type)
}
bc.currentBatch.Style = bsoncore.ArrayStyle
bc.currentBatch.Data = batch
bc.currentBatch.ResetIterator()
bc.numReturned += int32(bc.currentBatch.DocumentCount()) // Required for legacy operations which don't support limit.
pbrt, err := response.LookupErr("cursor", "postBatchResumeToken")
if err != nil {
// I don't really understand why we don't set bc.err here
return nil
}
pbrtDoc, ok := pbrt.DocumentOK()
if !ok {
bc.err = fmt.Errorf("expected BSON type for post batch resume token to be EmbeddedDocument but got %s", pbrt.Type)
return nil
}
bc.postBatchResumeToken = pbrtDoc
return nil
},
Client: bc.clientSession,
Clock: bc.clock,
Legacy: LegacyGetMore,
CommandMonitor: bc.cmdMonitor,
Crypt: bc.crypt,
ServerAPI: bc.serverAPI,
}.Execute(ctx)
// Once the cursor has been drained, we can unpin the connection if one is currently pinned.
if bc.id == 0 {
err := bc.unpinConnection()
if err != nil && bc.err == nil {
bc.err = err
}
}
// If we're in load balanced mode and the pinned connection encounters a network error, we should not use it for
// future commands. Per the spec, the connection will not be unpinned until the cursor is actually closed, but
// we set the cursor ID to 0 to ensure the Close() call will not execute a killCursors command.
if driverErr, ok := bc.err.(Error); ok && driverErr.NetworkError() && bc.connection != nil {
bc.id = 0
}
// Required for legacy operations which don't support limit.
if bc.limit != 0 && bc.numReturned >= bc.limit {
// call KillCursor instead of Close because Close will clear out the data for the current batch.
err := bc.KillCursor(ctx)
if err != nil && bc.err == nil {
bc.err = err
}
}
}
// PostBatchResumeToken returns the latest seen post batch resume token.
func (bc *BatchCursor) PostBatchResumeToken() bsoncore.Document {
return bc.postBatchResumeToken
}
// SetBatchSize sets the batchSize for future getMores.
func (bc *BatchCursor) SetBatchSize(size int32) {
bc.batchSize = size
}
func (bc *BatchCursor) getOperationDeployment() Deployment {
if bc.connection != nil {
return &loadBalancedCursorDeployment{
errorProcessor: bc.errorProcessor,
conn: bc.connection,
}
}
return SingleServerDeployment{bc.server}
}
// loadBalancedCursorDeployment is used as a Deployment for getMore and killCursors commands when pinning to a
// connection in load balanced mode. This type also functions as an ErrorProcessor to ensure that SDAM errors are
// handled for these commands in this mode.
type loadBalancedCursorDeployment struct {
errorProcessor ErrorProcessor
conn PinnedConnection
}
var _ Deployment = (*loadBalancedCursorDeployment)(nil)
var _ Server = (*loadBalancedCursorDeployment)(nil)
var _ ErrorProcessor = (*loadBalancedCursorDeployment)(nil)
func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ description.ServerSelector) (Server, error) {
return lbcd, nil
}
func (lbcd *loadBalancedCursorDeployment) Kind() description.TopologyKind {
return description.LoadBalanced
}
func (lbcd *loadBalancedCursorDeployment) Connection(_ context.Context) (Connection, error) {
return lbcd.conn, nil
}
// RTTMonitor implements the driver.Server interface.
func (lbcd *loadBalancedCursorDeployment) RTTMonitor() RTTMonitor {
return &internal.ZeroRTTMonitor{}
}
func (lbcd *loadBalancedCursorDeployment) ProcessError(err error, conn Connection) ProcessErrorResult {
return lbcd.errorProcessor.ProcessError(err, conn)
}

View File

@@ -0,0 +1,92 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"testing"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestBatchCursor(t *testing.T) {
t.Parallel()
t.Run("setBatchSize", func(t *testing.T) {
t.Parallel()
var size int32
bc := &BatchCursor{
batchSize: size,
}
assert.Equal(t, size, bc.batchSize, "expected batchSize %v, got %v", size, bc.batchSize)
size = int32(4)
bc.SetBatchSize(size)
assert.Equal(t, size, bc.batchSize, "expected batchSize %v, got %v", size, bc.batchSize)
})
t.Run("calcGetMoreBatchSize", func(t *testing.T) {
t.Parallel()
for _, tcase := range []struct {
name string
size, limit, numReturned, expected int32
ok bool
}{
{
name: "empty",
expected: 0,
ok: true,
},
{
name: "batchSize NEQ 0",
size: 4,
expected: 4,
ok: true,
},
{
name: "limit NEQ 0",
limit: 4,
expected: 0,
ok: true,
},
{
name: "limit NEQ and batchSize + numReturned EQ limit",
size: 4,
limit: 8,
numReturned: 4,
expected: 4,
ok: true,
},
{
name: "limit makes batchSize negative",
numReturned: 4,
limit: 2,
expected: -2,
ok: false,
},
} {
tcase := tcase
t.Run(tcase.name, func(t *testing.T) {
t.Parallel()
bc := &BatchCursor{
limit: tcase.limit,
batchSize: tcase.size,
numReturned: tcase.numReturned,
}
bc.SetBatchSize(tcase.size)
size, ok := calcGetMoreBatchSize(*bc)
assert.Equal(t, tcase.expected, size, "expected batchSize %v, got %v", tcase.expected, size)
assert.Equal(t, tcase.ok, ok, "expected ok %v, got %v", tcase.ok, ok)
})
}
})
}

View File

@@ -0,0 +1,76 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"errors"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ErrDocumentTooLarge occurs when a document that is larger than the maximum size accepted by a
// server is passed to an insert command.
var ErrDocumentTooLarge = errors.New("an inserted document is too large")
// Batches contains the necessary information to batch split an operation. This is only used for write
// oeprations.
type Batches struct {
Identifier string
Documents []bsoncore.Document
Current []bsoncore.Document
Ordered *bool
}
// Valid returns true if Batches contains both an identifier and the length of Documents is greater
// than zero.
func (b *Batches) Valid() bool { return b != nil && b.Identifier != "" && len(b.Documents) > 0 }
// ClearBatch clears the Current batch. This must be called before AdvanceBatch will advance to the
// next batch.
func (b *Batches) ClearBatch() { b.Current = b.Current[:0] }
// AdvanceBatch splits the next batch using maxCount and targetBatchSize. This method will do nothing if
// the current batch has not been cleared. We do this so that when this is called during execute we
// can call it without first needing to check if we already have a batch, which makes the code
// simpler and makes retrying easier.
// The maxDocSize parameter is used to check that any one document is not too large. If the first document is bigger
// than targetBatchSize but smaller than maxDocSize, a batch of size 1 containing that document will be created.
func (b *Batches) AdvanceBatch(maxCount, targetBatchSize, maxDocSize int) error {
if len(b.Current) > 0 {
return nil
}
if maxCount <= 0 {
maxCount = 1
}
splitAfter := 0
size := 0
for i, doc := range b.Documents {
if i == maxCount {
break
}
if len(doc) > maxDocSize {
return ErrDocumentTooLarge
}
if size+len(doc) > targetBatchSize {
break
}
size += len(doc)
splitAfter++
}
// if there are no documents, take the first one.
// this can happen if there is a document that is smaller than maxDocSize but greater than targetBatchSize.
if splitAfter == 0 {
splitAfter = 1
}
b.Current, b.Documents = b.Documents[:splitAfter], b.Documents[splitAfter:]
return nil
}

View File

@@ -0,0 +1,137 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestBatches(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
testCases := []struct {
name string
batches *Batches
want bool
}{
{"nil", nil, false},
{"missing identifier", &Batches{}, false},
{"no documents", &Batches{Identifier: "documents"}, false},
{"valid", &Batches{Identifier: "documents", Documents: make([]bsoncore.Document, 5)}, true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
want := tc.want
got := tc.batches.Valid()
if got != want {
t.Errorf("Did not get expected result from Valid. got %t; want %t", got, want)
}
})
}
})
t.Run("ClearBatch", func(t *testing.T) {
batches := &Batches{Identifier: "documents", Current: make([]bsoncore.Document, 2, 10)}
if len(batches.Current) != 2 {
t.Fatalf("Length of current batch should be 2, but is %d", len(batches.Current))
}
batches.ClearBatch()
if len(batches.Current) != 0 {
t.Fatalf("Length of current batch should be 0, but is %d", len(batches.Current))
}
})
t.Run("AdvanceBatch", func(t *testing.T) {
documents := make([]bsoncore.Document, 0)
for i := 0; i < 5; i++ {
doc := make(bsoncore.Document, 100)
documents = append(documents, doc)
}
testCases := []struct {
name string
batches *Batches
maxCount int
targetBatchSize int
maxDocSize int
err error
want *Batches
}{
{
"current batch non-zero",
&Batches{Current: make([]bsoncore.Document, 2, 10)},
0, 0, 0, nil,
&Batches{Current: make([]bsoncore.Document, 2, 10)},
},
{
// all of the documents in the batch fit in targetBatchSize so the batch is created successfully
"documents fit in targetBatchSize",
&Batches{Documents: documents},
10, 600, 1000, nil,
&Batches{Documents: documents[:0], Current: documents[0:]},
},
{
// the first doc is bigger than targetBatchSize but smaller than maxDocSize so it is taken alone
"first document larger than targetBatchSize, smaller than maxDocSize",
&Batches{Documents: documents},
10, 5, 100, nil,
&Batches{Documents: documents[1:], Current: documents[:1]},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.batches.AdvanceBatch(tc.maxCount, tc.targetBatchSize, tc.maxDocSize)
if !cmp.Equal(err, tc.err, cmp.Comparer(compareErrors)) {
t.Errorf("Errors do not match. got %v; want %v", err, tc.err)
}
if !cmp.Equal(tc.batches, tc.want) {
t.Errorf("Batches is not in correct state after AdvanceBatch. got %v; want %v", tc.batches, tc.want)
}
})
}
t.Run("middle document larger than targetBatchSize, smaller than maxDocSize", func(t *testing.T) {
// a batch is made but one document is too big, so everything before it is taken.
// on the second call to AdvanceBatch, only the large document is taken
middleLargeDoc := make([]bsoncore.Document, 0)
for i := 0; i < 5; i++ {
doc := make(bsoncore.Document, 100)
middleLargeDoc = append(middleLargeDoc, doc)
}
largeDoc := make(bsoncore.Document, 900)
middleLargeDoc[2] = largeDoc
batches := &Batches{Documents: middleLargeDoc}
maxCount := 10
targetSize := 600
maxDocSize := 1000
// first batch should take first 2 docs (size 100 each)
err := batches.AdvanceBatch(maxCount, targetSize, maxDocSize)
assert.Nil(t, err, "AdvanceBatch error: %v", err)
want := &Batches{Current: middleLargeDoc[:2], Documents: middleLargeDoc[2:]}
assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches)
// second batch should take single large doc (size 900)
batches.ClearBatch()
err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize)
assert.Nil(t, err, "AdvanceBatch error: %v", err)
want = &Batches{Current: middleLargeDoc[2:3], Documents: middleLargeDoc[3:]}
assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches)
// last batch should take last 2 docs (size 100 each)
batches.ClearBatch()
err = batches.AdvanceBatch(maxCount, targetSize, maxDocSize)
assert.Nil(t, err, "AdvanceBatch error: %v", err)
want = &Batches{Current: middleLargeDoc[3:], Documents: middleLargeDoc[:0]}
assert.Equal(t, want, batches, "expected batches %v, got %v", want, batches)
})
})
}

View File

@@ -0,0 +1,53 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"testing"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestCommandMonitoring(t *testing.T) {
t.Run("redactCommand", func(t *testing.T) {
emptyDoc := bsoncore.BuildDocumentFromElements(nil)
legacyHello := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, internal.LegacyHello, 1),
)
legacyHelloLowercase := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, internal.LegacyHelloLowercase, 1),
)
legacyHelloSpeculative := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, internal.LegacyHello, 1),
bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", emptyDoc),
)
legacyHelloSpeculativeLowercase := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendInt32Element(nil, internal.LegacyHelloLowercase, 1),
bsoncore.AppendDocumentElement(nil, "speculativeAuthenticate", emptyDoc),
)
testCases := []struct {
name string
commandName string
command bsoncore.Document
redacted bool
}{
{"legacy hello", internal.LegacyHello, legacyHello, false},
{"legacy hello lowercase", internal.LegacyHelloLowercase, legacyHelloLowercase, false},
{"legacy hello speculative auth", internal.LegacyHello, legacyHelloSpeculative, true},
{"legacy hello speculative auth lowercase", internal.LegacyHello, legacyHelloSpeculativeLowercase, true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
canMonitor := (&Operation{}).redactCommand(tc.commandName, tc.command)
assert.Equal(t, tc.redacted, canMonitor, "expected redacted %v, got %v", tc.redacted, canMonitor)
})
}
})
}

View File

@@ -0,0 +1,145 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"bytes"
"compress/zlib"
"fmt"
"io"
"sync"
"github.com/golang/snappy"
"github.com/klauspost/compress/zstd"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// CompressionOpts holds settings for how to compress a payload
type CompressionOpts struct {
Compressor wiremessage.CompressorID
ZlibLevel int
ZstdLevel int
UncompressedSize int32
}
var zstdEncoders sync.Map // map[zstd.EncoderLevel]*zstd.Encoder
func getZstdEncoder(level zstd.EncoderLevel) (*zstd.Encoder, error) {
if v, ok := zstdEncoders.Load(level); ok {
return v.(*zstd.Encoder), nil
}
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(level))
if err != nil {
return nil, err
}
zstdEncoders.Store(level, encoder)
return encoder, nil
}
var zlibEncoders sync.Map // map[int /*level*/]*zlibEncoder
func getZlibEncoder(level int) (*zlibEncoder, error) {
if v, ok := zlibEncoders.Load(level); ok {
return v.(*zlibEncoder), nil
}
writer, err := zlib.NewWriterLevel(nil, level)
if err != nil {
return nil, err
}
encoder := &zlibEncoder{writer: writer, buf: new(bytes.Buffer)}
zlibEncoders.Store(level, encoder)
return encoder, nil
}
type zlibEncoder struct {
mu sync.Mutex
writer *zlib.Writer
buf *bytes.Buffer
}
func (e *zlibEncoder) Encode(dst, src []byte) ([]byte, error) {
e.mu.Lock()
defer e.mu.Unlock()
e.buf.Reset()
e.writer.Reset(e.buf)
_, err := e.writer.Write(src)
if err != nil {
return nil, err
}
err = e.writer.Close()
if err != nil {
return nil, err
}
dst = append(dst[:0], e.buf.Bytes()...)
return dst, nil
}
// CompressPayload takes a byte slice and compresses it according to the options passed
func CompressPayload(in []byte, opts CompressionOpts) ([]byte, error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
case wiremessage.CompressorSnappy:
return snappy.Encode(nil, in), nil
case wiremessage.CompressorZLib:
encoder, err := getZlibEncoder(opts.ZlibLevel)
if err != nil {
return nil, err
}
return encoder.Encode(nil, in)
case wiremessage.CompressorZstd:
encoder, err := getZstdEncoder(zstd.EncoderLevelFromZstd(opts.ZstdLevel))
if err != nil {
return nil, err
}
return encoder.EncodeAll(in, nil), nil
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
}
// DecompressPayload takes a byte slice that has been compressed and undoes it according to the options passed
func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, err error) {
switch opts.Compressor {
case wiremessage.CompressorNoOp:
return in, nil
case wiremessage.CompressorSnappy:
uncompressed = make([]byte, opts.UncompressedSize)
return snappy.Decode(uncompressed, in)
case wiremessage.CompressorZLib:
r, err := zlib.NewReader(bytes.NewReader(in))
if err != nil {
return nil, err
}
defer func() {
err = r.Close()
}()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
return nil, err
}
return uncompressed, nil
case wiremessage.CompressorZstd:
r, err := zstd.NewReader(bytes.NewBuffer(in))
if err != nil {
return nil, err
}
defer r.Close()
uncompressed = make([]byte, opts.UncompressedSize)
_, err = io.ReadFull(r, uncompressed)
if err != nil {
return nil, err
}
return uncompressed, nil
default:
return nil, fmt.Errorf("unknown compressor ID %v", opts.Compressor)
}
}

View File

@@ -0,0 +1,80 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"os"
"testing"
"github.com/stretchr/testify/assert"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
func TestCompression(t *testing.T) {
compressors := []wiremessage.CompressorID{
wiremessage.CompressorNoOp,
wiremessage.CompressorSnappy,
wiremessage.CompressorZLib,
wiremessage.CompressorZstd,
}
for _, compressor := range compressors {
t.Run(compressor.String(), func(t *testing.T) {
payload := []byte("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt")
opts := CompressionOpts{
Compressor: compressor,
ZlibLevel: wiremessage.DefaultZlibLevel,
ZstdLevel: wiremessage.DefaultZstdLevel,
UncompressedSize: int32(len(payload)),
}
compressed, err := CompressPayload(payload, opts)
assert.NoError(t, err)
assert.NotEqual(t, 0, len(compressed))
decompressed, err := DecompressPayload(compressed, opts)
assert.NoError(t, err)
assert.Equal(t, payload, decompressed)
})
}
}
func BenchmarkCompressPayload(b *testing.B) {
payload := func() []byte {
buf, err := os.ReadFile("compression.go")
if err != nil {
b.Log(err)
b.FailNow()
}
for i := 1; i < 10; i++ {
buf = append(buf, buf...)
}
return buf
}()
compressors := []wiremessage.CompressorID{
wiremessage.CompressorSnappy,
wiremessage.CompressorZLib,
wiremessage.CompressorZstd,
}
for _, compressor := range compressors {
b.Run(compressor.String(), func(b *testing.B) {
opts := CompressionOpts{
Compressor: compressor,
ZlibLevel: wiremessage.DefaultZlibLevel,
ZstdLevel: wiremessage.DefaultZstdLevel,
}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := CompressPayload(payload, opts)
if err != nil {
b.Error(err)
}
}
})
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,356 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package connstring_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"math"
"path"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/internal/testutil/helpers"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
)
type host struct {
Type string
Host string
Port json.Number
}
type auth struct {
Username string
Password *string
DB string
}
type testCase struct {
Description string
URI string
Valid bool
Warning bool
Hosts []host
Auth *auth
Options map[string]interface{}
}
type testContainer struct {
Tests []testCase
}
const connstringTestsDir = "../../../../testdata/connection-string/"
const urioptionsTestDir = "../../../../testdata/uri-options/"
func (h *host) toString() string {
switch h.Type {
case "unix":
return h.Host
case "ip_literal":
if len(h.Port) == 0 {
return "[" + h.Host + "]"
}
return "[" + h.Host + "]" + ":" + string(h.Port)
case "ipv4":
fallthrough
case "hostname":
if len(h.Port) == 0 {
return h.Host
}
return h.Host + ":" + string(h.Port)
}
return ""
}
func hostsToStrings(hosts []host) []string {
out := make([]string, len(hosts))
for i, host := range hosts {
out[i] = host.toString()
}
return out
}
func runTestsInFile(t *testing.T, dirname string, filename string, warningsError bool) {
filepath := path.Join(dirname, filename)
content, err := ioutil.ReadFile(filepath)
require.NoError(t, err)
var container testContainer
require.NoError(t, json.Unmarshal(content, &container))
// Remove ".json" from filename.
filename = filename[:len(filename)-5]
for _, testCase := range container.Tests {
runTest(t, filename, testCase, warningsError)
}
}
var skipDescriptions = map[string]struct{}{
"Valid options specific to single-threaded drivers are parsed correctly": {},
}
var skipKeywords = []string{
"tlsAllowInvalidHostnames",
"tlsAllowInvalidCertificates",
"tlsDisableCertificateRevocationCheck",
"serverSelectionTryOnce",
}
func runTest(t *testing.T, filename string, test testCase, warningsError bool) {
t.Run(filename+"/"+test.Description, func(t *testing.T) {
if _, skip := skipDescriptions[test.Description]; skip {
t.Skip()
}
for _, keyword := range skipKeywords {
if strings.Contains(test.Description, keyword) {
t.Skipf("skipping because keyword %s", keyword)
}
}
cs, err := connstring.ParseAndValidate(test.URI)
// Since we don't have warnings in Go, we return warnings as errors.
//
// This is a bit unfortunate, but since we do raise warnings as errors with the newer
// URI options, but don't with some of the older things, we do a switch on the filename
// here. We are trying to not break existing user applications that have unrecognized
// options.
if test.Valid && !(test.Warning && warningsError) {
require.NoError(t, err)
} else {
require.Error(t, err)
return
}
require.Equal(t, test.URI, cs.Original)
if test.Hosts != nil {
require.Equal(t, hostsToStrings(test.Hosts), cs.Hosts)
}
if test.Auth != nil {
require.Equal(t, test.Auth.Username, cs.Username)
if test.Auth.Password == nil {
require.False(t, cs.PasswordSet)
} else {
require.True(t, cs.PasswordSet)
require.Equal(t, *test.Auth.Password, cs.Password)
}
if test.Auth.DB != cs.Database {
require.Equal(t, test.Auth.DB, cs.AuthSource)
} else {
require.Equal(t, test.Auth.DB, cs.Database)
}
}
// Check that all options are present.
verifyConnStringOptions(t, cs, test.Options)
// Check that non-present options are unset. This will be redundant with the above checks
// for options that are present.
var ok bool
_, ok = test.Options["maxpoolsize"]
require.Equal(t, ok, cs.MaxPoolSizeSet)
})
}
// Test case for all connection string spec tests.
func TestConnStringSpec(t *testing.T) {
for _, file := range helpers.FindJSONFilesInDir(t, connstringTestsDir) {
runTestsInFile(t, connstringTestsDir, file, false)
}
}
func TestURIOptionsSpec(t *testing.T) {
for _, file := range helpers.FindJSONFilesInDir(t, urioptionsTestDir) {
runTestsInFile(t, urioptionsTestDir, file, true)
}
}
// verifyConnStringOptions verifies the options on the connection string.
func verifyConnStringOptions(t *testing.T, cs connstring.ConnString, options map[string]interface{}) {
// Check that all options are present.
for key, value := range options {
key = strings.ToLower(key)
switch key {
case "appname":
require.Equal(t, value, cs.AppName)
case "authsource":
require.Equal(t, value, cs.AuthSource)
case "authmechanism":
require.Equal(t, value, cs.AuthMechanism)
case "authmechanismproperties":
convertedMap := value.(map[string]interface{})
require.Equal(t,
mapInterfaceToString(convertedMap),
cs.AuthMechanismProperties)
case "compressors":
require.Equal(t, convertToStringSlice(value), cs.Compressors)
case "connecttimeoutms":
require.Equal(t, value, float64(cs.ConnectTimeout/time.Millisecond))
case "directconnection":
require.True(t, cs.DirectConnectionSet)
require.Equal(t, value, cs.DirectConnection)
case "heartbeatfrequencyms":
require.Equal(t, value, float64(cs.HeartbeatInterval/time.Millisecond))
case "journal":
require.True(t, cs.JSet)
require.Equal(t, value, cs.J)
case "loadbalanced":
require.True(t, cs.LoadBalancedSet)
require.Equal(t, value, cs.LoadBalanced)
case "localthresholdms":
require.True(t, cs.LocalThresholdSet)
require.Equal(t, value, float64(cs.LocalThreshold/time.Millisecond))
case "maxidletimems":
require.Equal(t, value, float64(cs.MaxConnIdleTime/time.Millisecond))
case "maxpoolsize":
require.True(t, cs.MaxPoolSizeSet)
require.Equal(t, value, cs.MaxPoolSize)
case "maxstalenessseconds":
require.True(t, cs.MaxStalenessSet)
require.Equal(t, value, float64(cs.MaxStaleness/time.Second))
case "minpoolsize":
require.True(t, cs.MinPoolSizeSet)
require.Equal(t, value, int64(cs.MinPoolSize))
case "readpreference":
require.Equal(t, value, cs.ReadPreference)
case "readpreferencetags":
sm, ok := value.([]interface{})
require.True(t, ok)
tags := make([]map[string]string, 0, len(sm))
for _, i := range sm {
m, ok := i.(map[string]interface{})
require.True(t, ok)
tags = append(tags, mapInterfaceToString(m))
}
require.Equal(t, tags, cs.ReadPreferenceTagSets)
case "readconcernlevel":
require.Equal(t, value, cs.ReadConcernLevel)
case "replicaset":
require.Equal(t, value, cs.ReplicaSet)
case "retrywrites":
require.True(t, cs.RetryWritesSet)
require.Equal(t, value, cs.RetryWrites)
case "serverselectiontimeoutms":
require.Equal(t, value, float64(cs.ServerSelectionTimeout/time.Millisecond))
case "srvmaxhosts":
require.Equal(t, value, float64(cs.SRVMaxHosts))
case "srvservicename":
require.Equal(t, value, cs.SRVServiceName)
case "ssl", "tls":
require.Equal(t, value, cs.SSL)
case "sockettimeoutms":
require.Equal(t, value, float64(cs.SocketTimeout/time.Millisecond))
case "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsinsecure":
require.True(t, cs.SSLInsecureSet)
require.Equal(t, value, cs.SSLInsecure)
case "tlscafile":
require.True(t, cs.SSLCaFileSet)
require.Equal(t, value, cs.SSLCaFile)
case "tlscertificatekeyfile":
require.True(t, cs.SSLClientCertificateKeyFileSet)
require.Equal(t, value, cs.SSLClientCertificateKeyFile)
case "tlscertificatekeyfilepassword":
require.True(t, cs.SSLClientCertificateKeyPasswordSet)
require.Equal(t, value, cs.SSLClientCertificateKeyPassword())
case "w":
if cs.WNumberSet {
valueInt := getIntFromInterface(value)
require.NotNil(t, valueInt)
require.Equal(t, *valueInt, int64(cs.WNumber))
} else {
require.Equal(t, value, cs.WString)
}
case "wtimeoutms":
require.Equal(t, value, float64(cs.WTimeout/time.Millisecond))
case "waitqueuetimeoutms":
case "zlibcompressionlevel":
require.Equal(t, value, float64(cs.ZlibLevel))
case "zstdcompressionlevel":
require.Equal(t, value, float64(cs.ZstdLevel))
case "tlsdisableocspendpointcheck":
require.Equal(t, value, cs.SSLDisableOCSPEndpointCheck)
default:
opt, ok := cs.UnknownOptions[key]
require.True(t, ok)
require.Contains(t, opt, fmt.Sprint(value))
}
}
}
// Convert each interface{} value in the map to a string.
func mapInterfaceToString(m map[string]interface{}) map[string]string {
out := make(map[string]string)
for key, value := range m {
out[key] = fmt.Sprint(value)
}
return out
}
// getIntFromInterface attempts to convert an empty interface value to an integer.
//
// Returns nil if it is not possible.
func getIntFromInterface(i interface{}) *int64 {
var out int64
switch v := i.(type) {
case int:
out = int64(v)
case int32:
out = int64(v)
case int64:
out = v
case float32:
f := float64(v)
if math.Floor(f) != f || f > float64(math.MaxInt64) {
break
}
out = int64(f)
case float64:
if math.Floor(v) != v || v > float64(math.MaxInt64) {
break
}
out = int64(v)
default:
return nil
}
return &out
}
func convertToStringSlice(i interface{}) []string {
s, ok := i.([]interface{})
if !ok {
return nil
}
ret := make([]string, 0, len(s))
for _, v := range s {
str, ok := v.(string)
if !ok {
continue
}
ret = append(ret, str)
}
return ret
}

View File

@@ -0,0 +1,631 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package connstring_test
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
)
func TestAppName(t *testing.T) {
tests := []struct {
s string
expected string
err bool
}{
{s: "appName=Funny", expected: "Funny"},
{s: "appName=awesome", expected: "awesome"},
{s: "appName=", expected: ""},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.AppName)
}
})
}
}
func TestAuthMechanism(t *testing.T) {
tests := []struct {
s string
expected string
err bool
}{
{s: "authMechanism=scram-sha-1", expected: "scram-sha-1"},
{s: "authMechanism=scram-sha-256", expected: "scram-sha-256"},
{s: "authMechanism=mongodb-CR", expected: "mongodb-CR"},
{s: "authMechanism=plain", expected: "plain"},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://user:pass@localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.AuthMechanism)
}
})
}
}
func TestAuthSource(t *testing.T) {
tests := []struct {
s string
expected string
err bool
}{
{s: "foobar?authSource=bazqux", expected: "bazqux"},
{s: "foobar", expected: "foobar"},
{s: "", expected: "admin"},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://user:pass@localhost/%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.AuthSource)
}
})
}
}
func TestConnect(t *testing.T) {
tests := []struct {
s string
expected connstring.ConnectMode
err bool
}{
{s: "connect=automatic", expected: connstring.AutoConnect},
{s: "connect=AUTOMATIC", expected: connstring.AutoConnect},
{s: "connect=direct", expected: connstring.SingleConnect},
{s: "connect=blah", err: true},
// Combinations of connect and directConnection where connect is set first - conflicting combinations must
// error.
{s: "connect=automatic&directConnection=true", err: true},
{s: "connect=automatic&directConnection=false", expected: connstring.AutoConnect},
{s: "connect=direct&directConnection=true", expected: connstring.SingleConnect},
{s: "connect=direct&directConnection=false", err: true},
// Combinations of connect and directConnection where directConnection is set first.
{s: "directConnection=true&connect=automatic", err: true},
{s: "directConnection=false&connect=automatic", expected: connstring.AutoConnect},
{s: "directConnection=true&connect=direct", expected: connstring.SingleConnect},
{s: "directConnection=false&connect=direct", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.Connect)
}
})
}
}
func TestDirectConnection(t *testing.T) {
testCases := []struct {
s string
expected bool
err bool
}{
{"directConnection=true", true, false},
{"directConnection=false", false, false},
{"directConnection=TRUE", true, false},
{"directConnection=FALSE", false, false},
{"directConnection=blah", false, true},
}
for _, tc := range testCases {
s := fmt.Sprintf("mongodb://localhost/?%s", tc.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if tc.err {
assert.NotNil(t, err, "expected error, got nil")
return
}
assert.Nil(t, err, "expected no error, got %v", err)
assert.Equal(t, tc.expected, cs.DirectConnection, "expected DirectConnection value %v, got %v", tc.expected,
cs.DirectConnection)
assert.True(t, cs.DirectConnectionSet, "expected DirectConnectionSet to be true, got false")
})
}
}
func TestConnectTimeout(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "connectTimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "connectTimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "connectTimeoutMS=-2", err: true},
{s: "connectTimeoutMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.ConnectTimeout)
require.True(t, cs.ConnectTimeoutSet)
}
})
}
}
func TestHeartbeatInterval(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "heartbeatIntervalMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "heartbeatIntervalMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "heartbeatIntervalMS=-2", err: true},
{s: "heartbeatIntervalMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.HeartbeatInterval)
}
})
}
}
func TestLocalThreshold(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "localThresholdMS=0", expected: time.Duration(0) * time.Millisecond},
{s: "localThresholdMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "localThresholdMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "localThresholdMS=-2", err: true},
{s: "localThresholdMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.LocalThreshold)
}
})
}
}
func TestMaxConnIdleTime(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "maxIdleTimeMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "maxIdleTimeMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "maxIdleTimeMS=-2", err: true},
{s: "maxIdleTimeMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.MaxConnIdleTime)
}
})
}
}
func TestMaxPoolSize(t *testing.T) {
tests := []struct {
s string
expected uint64
err bool
}{
{s: "maxPoolSize=10", expected: 10},
{s: "maxPoolSize=100", expected: 100},
{s: "maxPoolSize=-2", err: true},
{s: "maxPoolSize=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(t, cs.MaxPoolSizeSet)
require.Equal(t, test.expected, cs.MaxPoolSize)
}
})
}
}
func TestMinPoolSize(t *testing.T) {
tests := []struct {
s string
expected uint64
err bool
}{
{s: "minPoolSize=10", expected: 10},
{s: "minPoolSize=100", expected: 100},
{s: "minPoolSize=-2", err: true},
{s: "minPoolSize=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(t, cs.MinPoolSizeSet)
require.Equal(t, test.expected, cs.MinPoolSize)
}
})
}
}
func TestMaxConnecting(t *testing.T) {
tests := []struct {
s string
expected uint64
err bool
}{
{s: "maxConnecting=10", expected: 10},
{s: "maxConnecting=100", expected: 100},
{s: "maxConnecting=-2", err: true},
{s: "maxConnecting=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.True(t, cs.MaxConnectingSet)
require.Equal(t, test.expected, cs.MaxConnecting)
}
})
}
}
func TestReadPreference(t *testing.T) {
tests := []struct {
s string
expected string
err bool
}{
{s: "readPreference=primary", expected: "primary"},
{s: "readPreference=secondaryPreferred", expected: "secondaryPreferred"},
{s: "readPreference=something", expected: "something"}, // we don't validate here
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.ReadPreference)
}
})
}
}
func TestReadPreferenceTags(t *testing.T) {
tests := []struct {
s string
expected []map[string]string
err bool
}{
{s: "", expected: nil},
{s: "readPreferenceTags=one:1", expected: []map[string]string{{"one": "1"}}},
{s: "readPreferenceTags=one:1,two:2", expected: []map[string]string{{"one": "1", "two": "2"}}},
{s: "readPreferenceTags=one:1&readPreferenceTags=two:2", expected: []map[string]string{{"one": "1"}, {"two": "2"}}},
{s: "readPreferenceTags=one:1:3,two:2", err: true},
{s: "readPreferenceTags=one:1&readPreferenceTags=two:2&readPreferenceTags=", expected: []map[string]string{{"one": "1"}, {"two": "2"}, {}}},
{s: "readPreferenceTags=", expected: []map[string]string{{}}},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.ReadPreferenceTagSets)
}
})
}
}
func TestMaxStaleness(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "maxStaleness=10", expected: time.Duration(10) * time.Second},
{s: "maxStaleness=100", expected: time.Duration(100) * time.Second},
{s: "maxStaleness=-2", err: true},
{s: "maxStaleness=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.MaxStaleness)
}
})
}
}
func TestReplicaSet(t *testing.T) {
tests := []struct {
s string
expected string
err bool
}{
{s: "replicaSet=auto", expected: "auto"},
{s: "replicaSet=rs0", expected: "rs0"},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.ReplicaSet)
}
})
}
}
func TestRetryWrites(t *testing.T) {
tests := []struct {
s string
expected bool
err bool
}{
{s: "retryWrites=true", expected: true},
{s: "retryWrites=false", expected: false},
{s: "retryWrites=foobar", expected: false, err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, test.expected, cs.RetryWrites)
require.True(t, cs.RetryWritesSet)
})
}
}
func TestRetryReads(t *testing.T) {
tests := []struct {
s string
expected bool
err bool
}{
{s: "retryReads=true", expected: true},
{s: "retryReads=false", expected: false},
{s: "retryReads=foobar", expected: false, err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, test.expected, cs.RetryReads)
require.True(t, cs.RetryReadsSet)
})
}
}
func TestScheme(t *testing.T) {
// Can't unit test 'mongodb+srv' because that requires networking. Tested
// in x/mongo/driver/topology/initial_dns_seedlist_discovery_test.go
cs, err := connstring.ParseAndValidate("mongodb://localhost/")
require.NoError(t, err)
require.Equal(t, cs.Scheme, "mongodb")
require.Equal(t, cs.Scheme, connstring.SchemeMongoDB)
}
func TestServerSelectionTimeout(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "serverSelectionTimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "serverSelectionTimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "serverSelectionTimeoutMS=-2", err: true},
{s: "serverSelectionTimeoutMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.ServerSelectionTimeout)
require.True(t, cs.ServerSelectionTimeoutSet)
}
})
}
}
func TestSocketTimeout(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "socketTimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "socketTimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "socketTimeoutMS=-2", err: true},
{s: "socketTimeoutMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.SocketTimeout)
require.True(t, cs.SocketTimeoutSet)
}
})
}
}
func TestWTimeout(t *testing.T) {
tests := []struct {
s string
expected time.Duration
err bool
}{
{s: "wtimeoutMS=10", expected: time.Duration(10) * time.Millisecond},
{s: "wtimeoutMS=100", expected: time.Duration(100) * time.Millisecond},
{s: "wtimeoutMS=-2", err: true},
{s: "wtimeoutMS=gsdge", err: true},
}
for _, test := range tests {
s := fmt.Sprintf("mongodb://localhost/?%s", test.s)
t.Run(s, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(s)
if test.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, test.expected, cs.WTimeout)
require.True(t, cs.WTimeoutSet)
}
})
}
}
func TestCompressionOptions(t *testing.T) {
tests := []struct {
name string
uriOptions string
compressors []string
zlibLevel int
zstdLevel int
err bool
}{
{name: "SingleCompressor", uriOptions: "compressors=zlib", compressors: []string{"zlib"}},
{name: "MultiCompressors", uriOptions: "compressors=snappy,zlib", compressors: []string{"snappy", "zlib"}},
{name: "ZlibWithLevel", uriOptions: "compressors=zlib&zlibCompressionLevel=7", compressors: []string{"zlib"}, zlibLevel: 7},
{name: "DefaultZlibLevel", uriOptions: "compressors=zlib&zlibCompressionLevel=-1", compressors: []string{"zlib"}, zlibLevel: 6},
{name: "InvalidZlibLevel", uriOptions: "compressors=zlib&zlibCompressionLevel=-2", compressors: []string{"zlib"}, err: true},
{name: "ZstdWithLevel", uriOptions: "compressors=zstd&zstdCompressionLevel=20", compressors: []string{"zstd"}, zstdLevel: 20},
{name: "DefaultZstdLevel", uriOptions: "compressors=zstd&zstdCompressionLevel=-1", compressors: []string{"zstd"}, zstdLevel: 6},
{name: "InvalidZstdLevel", uriOptions: "compressors=zstd&zstdCompressionLevel=30", compressors: []string{"zstd"}, err: true},
}
for _, tc := range tests {
uri := fmt.Sprintf("mongodb://localhost/?%s", tc.uriOptions)
t.Run(tc.name, func(t *testing.T) {
cs, err := connstring.ParseAndValidate(uri)
if tc.err {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.compressors, cs.Compressors)
if tc.zlibLevel != 0 {
assert.Equal(t, tc.zlibLevel, cs.ZlibLevel)
}
if tc.zstdLevel != 0 {
assert.Equal(t, tc.zstdLevel, cs.ZstdLevel)
}
}
})
}
}

View File

@@ -0,0 +1,473 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
)
const (
defaultKmsPort = 443
defaultKmsTimeout = 10 * time.Second
)
// CollectionInfoFn is a callback used to retrieve collection information.
type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
// KeyRetrieverFn is a callback used to retrieve keys from the key vault.
type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
// MarkCommandFn is a callback used to add encryption markings to a command.
type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
// CryptOptions specifies options to configure a Crypt instance.
type CryptOptions struct {
MongoCrypt *mongocrypt.MongoCrypt
CollInfoFn CollectionInfoFn
KeyFn KeyRetrieverFn
MarkFn MarkCommandFn
TLSConfig map[string]*tls.Config
HTTPClient *http.Client
BypassAutoEncryption bool
BypassQueryAnalysis bool
}
// Crypt is an interface implemented by types that can encrypt and decrypt instances of
// bsoncore.Document.
//
// Users should rely on the driver's crypt type (used by default) for encryption and decryption
// unless they are perfectly confident in another implementation of Crypt.
type Crypt interface {
// Encrypt encrypts the given command.
Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
// Decrypt decrypts the given command response.
Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
// CreateDataKey creates a data key using the given KMS provider and options.
CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
// EncryptExplicit encrypts the given value with the given options.
EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
// DecryptExplicit decrypts the given encrypted value.
DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
// Close cleans up any resources associated with the Crypt instance.
Close()
// BypassAutoEncryption returns true if auto-encryption should be bypassed.
BypassAutoEncryption() bool
// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents
// to be returned as a slice of bsoncore.Document.
RewrapDataKey(ctx context.Context, filter []byte, opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error)
}
// crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
// and decryption.
type crypt struct {
mongoCrypt *mongocrypt.MongoCrypt
collInfoFn CollectionInfoFn
keyFn KeyRetrieverFn
markFn MarkCommandFn
tlsConfig map[string]*tls.Config
httpClient *http.Client
bypassAutoEncryption bool
}
// NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
func NewCrypt(opts *CryptOptions) Crypt {
c := &crypt{
mongoCrypt: opts.MongoCrypt,
collInfoFn: opts.CollInfoFn,
keyFn: opts.KeyFn,
markFn: opts.MarkFn,
tlsConfig: opts.TLSConfig,
httpClient: opts.HTTPClient,
bypassAutoEncryption: opts.BypassAutoEncryption,
}
if c.httpClient == nil {
c.httpClient = internal.DefaultHTTPClient
}
return c
}
// Encrypt encrypts the given command.
func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
if c.bypassAutoEncryption {
return cmd, nil
}
cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
return c.executeStateMachine(ctx, cryptCtx, db)
}
// Decrypt decrypts the given command response.
func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
return c.executeStateMachine(ctx, cryptCtx, "")
}
// CreateDataKey creates a data key using the given KMS provider and options.
func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
return c.executeStateMachine(ctx, cryptCtx, "")
}
// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents to
// be returned as a slice of bsoncore.Document.
func (c *crypt) RewrapDataKey(ctx context.Context, filter []byte,
opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error) {
cryptCtx, err := c.mongoCrypt.RewrapDataKeyContext(filter, opts)
if err != nil {
return nil, err
}
defer cryptCtx.Close()
rewrappedBSON, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return nil, err
}
if rewrappedBSON == nil {
return nil, nil
}
// mongocrypt_ctx_rewrap_many_datakey_init wraps the documents in a BSON of the form { "v": [(BSON document), ...] }
// where each BSON document in the slice is a document containing a rewrapped datakey.
rewrappedDocumentBytes, err := rewrappedBSON.LookupErr("v")
if err != nil {
return nil, err
}
// Parse the resulting BSON as individual documents.
rewrappedDocsArray, ok := rewrappedDocumentBytes.ArrayOK()
if !ok {
return nil, fmt.Errorf("expected results from mongocrypt_ctx_rewrap_many_datakey_init to be an array")
}
rewrappedDocumentValues, err := rewrappedDocsArray.Values()
if err != nil {
return nil, err
}
rewrappedDocuments := []bsoncore.Document{}
for _, rewrappedDocumentValue := range rewrappedDocumentValues {
if rewrappedDocumentValue.Type != bsontype.EmbeddedDocument {
// If a value in the document's array returned by mongocrypt is anything other than an embedded document,
// then something is wrong and we should terminate the routine.
return nil, fmt.Errorf("expected value of type %q, got: %q",
bsontype.EmbeddedDocument.String(),
rewrappedDocumentValue.Type.String())
}
rewrappedDocuments = append(rewrappedDocuments, rewrappedDocumentValue.Document())
}
return rewrappedDocuments, nil
}
// EncryptExplicit encrypts the given value with the given options.
func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendValueElement(doc, "v", val)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
if err != nil {
return 0, nil, err
}
defer cryptCtx.Close()
res, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return 0, nil, err
}
sub, data := res.Lookup("v").Binary()
return sub, data, nil
}
// DecryptExplicit decrypts the given encrypted value.
func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
if err != nil {
return bsoncore.Value{}, err
}
defer cryptCtx.Close()
res, err := c.executeStateMachine(ctx, cryptCtx, "")
if err != nil {
return bsoncore.Value{}, err
}
return res.Lookup("v"), nil
}
// Close cleans up any resources associated with the Crypt instance.
func (c *crypt) Close() {
c.mongoCrypt.Close()
if c.httpClient == internal.DefaultHTTPClient {
internal.CloseIdleHTTPConnections(c.httpClient)
}
}
func (c *crypt) BypassAutoEncryption() bool {
return c.bypassAutoEncryption
}
func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
var err error
for {
state := cryptCtx.State()
switch state {
case mongocrypt.NeedMongoCollInfo:
err = c.collectionInfo(ctx, cryptCtx, db)
case mongocrypt.NeedMongoMarkings:
err = c.markCommand(ctx, cryptCtx, db)
case mongocrypt.NeedMongoKeys:
err = c.retrieveKeys(ctx, cryptCtx)
case mongocrypt.NeedKms:
err = c.decryptKeys(cryptCtx)
case mongocrypt.Ready:
return cryptCtx.Finish()
case mongocrypt.Done:
return nil, nil
case mongocrypt.NeedKmsCredentials:
err = c.provideKmsProviders(ctx, cryptCtx)
default:
return nil, fmt.Errorf("invalid Crypt state: %v", state)
}
if err != nil {
return nil, err
}
}
}
func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
op, err := cryptCtx.NextOperation()
if err != nil {
return err
}
collInfo, err := c.collInfoFn(ctx, db, op)
if err != nil {
return err
}
if collInfo != nil {
if err = cryptCtx.AddOperationResult(collInfo); err != nil {
return err
}
}
return cryptCtx.CompleteOperation()
}
func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
op, err := cryptCtx.NextOperation()
if err != nil {
return err
}
markedCmd, err := c.markFn(ctx, db, op)
if err != nil {
return err
}
if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
return err
}
return cryptCtx.CompleteOperation()
}
func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
op, err := cryptCtx.NextOperation()
if err != nil {
return err
}
keys, err := c.keyFn(ctx, op)
if err != nil {
return err
}
for _, key := range keys {
if err = cryptCtx.AddOperationResult(key); err != nil {
return err
}
}
return cryptCtx.CompleteOperation()
}
func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
for {
kmsCtx := cryptCtx.NextKmsContext()
if kmsCtx == nil {
break
}
if err := c.decryptKey(kmsCtx); err != nil {
return err
}
}
return cryptCtx.FinishKmsContexts()
}
func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
host, err := kmsCtx.HostName()
if err != nil {
return err
}
msg, err := kmsCtx.Message()
if err != nil {
return err
}
// add a port to the address if it's not already present
addr := host
if idx := strings.IndexByte(host, ':'); idx == -1 {
addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
}
kmsProvider := kmsCtx.KMSProvider()
tlsCfg := c.tlsConfig[kmsProvider]
if tlsCfg == nil {
tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
}
conn, err := tls.Dial("tcp", addr, tlsCfg)
if err != nil {
return err
}
defer func() {
_ = conn.Close()
}()
if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
return err
}
if _, err = conn.Write(msg); err != nil {
return err
}
for {
bytesNeeded := kmsCtx.BytesNeeded()
if bytesNeeded == 0 {
return nil
}
res := make([]byte, bytesNeeded)
bytesRead, err := conn.Read(res)
if err != nil && err != io.EOF {
return err
}
if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
return err
}
}
}
// needsKmsProvider returns true if provider was initially set to an empty document.
// An empty document signals the driver to fetch credentials.
func needsKmsProvider(kmsProviders bsoncore.Document, provider string) bool {
val, err := kmsProviders.LookupErr(provider)
if err != nil {
// KMS provider is not configured.
return false
}
doc, ok := val.DocumentOK()
// KMS provider is an empty document.
return ok && len(doc) == 5
}
func getGCPAccessToken(ctx context.Context, httpClient *http.Client) (string, error) {
metadataHost := "metadata.google.internal"
if envhost := os.Getenv("GCE_METADATA_HOST"); envhost != "" {
metadataHost = envhost
}
url := fmt.Sprintf("http://%s/computeMetadata/v1/instance/service-accounts/default/token", metadataHost)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials")
}
req.Header.Set("Metadata-Flavor", "Google")
resp, err := httpClient.Do(req.WithContext(ctx))
if err != nil {
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials")
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials: error reading response body")
}
if resp.StatusCode != http.StatusOK {
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials: expected StatusCode 200, got StatusCode: %v. Response body: %s", resp.StatusCode, body)
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
}
// Attempt to read body as JSON
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return "", internal.WrapErrorf(err, "unable to retrieve GCP credentials: error reading body JSON. Response body: %s", body)
}
if tokenResponse.AccessToken == "" {
return "", fmt.Errorf("unable to retrieve GCP credentials: got unexpected empty accessToken from GCP Metadata Server. Response body: %s", body)
}
return tokenResponse.AccessToken, nil
}
func (c *crypt) provideKmsProviders(ctx context.Context, cryptCtx *mongocrypt.Context) error {
kmsProviders := c.mongoCrypt.GetKmsProviders()
builder := bsoncore.NewDocumentBuilder()
if needsKmsProvider(kmsProviders, "gcp") {
// "gcp" KMS provider is an empty document.
// Attempt to fetch from GCP Instance Metadata server.
{
token, err := getGCPAccessToken(ctx, c.httpClient)
if err != nil {
return err
}
builder.StartDocument("gcp").
AppendString("accessToken", token).
FinishDocument()
}
}
return cryptCtx.ProvideKmsProviders(builder.Build())
}

View File

@@ -0,0 +1,144 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package dns
import (
"errors"
"fmt"
"net"
"runtime"
"strings"
)
// Resolver resolves DNS records.
type Resolver struct {
// Holds the functions to use for DNS lookups
LookupSRV func(string, string, string) (string, []*net.SRV, error)
LookupTXT func(string) ([]string, error)
}
// DefaultResolver is a Resolver that uses the default Resolver from the net package.
var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
// ParseHosts uses the srv string and service name to get the hosts.
func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
parsedHosts := strings.Split(host, ",")
if len(parsedHosts) != 1 {
return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
}
return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
}
// GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
var connectionArgsFromTXT []string
// error ignored because not finding a TXT record should not be
// considered an error.
recordsFromTXT, _ := r.LookupTXT(host)
// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
// It will currently incorrectly concatenate multiple TXT records to one
// on windows.
if runtime.GOOS == "windows" {
recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
}
if len(recordsFromTXT) > 1 {
return nil, errors.New("multiple records from TXT not supported")
}
if len(recordsFromTXT) > 0 {
connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
err := validateTXTResult(connectionArgsFromTXT)
if err != nil {
return nil, err
}
}
return connectionArgsFromTXT, nil
}
func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
var err error
_, _, err = net.SplitHostPort(host)
if err == nil {
// we were able to successfully extract a port from the host,
// but should not be able to when using SRV
return nil, fmt.Errorf("URI with srv must not include a port number")
}
// default to "mongodb" as service name if not supplied
if srvName == "" {
srvName = "mongodb"
}
_, addresses, err := r.LookupSRV(srvName, "tcp", host)
if err != nil {
return nil, err
}
trimmedHost := strings.TrimSuffix(host, ".")
parsedHosts := make([]string, 0, len(addresses))
for _, address := range addresses {
trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
err := validateSRVResult(trimmedAddressTarget, trimmedHost)
if err != nil {
if stopOnErr {
return nil, err
}
continue
}
parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
}
return parsedHosts, nil
}
func validateSRVResult(recordFromSRV, inputHostName string) error {
separatedInputDomain := strings.Split(inputHostName, ".")
separatedRecord := strings.Split(recordFromSRV, ".")
if len(separatedRecord) < 2 {
return errors.New("DNS name must contain at least 2 labels")
}
if len(separatedRecord) < len(separatedInputDomain) {
return errors.New("Domain suffix from SRV record not matched input domain")
}
inputDomainSuffix := separatedInputDomain[1:]
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
for ix, label := range inputDomainSuffix {
if label != recordDomainSuffix[ix] {
return errors.New("Domain suffix from SRV record not matched input domain")
}
}
return nil
}
var allowedTXTOptions = map[string]struct{}{
"authsource": {},
"replicaset": {},
"loadbalanced": {},
}
func validateTXTResult(paramsFromTXT []string) error {
for _, param := range paramsFromTXT {
kv := strings.SplitN(param, "=", 2)
if len(kv) != 2 {
return errors.New("Invalid TXT record")
}
key := strings.ToLower(kv[0])
if _, ok := allowedTXTOptions[key]; !ok {
return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
}
}
return nil
}

View File

@@ -0,0 +1,269 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver // import "go.mongodb.org/mongo-driver/x/mongo/driver"
import (
"context"
"time"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/session"
)
// Deployment is implemented by types that can select a server from a deployment.
type Deployment interface {
SelectServer(context.Context, description.ServerSelector) (Server, error)
Kind() description.TopologyKind
}
// Connector represents a type that can connect to a server.
type Connector interface {
Connect() error
}
// Disconnector represents a type that can disconnect from a server.
type Disconnector interface {
Disconnect(context.Context) error
}
// Subscription represents a subscription to topology updates. A subscriber can receive updates through the
// Updates field.
type Subscription struct {
Updates <-chan description.Topology
ID uint64
}
// Subscriber represents a type to which another type can subscribe. A subscription contains a channel that
// is updated with topology descriptions.
type Subscriber interface {
Subscribe() (*Subscription, error)
Unsubscribe(*Subscription) error
}
// Server represents a MongoDB server. Implementations should pool connections and handle the
// retrieving and returning of connections.
type Server interface {
Connection(context.Context) (Connection, error)
// RTTMonitor returns the round-trip time monitor associated with this server.
RTTMonitor() RTTMonitor
}
// Connection represents a connection to a MongoDB server.
type Connection interface {
WriteWireMessage(context.Context, []byte) error
ReadWireMessage(ctx context.Context) ([]byte, error)
Description() description.Server
// Close closes any underlying connection and returns or frees any resources held by the
// connection. Close is idempotent and can be called multiple times, although subsequent calls
// to Close may return an error. A connection cannot be used after it is closed.
Close() error
ID() string
ServerConnectionID() *int64
DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64.
Address() address.Address
Stale() bool
}
// RTTMonitor represents a round-trip-time monitor.
type RTTMonitor interface {
// EWMA returns the exponentially weighted moving average observed round-trip time.
EWMA() time.Duration
// Min returns the minimum observed round-trip time over the window period.
Min() time.Duration
// P90 returns the 90th percentile observed round-trip time over the window period.
P90() time.Duration
// Stats returns stringified stats of the current state of the monitor.
Stats() string
}
var _ RTTMonitor = &internal.ZeroRTTMonitor{}
// PinnedConnection represents a Connection that can be pinned by one or more cursors or transactions. Implementations
// of this interface should maintain the following invariants:
//
// 1. Each Pin* call should increment the number of references for the connection.
// 2. Each Unpin* call should decrement the number of references for the connection.
// 3. Calls to Close() should be ignored until all resources have unpinned the connection.
type PinnedConnection interface {
Connection
PinToCursor() error
PinToTransaction() error
UnpinFromCursor() error
UnpinFromTransaction() error
}
// The session.LoadBalancedTransactionConnection type is a copy of PinnedConnection that was introduced to avoid
// import cycles. This compile-time assertion ensures that these types remain in sync if the PinnedConnection interface
// is changed in the future.
var _ PinnedConnection = (session.LoadBalancedTransactionConnection)(nil)
// LocalAddresser is a type that is able to supply its local address
type LocalAddresser interface {
LocalAddress() address.Address
}
// Expirable represents an expirable object.
type Expirable interface {
Expire() error
Alive() bool
}
// StreamerConnection represents a Connection that supports streaming wire protocol messages using the moreToCome and
// exhaustAllowed flags.
//
// The SetStreaming and CurrentlyStreaming functions correspond to the moreToCome flag on server responses. If a
// response has moreToCome set, SetStreaming(true) will be called and CurrentlyStreaming() should return true.
//
// CanStream corresponds to the exhaustAllowed flag. The operations layer will set exhaustAllowed on outgoing wire
// messages to inform the server that the driver supports streaming.
type StreamerConnection interface {
Connection
SetStreaming(bool)
CurrentlyStreaming() bool
SupportsStreaming() bool
}
// Compressor is an interface used to compress wire messages. If a Connection supports compression
// it should implement this interface as well. The CompressWireMessage method will be called during
// the execution of an operation if the wire message is allowed to be compressed.
type Compressor interface {
CompressWireMessage(src, dst []byte) ([]byte, error)
}
// ProcessErrorResult represents the result of a ErrorProcessor.ProcessError() call. Exact values for this type can be
// checked directly (e.g. res == ServerMarkedUnknown), but it is recommended that applications use the ServerChanged()
// function instead.
type ProcessErrorResult int
const (
// NoChange indicates that the error did not affect the state of the server.
NoChange ProcessErrorResult = iota
// ServerMarkedUnknown indicates that the error only resulted in the server being marked as Unknown.
ServerMarkedUnknown
// ConnectionPoolCleared indicates that the error resulted in the server being marked as Unknown and its connection
// pool being cleared.
ConnectionPoolCleared
)
// ErrorProcessor implementations can handle processing errors, which may modify their internal state.
// If this type is implemented by a Server, then Operation.Execute will call it's ProcessError
// method after it decodes a wire message.
type ErrorProcessor interface {
ProcessError(err error, conn Connection) ProcessErrorResult
}
// HandshakeInformation contains information extracted from a MongoDB connection handshake. This is a helper type that
// augments description.Server by also tracking server connection ID and authentication-related fields. We use this type
// rather than adding authentication-related fields to description.Server to avoid retaining sensitive information in a
// user-facing type. The server connection ID is stored in this type because unlike description.Server, all handshakes are
// correlated with a single network connection.
type HandshakeInformation struct {
Description description.Server
SpeculativeAuthenticate bsoncore.Document
ServerConnectionID *int64
SaslSupportedMechs []string
}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker interface {
GetHandshakeInformation(context.Context, address.Address, Connection) (HandshakeInformation, error)
FinishHandshake(context.Context, Connection) error
}
// SingleServerDeployment is an implementation of Deployment that always returns a single server.
type SingleServerDeployment struct{ Server }
var _ Deployment = SingleServerDeployment{}
// SelectServer implements the Deployment interface. This method does not use the
// description.SelectedServer provided and instead returns the embedded Server.
func (ssd SingleServerDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return ssd.Server, nil
}
// Kind implements the Deployment interface. It always returns description.Single.
func (SingleServerDeployment) Kind() description.TopologyKind { return description.Single }
// SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This
// implementation should only be used for connection handshakes and server heartbeats as it does not implement
// ErrorProcessor, which is necessary for application operations.
type SingleConnectionDeployment struct{ C Connection }
var _ Deployment = SingleConnectionDeployment{}
var _ Server = SingleConnectionDeployment{}
// SelectServer implements the Deployment interface. This method does not use the
// description.SelectedServer provided and instead returns itself. The Connections returned from the
// Connection method have a no-op Close method.
func (ssd SingleConnectionDeployment) SelectServer(context.Context, description.ServerSelector) (Server, error) {
return ssd, nil
}
// Kind implements the Deployment interface. It always returns description.Single.
func (ssd SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single }
// Connection implements the Server interface. It always returns the embedded connection.
func (ssd SingleConnectionDeployment) Connection(context.Context) (Connection, error) {
return ssd.C, nil
}
// RTTMonitor implements the driver.Server interface.
func (ssd SingleConnectionDeployment) RTTMonitor() RTTMonitor {
return &internal.ZeroRTTMonitor{}
}
// TODO(GODRIVER-617): We can likely use 1 type for both the Type and the RetryMode by using 2 bits for the mode and 1
// TODO bit for the type. Although in the practical sense, we might not want to do that since the type of retryability
// TODO is tied to the operation itself and isn't going change, e.g. and insert operation will always be a write,
// TODO however some operations are both reads and writes, for instance aggregate is a read but with a $out parameter
// TODO it's a write.
// Type specifies whether an operation is a read, write, or unknown.
type Type uint
// THese are the availables types of Type.
const (
_ Type = iota
Write
Read
)
// RetryMode specifies the way that retries are handled for retryable operations.
type RetryMode uint
// These are the modes available for retrying. Note that if Timeout is specified on the Client, the
// operation will automatically retry as many times as possible within the context's deadline
// unless RetryNone is used.
const (
// RetryNone disables retrying.
RetryNone RetryMode = iota
// RetryOnce will enable retrying the entire operation once if Timeout is not specified.
RetryOnce
// RetryOncePerCommand will enable retrying each command associated with an operation if Timeout
// is not specified. For example, if an insert is batch split into 4 commands then each of
// those commands is eligible for one retry.
RetryOncePerCommand
// RetryContext will enable retrying until the context.Context's deadline is exceeded or it is
// cancelled.
RetryContext
)
// Enabled returns if this RetryMode enables retrying.
func (rm RetryMode) Enabled() bool {
return rm == RetryOnce || rm == RetryOncePerCommand || rm == RetryContext
}

View File

@@ -0,0 +1,153 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package drivertest
import (
"context"
"errors"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)
// ChannelConn implements the driver.Connection interface by reading and writing wire messages
// to a channel
type ChannelConn struct {
WriteErr error
Written chan []byte
ReadResp chan []byte
ReadErr chan error
Desc description.Server
}
// WriteWireMessage implements the driver.Connection interface.
func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error {
// Copy wm in case it came from a buffer pool.
b := make([]byte, len(wm))
copy(b, wm)
select {
case c.Written <- b:
default:
c.WriteErr = errors.New("could not write wiremessage to written channel")
}
return c.WriteErr
}
// ReadWireMessage implements the driver.Connection interface.
func (c *ChannelConn) ReadWireMessage(ctx context.Context) ([]byte, error) {
var wm []byte
var err error
select {
case wm = <-c.ReadResp:
case err = <-c.ReadErr:
case <-ctx.Done():
}
return wm, err
}
// Description implements the driver.Connection interface.
func (c *ChannelConn) Description() description.Server { return c.Desc }
// Close implements the driver.Connection interface.
func (c *ChannelConn) Close() error {
return nil
}
// ID implements the driver.Connection interface.
func (c *ChannelConn) ID() string {
return "faked"
}
// DriverConnectionID implements the driver.Connection interface.
// TODO(GODRIVER-2824): replace return type with int64.
func (c *ChannelConn) DriverConnectionID() uint64 {
return 0
}
// ServerConnectionID implements the driver.Connection interface.
func (c *ChannelConn) ServerConnectionID() *int64 {
serverConnectionID := int64(42)
return &serverConnectionID
}
// Address implements the driver.Connection interface.
func (c *ChannelConn) Address() address.Address { return address.Address("0.0.0.0") }
// Stale implements the driver.Connection interface.
func (c *ChannelConn) Stale() bool {
return false
}
// MakeReply creates an OP_REPLY wiremessage from a BSON document
func MakeReply(doc bsoncore.Document) []byte {
var dst []byte
idx, dst := wiremessage.AppendHeaderStart(dst, 10, 9, wiremessage.OpReply)
dst = wiremessage.AppendReplyFlags(dst, 0)
dst = wiremessage.AppendReplyCursorID(dst, 0)
dst = wiremessage.AppendReplyStartingFrom(dst, 0)
dst = wiremessage.AppendReplyNumberReturned(dst, 1)
dst = append(dst, doc...)
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
}
// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message.
func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) {
var ok bool
_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
if !ok {
return nil, errors.New("could not read header")
}
_, wm, ok = wiremessage.ReadQueryFlags(wm)
if !ok {
return nil, errors.New("could not read flags")
}
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
if !ok {
return nil, errors.New("could not read fullCollectionName")
}
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
if !ok {
return nil, errors.New("could not read numberToSkip")
}
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
if !ok {
return nil, errors.New("could not read numberToReturn")
}
var query bsoncore.Document
query, wm, ok = wiremessage.ReadQueryQuery(wm)
if !ok {
return nil, errors.New("could not read query")
}
return query, nil
}
// GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message.
func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) {
var ok bool
_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
if !ok {
return nil, errors.New("could not read header")
}
_, wm, ok = wiremessage.ReadMsgFlags(wm)
if !ok {
return nil, errors.New("could not read flags")
}
_, wm, ok = wiremessage.ReadMsgSectionType(wm)
if !ok {
return nil, errors.New("could not read section type")
}
cmdDoc, wm, ok := wiremessage.ReadMsgSectionSingleDocument(wm)
if !ok {
return nil, errors.New("could not read command document")
}
return cmdDoc, nil
}

View File

@@ -0,0 +1,102 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package drivertest
import (
"errors"
"net"
"time"
)
// ChannelNetConn implements the net.Conn interface by reading and writing wire messages to a channel.
type ChannelNetConn struct {
WriteErr error
Written chan []byte
ReadResp chan []byte
ReadErr chan error
}
// Read reads data from the connection
func (c *ChannelNetConn) Read(b []byte) (int, error) {
var wm []byte
var err error
select {
case wm = <-c.ReadResp:
case err = <-c.ReadErr:
}
return copy(b, wm), err
}
// Write writes data to the connection.
func (c *ChannelNetConn) Write(b []byte) (int, error) {
copyBuf := make([]byte, len(b))
copy(copyBuf, b)
select {
case c.Written <- copyBuf:
default:
c.WriteErr = errors.New("could not write wm to Written channel")
}
return len(b), c.WriteErr
}
// Close closes the connection.
func (c *ChannelNetConn) Close() error {
return nil
}
// LocalAddr returns the local network address.
func (c *ChannelNetConn) LocalAddr() net.Addr {
return nil
}
// RemoteAddr returns the remote network address.
func (c *ChannelNetConn) RemoteAddr() net.Addr {
return nil
}
// SetDeadline sets the read and write deadlines associated with the connection.
func (c *ChannelNetConn) SetDeadline(_ time.Time) error {
return nil
}
// SetReadDeadline sets the read and write deadlines associated with the connection.
func (c *ChannelNetConn) SetReadDeadline(_ time.Time) error {
return nil
}
// SetWriteDeadline sets the read and write deadlines associated with the connection.
func (c *ChannelNetConn) SetWriteDeadline(_ time.Time) error {
return nil
}
// GetWrittenMessage gets the last wire message written to the connection
func (c *ChannelNetConn) GetWrittenMessage() []byte {
select {
case wm := <-c.Written:
return wm
default:
return nil
}
}
// AddResponse adds a response to the connection.
func (c *ChannelNetConn) AddResponse(resp []byte) error {
select {
case c.ReadResp <- resp[:4]:
default:
return errors.New("could not write length bytes")
}
select {
case c.ReadResp <- resp[4:]:
default:
return errors.New("could not write response bytes")
}
return nil
}

View File

@@ -0,0 +1,520 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"bytes"
"errors"
"fmt"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/internal"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var (
retryableCodes = []int32{11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001, 262}
nodeIsRecoveringCodes = []int32{11600, 11602, 13436, 189, 91}
notPrimaryCodes = []int32{10107, 13435, 10058}
nodeIsShuttingDownCodes = []int32{11600, 91}
unknownReplWriteConcernCode = int32(79)
unsatisfiableWriteConcernCode = int32(100)
)
var (
// UnknownTransactionCommitResult is an error label for unknown transaction commit results.
UnknownTransactionCommitResult = "UnknownTransactionCommitResult"
// TransientTransactionError is an error label for transient errors with transactions.
TransientTransactionError = "TransientTransactionError"
// NetworkError is an error label for network errors.
NetworkError = "NetworkError"
// RetryableWriteError is an error lable for retryable write errors.
RetryableWriteError = "RetryableWriteError"
// NoWritesPerformed is an error label indicated that no writes were performed for an operation.
NoWritesPerformed = "NoWritesPerformed"
// ErrCursorNotFound is the cursor not found error for legacy find operations.
ErrCursorNotFound = errors.New("cursor not found")
// ErrUnacknowledgedWrite is returned from functions that have an unacknowledged
// write concern.
ErrUnacknowledgedWrite = errors.New("unacknowledged write")
// ErrUnsupportedStorageEngine is returned when a retryable write is attempted against a server
// that uses a storage engine that does not support retryable writes
ErrUnsupportedStorageEngine = errors.New("this MongoDB deployment does not support retryable writes. Please add retryWrites=false to your connection string")
// ErrDeadlineWouldBeExceeded is returned when a Timeout set on an operation would be exceeded
// if the operation were sent to the server.
ErrDeadlineWouldBeExceeded = errors.New("operation not sent to server, as Timeout would be exceeded")
// ErrNegativeMaxTime is returned when MaxTime on an operation is a negative value.
ErrNegativeMaxTime = errors.New("a negative value was provided for MaxTime on an operation")
)
// QueryFailureError is an error representing a command failure as a document.
type QueryFailureError struct {
Message string
Response bsoncore.Document
Wrapped error
}
// Error implements the error interface.
func (e QueryFailureError) Error() string {
return fmt.Sprintf("%s: %v", e.Message, e.Response)
}
// Unwrap returns the underlying error.
func (e QueryFailureError) Unwrap() error {
return e.Wrapped
}
// ResponseError is an error parsing the response to a command.
type ResponseError struct {
Message string
Wrapped error
}
// NewCommandResponseError creates a CommandResponseError.
func NewCommandResponseError(msg string, err error) ResponseError {
return ResponseError{Message: msg, Wrapped: err}
}
// Error implements the error interface.
func (e ResponseError) Error() string {
if e.Wrapped != nil {
return fmt.Sprintf("%s: %s", e.Message, e.Wrapped)
}
return e.Message
}
// WriteCommandError is an error for a write command.
type WriteCommandError struct {
WriteConcernError *WriteConcernError
WriteErrors WriteErrors
Labels []string
Raw bsoncore.Document
}
// UnsupportedStorageEngine returns whether or not the WriteCommandError comes from a retryable write being attempted
// against a server that has a storage engine where they are not supported
func (wce WriteCommandError) UnsupportedStorageEngine() bool {
for _, writeError := range wce.WriteErrors {
if writeError.Code == 20 && strings.HasPrefix(strings.ToLower(writeError.Message), "transaction numbers") {
return true
}
}
return false
}
func (wce WriteCommandError) Error() string {
var buf bytes.Buffer
fmt.Fprint(&buf, "write command error: [")
fmt.Fprintf(&buf, "{%s}, ", wce.WriteErrors)
fmt.Fprintf(&buf, "{%s}]", wce.WriteConcernError)
return buf.String()
}
// Retryable returns true if the error is retryable
func (wce WriteCommandError) Retryable(wireVersion *description.VersionRange) bool {
for _, label := range wce.Labels {
if label == RetryableWriteError {
return true
}
}
if wireVersion != nil && wireVersion.Max >= 9 {
return false
}
if wce.WriteConcernError == nil {
return false
}
return (*wce.WriteConcernError).Retryable()
}
// HasErrorLabel returns true if the error contains the specified label.
func (wce WriteCommandError) HasErrorLabel(label string) bool {
if wce.Labels != nil {
for _, l := range wce.Labels {
if l == label {
return true
}
}
}
return false
}
// WriteConcernError is a write concern failure that occurred as a result of a
// write operation.
type WriteConcernError struct {
Name string
Code int64
Message string
Details bsoncore.Document
Labels []string
TopologyVersion *description.TopologyVersion
Raw bsoncore.Document
}
func (wce WriteConcernError) Error() string {
if wce.Name != "" {
return fmt.Sprintf("(%v) %v", wce.Name, wce.Message)
}
return wce.Message
}
// Retryable returns true if the error is retryable
func (wce WriteConcernError) Retryable() bool {
for _, code := range retryableCodes {
if wce.Code == int64(code) {
return true
}
}
return false
}
// NodeIsRecovering returns true if this error is a node is recovering error.
func (wce WriteConcernError) NodeIsRecovering() bool {
for _, code := range nodeIsRecoveringCodes {
if wce.Code == int64(code) {
return true
}
}
hasNoCode := wce.Code == 0
return hasNoCode && strings.Contains(wce.Message, "node is recovering")
}
// NodeIsShuttingDown returns true if this error is a node is shutting down error.
func (wce WriteConcernError) NodeIsShuttingDown() bool {
for _, code := range nodeIsShuttingDownCodes {
if wce.Code == int64(code) {
return true
}
}
hasNoCode := wce.Code == 0
return hasNoCode && strings.Contains(wce.Message, "node is shutting down")
}
// NotPrimary returns true if this error is a not primary error.
func (wce WriteConcernError) NotPrimary() bool {
for _, code := range notPrimaryCodes {
if wce.Code == int64(code) {
return true
}
}
hasNoCode := wce.Code == 0
return hasNoCode && strings.Contains(wce.Message, internal.LegacyNotPrimary)
}
// WriteError is a non-write concern failure that occurred as a result of a write
// operation.
type WriteError struct {
Index int64
Code int64
Message string
Details bsoncore.Document
Raw bsoncore.Document
}
func (we WriteError) Error() string { return we.Message }
// WriteErrors is a group of non-write concern failures that occurred as a result
// of a write operation.
type WriteErrors []WriteError
func (we WriteErrors) Error() string {
var buf bytes.Buffer
fmt.Fprint(&buf, "write errors: [")
for idx, err := range we {
if idx != 0 {
fmt.Fprintf(&buf, ", ")
}
fmt.Fprintf(&buf, "{%s}", err)
}
fmt.Fprint(&buf, "]")
return buf.String()
}
// Error is a command execution error from the database.
type Error struct {
Code int32
Message string
Labels []string
Name string
Wrapped error
TopologyVersion *description.TopologyVersion
Raw bsoncore.Document
}
// UnsupportedStorageEngine returns whether e came as a result of an unsupported storage engine
func (e Error) UnsupportedStorageEngine() bool {
return e.Code == 20 && strings.HasPrefix(strings.ToLower(e.Message), "transaction numbers")
}
// Error implements the error interface.
func (e Error) Error() string {
if e.Name != "" {
return fmt.Sprintf("(%v) %v", e.Name, e.Message)
}
return e.Message
}
// Unwrap returns the underlying error.
func (e Error) Unwrap() error {
return e.Wrapped
}
// HasErrorLabel returns true if the error contains the specified label.
func (e Error) HasErrorLabel(label string) bool {
if e.Labels != nil {
for _, l := range e.Labels {
if l == label {
return true
}
}
}
return false
}
// RetryableRead returns true if the error is retryable for a read operation
func (e Error) RetryableRead() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
return false
}
// RetryableWrite returns true if the error is retryable for a write operation
func (e Error) RetryableWrite(wireVersion *description.VersionRange) bool {
for _, label := range e.Labels {
if label == NetworkError || label == RetryableWriteError {
return true
}
}
if wireVersion != nil && wireVersion.Max >= 9 {
return false
}
for _, code := range retryableCodes {
if e.Code == code {
return true
}
}
return false
}
// NetworkError returns true if the error is a network error.
func (e Error) NetworkError() bool {
for _, label := range e.Labels {
if label == NetworkError {
return true
}
}
return false
}
// NodeIsRecovering returns true if this error is a node is recovering error.
func (e Error) NodeIsRecovering() bool {
for _, code := range nodeIsRecoveringCodes {
if e.Code == code {
return true
}
}
hasNoCode := e.Code == 0
return hasNoCode && strings.Contains(e.Message, "node is recovering")
}
// NodeIsShuttingDown returns true if this error is a node is shutting down error.
func (e Error) NodeIsShuttingDown() bool {
for _, code := range nodeIsShuttingDownCodes {
if e.Code == code {
return true
}
}
hasNoCode := e.Code == 0
return hasNoCode && strings.Contains(e.Message, "node is shutting down")
}
// NotPrimary returns true if this error is a not primary error.
func (e Error) NotPrimary() bool {
for _, code := range notPrimaryCodes {
if e.Code == code {
return true
}
}
hasNoCode := e.Code == 0
return hasNoCode && strings.Contains(e.Message, internal.LegacyNotPrimary)
}
// NamespaceNotFound returns true if this errors is a NamespaceNotFound error.
func (e Error) NamespaceNotFound() bool {
return e.Code == 26 || e.Message == "ns not found"
}
// ExtractErrorFromServerResponse extracts an error from a server response bsoncore.Document
// if there is one. Also used in testing for SDAM.
func ExtractErrorFromServerResponse(doc bsoncore.Document) error {
var errmsg, codeName string
var code int32
var labels []string
var ok bool
var tv *description.TopologyVersion
var wcError WriteCommandError
elems, err := doc.Elements()
if err != nil {
return err
}
for _, elem := range elems {
switch elem.Key() {
case "ok":
switch elem.Value().Type {
case bson.TypeInt32:
if elem.Value().Int32() == 1 {
ok = true
}
case bson.TypeInt64:
if elem.Value().Int64() == 1 {
ok = true
}
case bson.TypeDouble:
if elem.Value().Double() == 1 {
ok = true
}
}
case "errmsg":
if str, okay := elem.Value().StringValueOK(); okay {
errmsg = str
}
case "codeName":
if str, okay := elem.Value().StringValueOK(); okay {
codeName = str
}
case "code":
if c, okay := elem.Value().Int32OK(); okay {
code = c
}
case "errorLabels":
if arr, okay := elem.Value().ArrayOK(); okay {
vals, err := arr.Values()
if err != nil {
continue
}
for _, val := range vals {
if str, ok := val.StringValueOK(); ok {
labels = append(labels, str)
}
}
}
case "writeErrors":
arr, exists := elem.Value().ArrayOK()
if !exists {
break
}
vals, err := arr.Values()
if err != nil {
continue
}
for _, val := range vals {
var we WriteError
doc, exists := val.DocumentOK()
if !exists {
continue
}
if index, exists := doc.Lookup("index").AsInt64OK(); exists {
we.Index = index
}
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
we.Code = code
}
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
we.Message = msg
}
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
we.Details = make([]byte, len(info))
copy(we.Details, info)
}
we.Raw = doc
wcError.WriteErrors = append(wcError.WriteErrors, we)
}
case "writeConcernError":
doc, exists := elem.Value().DocumentOK()
if !exists {
break
}
wcError.WriteConcernError = new(WriteConcernError)
wcError.WriteConcernError.Raw = doc
if code, exists := doc.Lookup("code").AsInt64OK(); exists {
wcError.WriteConcernError.Code = code
}
if name, exists := doc.Lookup("codeName").StringValueOK(); exists {
wcError.WriteConcernError.Name = name
}
if msg, exists := doc.Lookup("errmsg").StringValueOK(); exists {
wcError.WriteConcernError.Message = msg
}
if info, exists := doc.Lookup("errInfo").DocumentOK(); exists {
wcError.WriteConcernError.Details = make([]byte, len(info))
copy(wcError.WriteConcernError.Details, info)
}
if errLabels, exists := doc.Lookup("errorLabels").ArrayOK(); exists {
vals, err := errLabels.Values()
if err != nil {
continue
}
for _, val := range vals {
if str, ok := val.StringValueOK(); ok {
labels = append(labels, str)
}
}
}
case "topologyVersion":
doc, ok := elem.Value().DocumentOK()
if !ok {
break
}
version, err := description.NewTopologyVersion(bson.Raw(doc))
if err == nil {
tv = version
}
}
}
if !ok {
if errmsg == "" {
errmsg = "command failed"
}
return Error{
Code: code,
Message: errmsg,
Name: codeName,
Labels: labels,
TopologyVersion: tv,
Raw: doc,
}
}
if len(wcError.WriteErrors) > 0 || wcError.WriteConcernError != nil {
wcError.Labels = labels
if wcError.WriteConcernError != nil {
wcError.WriteConcernError.TopologyVersion = tv
}
wcError.Raw = doc
return wcError
}
return nil
}

View File

@@ -0,0 +1,195 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package integration
import (
"bytes"
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/testutil"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
)
func setUpMonitor() (*event.CommandMonitor, chan *event.CommandStartedEvent, chan *event.CommandSucceededEvent, chan *event.CommandFailedEvent) {
started := make(chan *event.CommandStartedEvent, 1)
succeeded := make(chan *event.CommandSucceededEvent, 1)
failed := make(chan *event.CommandFailedEvent, 1)
return &event.CommandMonitor{
Started: func(ctx context.Context, e *event.CommandStartedEvent) {
started <- e
},
Succeeded: func(ctx context.Context, e *event.CommandSucceededEvent) {
succeeded <- e
},
Failed: func(ctx context.Context, e *event.CommandFailedEvent) {
failed <- e
},
}, started, succeeded, failed
}
func skipIfBelow32(ctx context.Context, t *testing.T, topo *topology.Topology) {
server, err := topo.SelectServer(ctx, description.WriteSelector())
noerr(t, err)
versionCmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "serverStatus", 1))
serverStatus, err := testutil.RunCommand(t, server, dbName, versionCmd)
noerr(t, err)
version, err := serverStatus.LookupErr("version")
noerr(t, err)
if testutil.CompareVersions(t, version.StringValue(), "3.2") < 0 {
t.Skip()
}
}
func TestAggregate(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
t.Run("TestMaxTimeMSInGetMore", func(t *testing.T) {
ctx := context.Background()
monitor, started, succeeded, failed := setUpMonitor()
dbName := "TestAggMaxTimeDB"
collName := "TestAggMaxTimeColl"
top := testutil.MonitoredTopology(t, dbName, monitor)
clearChannels(started, succeeded, failed)
skipIfBelow32(ctx, t, top)
clearChannels(started, succeeded, failed)
err := operation.NewInsert(
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)),
).Collection(collName).Database(dbName).
Deployment(top).ServerSelector(description.WriteSelector()).Execute(context.Background())
noerr(t, err)
clearChannels(started, succeeded, failed)
op := operation.NewAggregate(bsoncore.BuildDocumentFromElements(nil)).
Collection(collName).Database(dbName).Deployment(top).ServerSelector(description.WriteSelector()).
CommandMonitor(monitor).BatchSize(2)
err = op.Execute(context.Background())
noerr(t, err)
batchCursor, err := op.Result(driver.CursorOptions{MaxTimeMS: 10, BatchSize: 2, CommandMonitor: monitor})
noerr(t, err)
var e *event.CommandStartedEvent
select {
case e = <-started:
case <-time.After(2000 * time.Millisecond):
t.Fatal("timed out waiting for aggregate")
}
require.Equal(t, "aggregate", e.CommandName)
clearChannels(started, succeeded, failed)
// first Next() should automatically return true
require.True(t, batchCursor.Next(ctx), "expected true from first Next, got false")
clearChannels(started, succeeded, failed)
batchCursor.Next(ctx) // should do getMore
select {
case e = <-started:
case <-time.After(200 * time.Millisecond):
t.Fatal("timed out waiting for getMore")
}
require.Equal(t, "getMore", e.CommandName)
_, err = e.Command.LookupErr("maxTimeMS")
noerr(t, err)
})
t.Run("Multiple Batches", func(t *testing.T) {
ds := []bsoncore.Document{
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 1)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 2)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 3)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 4)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 5)),
}
wc := writeconcern.New(writeconcern.WMajority())
testutil.AutoInsertDocs(t, wc, ds...)
op := operation.NewAggregate(bsoncore.BuildArray(nil,
bsoncore.BuildDocumentValue(
bsoncore.BuildDocumentElement(nil,
"$match", bsoncore.BuildDocumentElement(nil,
"_id", bsoncore.AppendInt32Element(nil, "$gt", 2),
),
),
),
bsoncore.BuildDocumentValue(
bsoncore.BuildDocumentElement(nil,
"$sort", bsoncore.AppendInt32Element(nil, "_id", 1),
),
),
)).Collection(testutil.ColName(t)).Database(dbName).Deployment(testutil.Topology(t)).
ServerSelector(description.WriteSelector()).BatchSize(2)
err := op.Execute(context.Background())
noerr(t, err)
cursor, err := op.Result(driver.CursorOptions{BatchSize: 2})
noerr(t, err)
var got []bsoncore.Document
for i := 0; i < 2; i++ {
if !cursor.Next(context.Background()) {
t.Error("Cursor should have results, but does not have a next result")
}
docs, err := cursor.Batch().Documents()
noerr(t, err)
got = append(got, docs...)
}
readers := ds[2:]
for i, g := range got {
if !bytes.Equal(g[:len(readers[i])], readers[i]) {
t.Errorf("Did not get expected document. got %v; want %v", bson.Raw(g[:len(readers[i])]), readers[i])
}
}
if cursor.Next(context.Background()) {
t.Error("Cursor should be exhausted but has more results")
}
})
t.Run("AllowDiskUse", func(t *testing.T) {
ds := []bsoncore.Document{
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 1)),
bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "_id", 2)),
}
wc := writeconcern.New(writeconcern.WMajority())
testutil.AutoInsertDocs(t, wc, ds...)
op := operation.NewAggregate(bsoncore.BuildArray(nil)).Collection(testutil.ColName(t)).Database(dbName).
Deployment(testutil.Topology(t)).ServerSelector(description.WriteSelector()).AllowDiskUse(true)
err := op.Execute(context.Background())
if err != nil {
t.Errorf("Expected no error from allowing disk use, but got %v", err)
}
})
}
func clearChannels(s chan *event.CommandStartedEvent, succ chan *event.CommandSucceededEvent, f chan *event.CommandFailedEvent) {
for len(s) > 0 {
<-s
}
for len(succ) > 0 {
<-succ
}
for len(f) > 0 {
<-f
}
}

View File

@@ -0,0 +1,76 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package integration
import (
"context"
"os"
"testing"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/internal/testutil"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
)
func TestCompression(t *testing.T) {
comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
if len(comp) == 0 {
t.Skip("Skipping because no compressor specified")
}
wc := writeconcern.New(writeconcern.WMajority())
collOne := testutil.ColName(t)
testutil.DropCollection(t, testutil.DBName(t), collOne)
testutil.InsertDocs(t, testutil.DBName(t), collOne, wc,
bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "name", "compression_test")),
)
cmd := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "serverStatus", 1))).
Deployment(testutil.Topology(t)).
Database(testutil.DBName(t))
ctx := context.Background()
err := cmd.Execute(ctx)
noerr(t, err)
result := cmd.Result()
serverVersion, err := result.LookupErr("version")
noerr(t, err)
if testutil.CompareVersions(t, serverVersion.StringValue(), "3.4") < 0 {
t.Skip("skipping compression test for version < 3.4")
}
networkVal, err := result.LookupErr("network")
noerr(t, err)
require.Equal(t, networkVal.Type, bson.TypeEmbeddedDocument)
compressionVal, err := networkVal.Document().LookupErr("compression")
noerr(t, err)
compressorDoc, err := compressionVal.Document().LookupErr(comp)
noerr(t, err)
compressorKey := "compressor"
compareTo36 := testutil.CompareVersions(t, serverVersion.StringValue(), "3.6")
if compareTo36 < 0 {
compressorKey = "compressed"
}
compressor, err := compressorDoc.Document().LookupErr(compressorKey)
noerr(t, err)
bytesIn, err := compressor.Document().LookupErr("bytesIn")
noerr(t, err)
require.True(t, bytesIn.IsNumber())
require.True(t, bytesIn.Int64() > 0)
}

View File

@@ -0,0 +1,56 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package integration
import (
"context"
"testing"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
)
func TestInsert(t *testing.T) {
t.Skip()
topo, err := topology.New(nil)
if err != nil {
t.Fatalf("Couldn't connect topology: %v", err)
}
_ = topo.Connect()
doc := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))
iop := operation.NewInsert(doc).Database("foo").Collection("bar").Deployment(topo)
err = iop.Execute(context.Background())
if err != nil {
t.Fatalf("Couldn't execute insert operation: %v", err)
}
t.Log(iop.Result())
fop := operation.NewFind(bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159))).
Database("foo").Collection("bar").Deployment(topo).BatchSize(1)
err = fop.Execute(context.Background())
if err != nil {
t.Fatalf("Couldn't execute find operation: %v", err)
}
cur, err := fop.Result(driver.CursorOptions{BatchSize: 2})
if err != nil {
t.Fatalf("Couldn't get cursor result from find operation: %v", err)
}
for cur.Next(context.Background()) {
batch := cur.Batch()
docs, err := batch.Documents()
if err != nil {
t.Fatalf("Couldn't iterate batch: %v", err)
}
for i, doc := range docs {
t.Log(i, doc)
}
}
}

View File

@@ -0,0 +1,7 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package integration

View File

@@ -0,0 +1,113 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package integration
import (
"flag"
"fmt"
"os"
"strings"
"testing"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
)
var host *string
var connectionString connstring.ConnString
var dbName string
func TestMain(m *testing.M) {
flag.Parse()
mongodbURI := os.Getenv("MONGODB_URI")
if mongodbURI == "" {
mongodbURI = "mongodb://localhost:27017"
}
mongodbURI = addTLSConfigToURI(mongodbURI)
mongodbURI = addCompressorToURI(mongodbURI)
var err error
connectionString, err = connstring.ParseAndValidate(mongodbURI)
if err != nil {
fmt.Printf("Could not parse connection string: %v\n", err)
os.Exit(1)
}
host = &connectionString.Hosts[0]
dbName = fmt.Sprintf("mongo-go-driver-%d", os.Getpid())
if connectionString.Database != "" {
dbName = connectionString.Database
}
os.Exit(m.Run())
}
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: %v", err)
t.FailNow()
}
}
func autherr(t *testing.T, err error) {
t.Helper()
switch e := err.(type) {
case topology.ConnectionError:
_, ok := e.Wrapped.(*auth.Error)
if !ok {
t.Fatal("Expected auth error and didn't get one")
}
case *auth.Error:
return
default:
t.Fatal("Expected auth error and didn't get one")
}
}
// addTLSConfigToURI checks for the environmental variable indicating that the tests are being run
// on an SSL-enabled server, and if so, returns a new URI with the necessary configuration.
func addTLSConfigToURI(uri string) string {
caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE")
if len(caFile) == 0 {
return uri
}
if !strings.ContainsRune(uri, '?') {
if uri[len(uri)-1] != '/' {
uri += "/"
}
uri += "?"
} else {
uri += "&"
}
return uri + "ssl=true&sslCertificateAuthorityFile=" + caFile
}
func addCompressorToURI(uri string) string {
comp := os.Getenv("MONGO_GO_DRIVER_COMPRESSOR")
if len(comp) == 0 {
return uri
}
if !strings.ContainsRune(uri, '?') {
if uri[len(uri)-1] != '/' {
uri += "/"
}
uri += "?"
} else {
uri += "&"
}
return uri + "compressors=" + comp
}

View File

@@ -0,0 +1,178 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package integration
import (
"context"
"fmt"
"os"
"testing"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/writeconcern"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
)
type scramTestCase struct {
username string
password string
mechanisms []string
altPassword string
}
func TestSCRAM(t *testing.T) {
if os.Getenv("AUTH") != "auth" {
t.Skip("Skipping because authentication is required")
}
server, err := testutil.Topology(t).SelectServer(context.Background(), description.WriteSelector())
noerr(t, err)
serverConnection, err := server.Connection(context.Background())
noerr(t, err)
defer serverConnection.Close()
if !serverConnection.Description().WireVersion.Includes(7) {
t.Skip("Skipping because MongoDB 4.0 is needed for SCRAM-SHA-256")
}
// Unicode constants for testing
var romanFour = "\u2163" // ROMAN NUMERAL FOUR -> SASL prepped is "IV"
var romanNine = "\u2168" // ROMAN NUMERAL NINE -> SASL prepped is "IX"
testUsers := []scramTestCase{
// SCRAM spec test steps 1-3
{username: "sha1", password: "sha1", mechanisms: []string{"SCRAM-SHA-1"}},
{username: "sha256", password: "sha256", mechanisms: []string{"SCRAM-SHA-256"}},
{username: "both", password: "both", mechanisms: []string{"SCRAM-SHA-1", "SCRAM-SHA-256"}},
// SCRAM spec test step 4
{username: "IX", password: "IX", mechanisms: []string{"SCRAM-SHA-256"}, altPassword: "I\u00ADX"},
{username: romanNine, password: romanFour, mechanisms: []string{"SCRAM-SHA-256"}, altPassword: "I\u00ADV"},
}
// Verify that test (root) user is authenticated. If this fails, the
// rest of the test can't succeed.
wc := writeconcern.New(writeconcern.WMajority())
collOne := testutil.ColName(t)
testutil.DropCollection(t, testutil.DBName(t), collOne)
testutil.InsertDocs(t, testutil.DBName(t),
collOne, wc, bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "name", "scram_test")),
)
// Test step 1: Create users for test cases
err = createScramUsers(t, server, testUsers)
if err != nil {
t.Fatal(err)
}
// Step 2 and 3a: For each auth mechanism, "SCRAM-SHA-1", "SCRAM-SHA-256"
// and "negotiate" (a fake, placeholder mechanism), iterate over each user
// and ensure that each mechanism that should succeed does so and each
// that should fail does so.
for _, m := range []string{"SCRAM-SHA-1", "SCRAM-SHA-256", "negotiate"} {
for _, c := range testUsers {
t.Run(
fmt.Sprintf("%s %s", c.username, m),
func(t *testing.T) {
err := testScramUserAuthWithMech(t, c, m)
if m == "negotiate" || hasAuthMech(c.mechanisms, m) {
noerr(t, err)
} else {
autherr(t, err)
}
},
)
}
}
// Step 3b: test non-existing user with negotiation fails with
// an auth.Error type.
bogus := scramTestCase{username: "eliot", password: "trustno1"}
err = testScramUserAuthWithMech(t, bogus, "negotiate")
autherr(t, err)
// XXX Step 4: test alternate password forms
for _, c := range testUsers {
if c.altPassword == "" {
continue
}
c.password = c.altPassword
t.Run(
fmt.Sprintf("%s alternate password", c.username),
func(t *testing.T) {
err := testScramUserAuthWithMech(t, c, "SCRAM-SHA-256")
noerr(t, err)
},
)
}
}
func hasAuthMech(mechs []string, m string) bool {
for _, v := range mechs {
if v == m {
return true
}
}
return false
}
func testScramUserAuthWithMech(t *testing.T, c scramTestCase, mech string) error {
t.Helper()
credential := options.Credential{
Username: c.username,
Password: c.password,
AuthSource: testutil.DBName(t),
}
switch mech {
case "negotiate":
credential.AuthMechanism = ""
default:
credential.AuthMechanism = mech
}
return runScramAuthTest(t, credential)
}
func runScramAuthTest(t *testing.T, credential options.Credential) error {
t.Helper()
topology := testutil.TopologyWithCredential(t, credential)
server, err := topology.SelectServer(context.Background(), description.WriteSelector())
noerr(t, err)
cmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dbstats", 1))
_, err = testutil.RunCommand(t, server, testutil.DBName(t), cmd)
return err
}
func createScramUsers(t *testing.T, s driver.Server, cases []scramTestCase) error {
db := testutil.DBName(t)
for _, c := range cases {
var values []bsoncore.Value
for _, v := range c.mechanisms {
values = append(values, bsoncore.Value{Type: bsontype.String, Data: bsoncore.AppendString(nil, v)})
}
newUserCmd := bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendStringElement(nil, "createUser", c.username),
bsoncore.AppendStringElement(nil, "pwd", c.password),
bsoncore.AppendArrayElement(nil, "roles", bsoncore.BuildArray(nil,
bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: bsoncore.BuildDocumentFromElements(nil,
bsoncore.AppendStringElement(nil, "role", "readWrite"),
bsoncore.AppendStringElement(nil, "db", db),
)},
)),
bsoncore.AppendArrayElement(nil, "mechanisms", bsoncore.BuildArray(nil, values...)),
)
_, err := testutil.RunCommand(t, s, db, newUserCmd)
if err != nil {
return fmt.Errorf("Couldn't create user '%s' on db '%s': %v", c.username, testutil.DBName(t), err)
}
}
return nil
}

View File

@@ -0,0 +1,22 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
// LegacyOperationKind indicates if an operation is a legacy find, getMore, or killCursors. This is used
// in Operation.Execute, which will create legacy OP_QUERY, OP_GET_MORE, or OP_KILL_CURSORS instead
// of sending them as a command.
type LegacyOperationKind uint
// These constants represent the three different kinds of legacy operations.
const (
LegacyNone LegacyOperationKind = iota
LegacyFind
LegacyGetMore
LegacyKillCursors
LegacyListCollections
LegacyListIndexes
)

View File

@@ -0,0 +1,134 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package driver
import (
"context"
"errors"
"io"
"strings"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ListCollectionsBatchCursor is a special batch cursor returned from ListCollections that properly
// handles current and legacy ListCollections operations.
type ListCollectionsBatchCursor struct {
legacy bool // server version < 3.0
bc *BatchCursor
currentBatch *bsoncore.DocumentSequence
err error
}
// NewListCollectionsBatchCursor creates a new non-legacy ListCollectionsCursor.
func NewListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) {
if bc == nil {
return nil, errors.New("batch cursor must not be nil")
}
return &ListCollectionsBatchCursor{bc: bc, currentBatch: new(bsoncore.DocumentSequence)}, nil
}
// NewLegacyListCollectionsBatchCursor creates a new legacy ListCollectionsCursor.
func NewLegacyListCollectionsBatchCursor(bc *BatchCursor) (*ListCollectionsBatchCursor, error) {
if bc == nil {
return nil, errors.New("batch cursor must not be nil")
}
return &ListCollectionsBatchCursor{legacy: true, bc: bc, currentBatch: new(bsoncore.DocumentSequence)}, nil
}
// ID returns the cursor ID for this batch cursor.
func (lcbc *ListCollectionsBatchCursor) ID() int64 {
return lcbc.bc.ID()
}
// Next indicates if there is another batch available. Returning false does not necessarily indicate
// that the cursor is closed. This method will return false when an empty batch is returned.
//
// If Next returns true, there is a valid batch of documents available. If Next returns false, there
// is not a valid batch of documents available.
func (lcbc *ListCollectionsBatchCursor) Next(ctx context.Context) bool {
if !lcbc.bc.Next(ctx) {
return false
}
if !lcbc.legacy {
lcbc.currentBatch.Style = lcbc.bc.currentBatch.Style
lcbc.currentBatch.Data = lcbc.bc.currentBatch.Data
lcbc.currentBatch.ResetIterator()
return true
}
lcbc.currentBatch.Style = bsoncore.SequenceStyle
lcbc.currentBatch.Data = lcbc.currentBatch.Data[:0]
var doc bsoncore.Document
for {
doc, lcbc.err = lcbc.bc.currentBatch.Next()
if lcbc.err != nil {
if lcbc.err == io.EOF {
lcbc.err = nil
break
}
return false
}
doc, lcbc.err = lcbc.projectNameElement(doc)
if lcbc.err != nil {
return false
}
lcbc.currentBatch.Data = append(lcbc.currentBatch.Data, doc...)
}
return true
}
// Batch will return a DocumentSequence for the current batch of documents. The returned
// DocumentSequence is only valid until the next call to Next or Close.
func (lcbc *ListCollectionsBatchCursor) Batch() *bsoncore.DocumentSequence { return lcbc.currentBatch }
// Server returns a pointer to the cursor's server.
func (lcbc *ListCollectionsBatchCursor) Server() Server { return lcbc.bc.server }
// Err returns the latest error encountered.
func (lcbc *ListCollectionsBatchCursor) Err() error {
if lcbc.err != nil {
return lcbc.err
}
return lcbc.bc.Err()
}
// Close closes this batch cursor.
func (lcbc *ListCollectionsBatchCursor) Close(ctx context.Context) error { return lcbc.bc.Close(ctx) }
// project out the database name for a legacy server
func (*ListCollectionsBatchCursor) projectNameElement(rawDoc bsoncore.Document) (bsoncore.Document, error) {
elems, err := rawDoc.Elements()
if err != nil {
return nil, err
}
var filteredElems []byte
for _, elem := range elems {
key := elem.Key()
if key != "name" {
filteredElems = append(filteredElems, elem...)
continue
}
name := elem.Value().StringValue()
collName := name[strings.Index(name, ".")+1:]
filteredElems = bsoncore.AppendStringElement(filteredElems, "name", collName)
}
var filteredDoc []byte
filteredDoc = bsoncore.BuildDocument(filteredDoc, filteredElems)
return filteredDoc, nil
}
// SetBatchSize sets the batchSize for future getMores.
func (lcbc *ListCollectionsBatchCursor) SetBatchSize(size int32) {
lcbc.bc.SetBatchSize(size)
}

View File

@@ -0,0 +1,56 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
// +build cse
package mongocrypt
// #include <mongocrypt.h>
import "C"
import (
"unsafe"
)
// binary is a wrapper type around a mongocrypt_binary_t*
type binary struct {
wrapped *C.mongocrypt_binary_t
}
// newBinary creates an empty binary instance.
func newBinary() *binary {
return &binary{
wrapped: C.mongocrypt_binary_new(),
}
}
// newBinaryFromBytes creates a binary instance from a byte buffer.
func newBinaryFromBytes(data []byte) *binary {
if len(data) == 0 {
return newBinary()
}
// We don't need C.CBytes here because data cannot go out of scope. Any mongocrypt function that takes a
// mongocrypt_binary_t will make a copy of the data so the data can be garbage collected after calling.
addr := (*C.uint8_t)(unsafe.Pointer(&data[0])) // uint8_t*
dataLen := C.uint32_t(len(data)) // uint32_t
return &binary{
wrapped: C.mongocrypt_binary_new_from_data(addr, dataLen),
}
}
// toBytes converts the given binary instance to []byte.
func (b *binary) toBytes() []byte {
dataPtr := C.mongocrypt_binary_data(b.wrapped) // C.uint8_t*
dataLen := C.mongocrypt_binary_len(b.wrapped) // C.uint32_t
return C.GoBytes(unsafe.Pointer(dataPtr), C.int(dataLen))
}
// close cleans up any resources associated with the given binary instance.
func (b *binary) close() {
C.mongocrypt_binary_destroy(b.wrapped)
}

View File

@@ -0,0 +1,44 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build cse
// +build cse
package mongocrypt
// #include <mongocrypt.h>
import "C"
import (
"fmt"
)
// Error represents an error from an operation on a MongoCrypt instance.
type Error struct {
Code int32
Message string
}
// Error implements the error interface.
func (e Error) Error() string {
return fmt.Sprintf("mongocrypt error %d: %v", e.Code, e.Message)
}
// errorFromStatus builds a Error from a mongocrypt_status_t object.
func errorFromStatus(status *C.mongocrypt_status_t) error {
cCode := C.mongocrypt_status_code(status) // uint32_t
// mongocrypt_status_message takes uint32_t* as its second param to store the length of the returned string.
// pass nil because the length is handled by C.GoString
cMsg := C.mongocrypt_status_message(status, nil) // const char*
var msg string
if cMsg != nil {
msg = C.GoString(cMsg)
}
return Error{
Code: int32(cCode),
Message: msg,
}
}

View File

@@ -0,0 +1,21 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//go:build !cse
// +build !cse
package mongocrypt
// Error represents an error from an operation on a MongoCrypt instance.
type Error struct {
Code int32
Message string
}
// Error implements the error interface
func (Error) Error() string {
panic(cseNotSupportedMsg)
}

Some files were not shown because too many files have changed in this diff Show More