Copied mongo repo (to patch it)

This commit is contained in:
2023-06-18 15:50:55 +02:00
parent 21d241f9b1
commit d471d7c396
544 changed files with 142039 additions and 1 deletions

View File

@@ -0,0 +1,33 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import "testing"
func compareErrors(err1, err2 error) bool {
if err1 == nil && err2 == nil {
return true
}
if err1 == nil || err2 == nil {
return false
}
if err1.Error() != err2.Error() {
return false
}
return true
}
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}

View File

@@ -0,0 +1,848 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsonrwtest provides utilities for testing the "bson/bsonrw" package.
package bsonrwtest // import "go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
import (
"testing"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var _ bsonrw.ValueReader = (*ValueReaderWriter)(nil)
var _ bsonrw.ValueWriter = (*ValueReaderWriter)(nil)
// Invoked is a type used to indicate what method was called last.
type Invoked byte
// These are the different methods that can be invoked.
const (
Nothing Invoked = iota
ReadArray
ReadBinary
ReadBoolean
ReadDocument
ReadCodeWithScope
ReadDBPointer
ReadDateTime
ReadDecimal128
ReadDouble
ReadInt32
ReadInt64
ReadJavascript
ReadMaxKey
ReadMinKey
ReadNull
ReadObjectID
ReadRegex
ReadString
ReadSymbol
ReadTimestamp
ReadUndefined
ReadElement
ReadValue
WriteArray
WriteBinary
WriteBinaryWithSubtype
WriteBoolean
WriteCodeWithScope
WriteDBPointer
WriteDateTime
WriteDecimal128
WriteDouble
WriteInt32
WriteInt64
WriteJavascript
WriteMaxKey
WriteMinKey
WriteNull
WriteObjectID
WriteRegex
WriteString
WriteDocument
WriteSymbol
WriteTimestamp
WriteUndefined
WriteDocumentElement
WriteDocumentEnd
WriteArrayElement
WriteArrayEnd
Skip
)
func (i Invoked) String() string {
switch i {
case Nothing:
return "Nothing"
case ReadArray:
return "ReadArray"
case ReadBinary:
return "ReadBinary"
case ReadBoolean:
return "ReadBoolean"
case ReadDocument:
return "ReadDocument"
case ReadCodeWithScope:
return "ReadCodeWithScope"
case ReadDBPointer:
return "ReadDBPointer"
case ReadDateTime:
return "ReadDateTime"
case ReadDecimal128:
return "ReadDecimal128"
case ReadDouble:
return "ReadDouble"
case ReadInt32:
return "ReadInt32"
case ReadInt64:
return "ReadInt64"
case ReadJavascript:
return "ReadJavascript"
case ReadMaxKey:
return "ReadMaxKey"
case ReadMinKey:
return "ReadMinKey"
case ReadNull:
return "ReadNull"
case ReadObjectID:
return "ReadObjectID"
case ReadRegex:
return "ReadRegex"
case ReadString:
return "ReadString"
case ReadSymbol:
return "ReadSymbol"
case ReadTimestamp:
return "ReadTimestamp"
case ReadUndefined:
return "ReadUndefined"
case ReadElement:
return "ReadElement"
case ReadValue:
return "ReadValue"
case WriteArray:
return "WriteArray"
case WriteBinary:
return "WriteBinary"
case WriteBinaryWithSubtype:
return "WriteBinaryWithSubtype"
case WriteBoolean:
return "WriteBoolean"
case WriteCodeWithScope:
return "WriteCodeWithScope"
case WriteDBPointer:
return "WriteDBPointer"
case WriteDateTime:
return "WriteDateTime"
case WriteDecimal128:
return "WriteDecimal128"
case WriteDouble:
return "WriteDouble"
case WriteInt32:
return "WriteInt32"
case WriteInt64:
return "WriteInt64"
case WriteJavascript:
return "WriteJavascript"
case WriteMaxKey:
return "WriteMaxKey"
case WriteMinKey:
return "WriteMinKey"
case WriteNull:
return "WriteNull"
case WriteObjectID:
return "WriteObjectID"
case WriteRegex:
return "WriteRegex"
case WriteString:
return "WriteString"
case WriteDocument:
return "WriteDocument"
case WriteSymbol:
return "WriteSymbol"
case WriteTimestamp:
return "WriteTimestamp"
case WriteUndefined:
return "WriteUndefined"
case WriteDocumentElement:
return "WriteDocumentElement"
case WriteDocumentEnd:
return "WriteDocumentEnd"
case WriteArrayElement:
return "WriteArrayElement"
case WriteArrayEnd:
return "WriteArrayEnd"
default:
return "<unknown>"
}
}
// ValueReaderWriter is a test implementation of a bsonrw.ValueReader and bsonrw.ValueWriter
type ValueReaderWriter struct {
T *testing.T
Invoked Invoked
Return interface{} // Can be a primitive or a bsoncore.Value
BSONType bsontype.Type
Err error
ErrAfter Invoked // error after this method is called
depth uint64
}
// prevent infinite recursion.
func (llvrw *ValueReaderWriter) checkdepth() {
llvrw.depth++
if llvrw.depth > 1000 {
panic("max depth exceeded")
}
}
// Type implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) Type() bsontype.Type {
llvrw.checkdepth()
return llvrw.BSONType
}
// Skip implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) Skip() error {
llvrw.checkdepth()
llvrw.Invoked = Skip
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// ReadArray implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadArray() (bsonrw.ArrayReader, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadArray
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// ReadBinary implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadBinary() (b []byte, btype byte, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadBinary
if llvrw.ErrAfter == llvrw.Invoked {
return nil, 0x00, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
subtype, data, _, ok := bsoncore.ReadBinary(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value provided for return value of ReadBinary.")
return nil, 0x00, nil
}
return data, subtype, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.Return)
return nil, 0x00, nil
}
}
// ReadBoolean implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadBoolean() (bool, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadBoolean
if llvrw.ErrAfter == llvrw.Invoked {
return false, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bool:
return tt, nil
case bsoncore.Value:
b, _, ok := bsoncore.ReadBoolean(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value provided for return value of ReadBoolean.")
return false, nil
}
return b, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.Return)
return false, nil
}
}
// ReadDocument implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadDocument() (bsonrw.DocumentReader, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadDocument
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// ReadCodeWithScope implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadCodeWithScope() (code string, dr bsonrw.DocumentReader, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadCodeWithScope
if llvrw.ErrAfter == llvrw.Invoked {
return "", nil, llvrw.Err
}
return "", llvrw, nil
}
// ReadDBPointer implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadDBPointer
if llvrw.ErrAfter == llvrw.Invoked {
return "", primitive.ObjectID{}, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value instance provided for return value of ReadDBPointer")
return "", primitive.ObjectID{}, nil
}
return ns, oid, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.Return)
return "", primitive.ObjectID{}, nil
}
}
// ReadDateTime implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadDateTime() (int64, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadDateTime
if llvrw.ErrAfter == llvrw.Invoked {
return 0, llvrw.Err
}
dt, ok := llvrw.Return.(int64)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.Return)
return 0, nil
}
return dt, nil
}
// ReadDecimal128 implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadDecimal128() (primitive.Decimal128, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadDecimal128
if llvrw.ErrAfter == llvrw.Invoked {
return primitive.Decimal128{}, llvrw.Err
}
d128, ok := llvrw.Return.(primitive.Decimal128)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.Return)
return primitive.Decimal128{}, nil
}
return d128, nil
}
// ReadDouble implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadDouble() (float64, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadDouble
if llvrw.ErrAfter == llvrw.Invoked {
return 0, llvrw.Err
}
f64, ok := llvrw.Return.(float64)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.Return)
return 0, nil
}
return f64, nil
}
// ReadInt32 implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadInt32() (int32, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadInt32
if llvrw.ErrAfter == llvrw.Invoked {
return 0, llvrw.Err
}
i32, ok := llvrw.Return.(int32)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.Return)
return 0, nil
}
return i32, nil
}
// ReadInt64 implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadInt64() (int64, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadInt64
if llvrw.ErrAfter == llvrw.Invoked {
return 0, llvrw.Err
}
i64, ok := llvrw.Return.(int64)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.Return)
return 0, nil
}
return i64, nil
}
// ReadJavascript implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadJavascript() (code string, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadJavascript
if llvrw.ErrAfter == llvrw.Invoked {
return "", llvrw.Err
}
js, ok := llvrw.Return.(string)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.Return)
return "", nil
}
return js, nil
}
// ReadMaxKey implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadMaxKey() error {
llvrw.checkdepth()
llvrw.Invoked = ReadMaxKey
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// ReadMinKey implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadMinKey() error {
llvrw.checkdepth()
llvrw.Invoked = ReadMinKey
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// ReadNull implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadNull() error {
llvrw.checkdepth()
llvrw.Invoked = ReadNull
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// ReadObjectID implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadObjectID() (primitive.ObjectID, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadObjectID
if llvrw.ErrAfter == llvrw.Invoked {
return primitive.ObjectID{}, llvrw.Err
}
oid, ok := llvrw.Return.(primitive.ObjectID)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.Return)
return primitive.ObjectID{}, nil
}
return oid, nil
}
// ReadRegex implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadRegex() (pattern string, options string, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadRegex
if llvrw.ErrAfter == llvrw.Invoked {
return "", "", llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
pattern, options, _, ok := bsoncore.ReadRegex(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value instance provided for ReadRegex")
return "", "", nil
}
return pattern, options, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.Return)
return "", "", nil
}
}
// ReadString implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadString() (string, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadString
if llvrw.ErrAfter == llvrw.Invoked {
return "", llvrw.Err
}
str, ok := llvrw.Return.(string)
if !ok {
llvrw.T.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.Return)
return "", nil
}
return str, nil
}
// ReadSymbol implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadSymbol() (symbol string, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadSymbol
if llvrw.ErrAfter == llvrw.Invoked {
return "", llvrw.Err
}
switch tt := llvrw.Return.(type) {
case string:
return tt, nil
case bsoncore.Value:
symbol, _, ok := bsoncore.ReadSymbol(tt.Data)
if !ok {
llvrw.T.Error("Invalid Value instance provided for ReadSymbol")
return "", nil
}
return symbol, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.Return)
return "", nil
}
}
// ReadTimestamp implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) {
llvrw.checkdepth()
llvrw.Invoked = ReadTimestamp
if llvrw.ErrAfter == llvrw.Invoked {
return 0, 0, llvrw.Err
}
switch tt := llvrw.Return.(type) {
case bsoncore.Value:
t, i, _, ok := bsoncore.ReadTimestamp(tt.Data)
if !ok {
llvrw.T.Errorf("Invalid Value instance provided for return value of ReadTimestamp")
return 0, 0, nil
}
return t, i, nil
default:
llvrw.T.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.Return)
return 0, 0, nil
}
}
// ReadUndefined implements the bsonrw.ValueReader interface.
func (llvrw *ValueReaderWriter) ReadUndefined() error {
llvrw.checkdepth()
llvrw.Invoked = ReadUndefined
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteArray implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteArray() (bsonrw.ArrayWriter, error) {
llvrw.checkdepth()
llvrw.Invoked = WriteArray
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteBinary implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteBinary(b []byte) error {
llvrw.checkdepth()
llvrw.Invoked = WriteBinary
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteBinaryWithSubtype implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
llvrw.checkdepth()
llvrw.Invoked = WriteBinaryWithSubtype
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteBoolean implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteBoolean(bool) error {
llvrw.checkdepth()
llvrw.Invoked = WriteBoolean
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteCodeWithScope implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteCodeWithScope(code string) (bsonrw.DocumentWriter, error) {
llvrw.checkdepth()
llvrw.Invoked = WriteCodeWithScope
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteDBPointer implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
llvrw.checkdepth()
llvrw.Invoked = WriteDBPointer
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteDateTime implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteDateTime(dt int64) error {
llvrw.checkdepth()
llvrw.Invoked = WriteDateTime
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteDecimal128 implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteDecimal128(primitive.Decimal128) error {
llvrw.checkdepth()
llvrw.Invoked = WriteDecimal128
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteDouble implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteDouble(float64) error {
llvrw.checkdepth()
llvrw.Invoked = WriteDouble
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteInt32 implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteInt32(int32) error {
llvrw.checkdepth()
llvrw.Invoked = WriteInt32
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteInt64 implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteInt64(int64) error {
llvrw.checkdepth()
llvrw.Invoked = WriteInt64
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteJavascript implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteJavascript(code string) error {
llvrw.checkdepth()
llvrw.Invoked = WriteJavascript
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteMaxKey implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteMaxKey() error {
llvrw.checkdepth()
llvrw.Invoked = WriteMaxKey
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteMinKey implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteMinKey() error {
llvrw.checkdepth()
llvrw.Invoked = WriteMinKey
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteNull implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteNull() error {
llvrw.checkdepth()
llvrw.Invoked = WriteNull
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteObjectID implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteObjectID(primitive.ObjectID) error {
llvrw.checkdepth()
llvrw.Invoked = WriteObjectID
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteRegex implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteRegex(pattern string, options string) error {
llvrw.checkdepth()
llvrw.Invoked = WriteRegex
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteString implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteString(string) error {
llvrw.checkdepth()
llvrw.Invoked = WriteString
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteDocument implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteDocument() (bsonrw.DocumentWriter, error) {
llvrw.checkdepth()
llvrw.Invoked = WriteDocument
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteSymbol implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteSymbol(symbol string) error {
llvrw.checkdepth()
llvrw.Invoked = WriteSymbol
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteTimestamp implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteTimestamp(t uint32, i uint32) error {
llvrw.checkdepth()
llvrw.Invoked = WriteTimestamp
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// WriteUndefined implements the bsonrw.ValueWriter interface.
func (llvrw *ValueReaderWriter) WriteUndefined() error {
llvrw.checkdepth()
llvrw.Invoked = WriteUndefined
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// ReadElement implements the bsonrw.DocumentReader interface.
func (llvrw *ValueReaderWriter) ReadElement() (string, bsonrw.ValueReader, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadElement
if llvrw.ErrAfter == llvrw.Invoked {
return "", nil, llvrw.Err
}
return "", llvrw, nil
}
// WriteDocumentElement implements the bsonrw.DocumentWriter interface.
func (llvrw *ValueReaderWriter) WriteDocumentElement(string) (bsonrw.ValueWriter, error) {
llvrw.checkdepth()
llvrw.Invoked = WriteDocumentElement
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteDocumentEnd implements the bsonrw.DocumentWriter interface.
func (llvrw *ValueReaderWriter) WriteDocumentEnd() error {
llvrw.checkdepth()
llvrw.Invoked = WriteDocumentEnd
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}
// ReadValue implements the bsonrw.ArrayReader interface.
func (llvrw *ValueReaderWriter) ReadValue() (bsonrw.ValueReader, error) {
llvrw.checkdepth()
llvrw.Invoked = ReadValue
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteArrayElement implements the bsonrw.ArrayWriter interface.
func (llvrw *ValueReaderWriter) WriteArrayElement() (bsonrw.ValueWriter, error) {
llvrw.checkdepth()
llvrw.Invoked = WriteArrayElement
if llvrw.ErrAfter == llvrw.Invoked {
return nil, llvrw.Err
}
return llvrw, nil
}
// WriteArrayEnd implements the bsonrw.ArrayWriter interface.
func (llvrw *ValueReaderWriter) WriteArrayEnd() error {
llvrw.checkdepth()
llvrw.Invoked = WriteArrayEnd
if llvrw.ErrAfter == llvrw.Invoked {
return llvrw.Err
}
return nil
}

445
mongo/bson/bsonrw/copier.go Normal file
View File

@@ -0,0 +1,445 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// Copier is a type that allows copying between ValueReaders, ValueWriters, and
// []byte values.
type Copier struct{}
// NewCopier creates a new copier with the given registry. If a nil registry is provided
// a default registry is used.
func NewCopier() Copier {
return Copier{}
}
// CopyDocument handles copying a document from src to dst.
func CopyDocument(dst ValueWriter, src ValueReader) error {
return Copier{}.CopyDocument(dst, src)
}
// CopyDocument handles copying one document from the src to the dst.
func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error {
dr, err := src.ReadDocument()
if err != nil {
return err
}
dw, err := dst.WriteDocument()
if err != nil {
return err
}
return c.copyDocumentCore(dw, dr)
}
// CopyArrayFromBytes copies the values from a BSON array represented as a
// []byte to a ValueWriter.
func (c Copier) CopyArrayFromBytes(dst ValueWriter, src []byte) error {
aw, err := dst.WriteArray()
if err != nil {
return err
}
err = c.CopyBytesToArrayWriter(aw, src)
if err != nil {
return err
}
return aw.WriteArrayEnd()
}
// CopyDocumentFromBytes copies the values from a BSON document represented as a
// []byte to a ValueWriter.
func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error {
dw, err := dst.WriteDocument()
if err != nil {
return err
}
err = c.CopyBytesToDocumentWriter(dw, src)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
type writeElementFn func(key string) (ValueWriter, error)
// CopyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an
// ArrayWriter.
func (c Copier) CopyBytesToArrayWriter(dst ArrayWriter, src []byte) error {
wef := func(_ string) (ValueWriter, error) {
return dst.WriteArrayElement()
}
return c.copyBytesToValueWriter(src, wef)
}
// CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a
// DocumentWriter.
func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error {
wef := func(key string) (ValueWriter, error) {
return dst.WriteDocumentElement(key)
}
return c.copyBytesToValueWriter(src, wef)
}
func (c Copier) copyBytesToValueWriter(src []byte, wef writeElementFn) error {
// TODO(skriptble): Create errors types here. Anything thats a tag should be a property.
length, rem, ok := bsoncore.ReadLength(src)
if !ok {
return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src))
}
if len(src) < int(length) {
return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length)
}
rem = rem[:length-4]
var t bsontype.Type
var key string
var val bsoncore.Value
for {
t, rem, ok = bsoncore.ReadType(rem)
if !ok {
return io.EOF
}
if t == bsontype.Type(0) {
if len(rem) != 0 {
return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem)
}
break
}
key, rem, ok = bsoncore.ReadKey(rem)
if !ok {
return fmt.Errorf("invalid key found. remaining bytes=%v", rem)
}
// write as either array element or document element using writeElementFn
vw, err := wef(key)
if err != nil {
return err
}
val, rem, ok = bsoncore.ReadValue(rem, t)
if !ok {
return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t)
}
err = c.CopyValueFromBytes(vw, t, val.Data)
if err != nil {
return err
}
}
return nil
}
// CopyDocumentToBytes copies an entire document from the ValueReader and
// returns it as bytes.
func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) {
return c.AppendDocumentBytes(nil, src)
}
// AppendDocumentBytes functions the same as CopyDocumentToBytes, but will
// append the result to dst.
func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(BytesReader); ok {
_, dst, err := br.ReadValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
vw.reset(dst)
err := c.CopyDocument(vw, src)
dst = vw.buf
return dst, err
}
// AppendArrayBytes copies an array from the ValueReader to dst.
func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) {
if br, ok := src.(BytesReader); ok {
_, dst, err := br.ReadValueBytes(dst)
return dst, err
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
vw.reset(dst)
err := c.copyArray(vw, src)
dst = vw.buf
return dst, err
}
// CopyValueFromBytes will write the value represtend by t and src to dst.
func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error {
if wvb, ok := dst.(BytesWriter); ok {
return wvb.WriteValueBytes(t, src)
}
vr := vrPool.Get().(*valueReader)
defer vrPool.Put(vr)
vr.reset(src)
vr.pushElement(t)
return c.CopyValue(dst, vr)
}
// CopyValueToBytes copies a value from src and returns it as a bsontype.Type and a
// []byte.
func (c Copier) CopyValueToBytes(src ValueReader) (bsontype.Type, []byte, error) {
return c.AppendValueBytes(nil, src)
}
// AppendValueBytes functions the same as CopyValueToBytes, but will append the
// result to dst.
func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []byte, error) {
if br, ok := src.(BytesReader); ok {
return br.ReadValueBytes(dst)
}
vw := vwPool.Get().(*valueWriter)
defer vwPool.Put(vw)
start := len(dst)
vw.reset(dst)
vw.push(mElement)
err := c.CopyValue(vw, src)
if err != nil {
return 0, dst, err
}
return bsontype.Type(vw.buf[start]), vw.buf[start+2:], nil
}
// CopyValue will copy a single value from src to dst.
func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error {
var err error
switch src.Type() {
case bsontype.Double:
var f64 float64
f64, err = src.ReadDouble()
if err != nil {
break
}
err = dst.WriteDouble(f64)
case bsontype.String:
var str string
str, err = src.ReadString()
if err != nil {
return err
}
err = dst.WriteString(str)
case bsontype.EmbeddedDocument:
err = c.CopyDocument(dst, src)
case bsontype.Array:
err = c.copyArray(dst, src)
case bsontype.Binary:
var data []byte
var subtype byte
data, subtype, err = src.ReadBinary()
if err != nil {
break
}
err = dst.WriteBinaryWithSubtype(data, subtype)
case bsontype.Undefined:
err = src.ReadUndefined()
if err != nil {
break
}
err = dst.WriteUndefined()
case bsontype.ObjectID:
var oid primitive.ObjectID
oid, err = src.ReadObjectID()
if err != nil {
break
}
err = dst.WriteObjectID(oid)
case bsontype.Boolean:
var b bool
b, err = src.ReadBoolean()
if err != nil {
break
}
err = dst.WriteBoolean(b)
case bsontype.DateTime:
var dt int64
dt, err = src.ReadDateTime()
if err != nil {
break
}
err = dst.WriteDateTime(dt)
case bsontype.Null:
err = src.ReadNull()
if err != nil {
break
}
err = dst.WriteNull()
case bsontype.Regex:
var pattern, options string
pattern, options, err = src.ReadRegex()
if err != nil {
break
}
err = dst.WriteRegex(pattern, options)
case bsontype.DBPointer:
var ns string
var pointer primitive.ObjectID
ns, pointer, err = src.ReadDBPointer()
if err != nil {
break
}
err = dst.WriteDBPointer(ns, pointer)
case bsontype.JavaScript:
var js string
js, err = src.ReadJavascript()
if err != nil {
break
}
err = dst.WriteJavascript(js)
case bsontype.Symbol:
var symbol string
symbol, err = src.ReadSymbol()
if err != nil {
break
}
err = dst.WriteSymbol(symbol)
case bsontype.CodeWithScope:
var code string
var srcScope DocumentReader
code, srcScope, err = src.ReadCodeWithScope()
if err != nil {
break
}
var dstScope DocumentWriter
dstScope, err = dst.WriteCodeWithScope(code)
if err != nil {
break
}
err = c.copyDocumentCore(dstScope, srcScope)
case bsontype.Int32:
var i32 int32
i32, err = src.ReadInt32()
if err != nil {
break
}
err = dst.WriteInt32(i32)
case bsontype.Timestamp:
var t, i uint32
t, i, err = src.ReadTimestamp()
if err != nil {
break
}
err = dst.WriteTimestamp(t, i)
case bsontype.Int64:
var i64 int64
i64, err = src.ReadInt64()
if err != nil {
break
}
err = dst.WriteInt64(i64)
case bsontype.Decimal128:
var d128 primitive.Decimal128
d128, err = src.ReadDecimal128()
if err != nil {
break
}
err = dst.WriteDecimal128(d128)
case bsontype.MinKey:
err = src.ReadMinKey()
if err != nil {
break
}
err = dst.WriteMinKey()
case bsontype.MaxKey:
err = src.ReadMaxKey()
if err != nil {
break
}
err = dst.WriteMaxKey()
default:
err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type())
}
return err
}
func (c Copier) copyArray(dst ValueWriter, src ValueReader) error {
ar, err := src.ReadArray()
if err != nil {
return err
}
aw, err := dst.WriteArray()
if err != nil {
return err
}
for {
vr, err := ar.ReadValue()
if err == ErrEOA {
break
}
if err != nil {
return err
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
err = c.CopyValue(vw, vr)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error {
for {
key, vr, err := dr.ReadElement()
if err == ErrEOD {
break
}
if err != nil {
return err
}
vw, err := dw.WriteDocumentElement(key)
if err != nil {
return err
}
err = c.CopyValue(vw, vr)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}

View File

@@ -0,0 +1,529 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"errors"
"fmt"
"testing"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestCopier(t *testing.T) {
t.Run("CopyDocument", func(t *testing.T) {
t.Run("ReadDocument Error", func(t *testing.T) {
want := errors.New("ReadDocumentError")
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadDocument}
got := Copier{}.CopyDocument(nil, src)
if !compareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("WriteDocument Error", func(t *testing.T) {
want := errors.New("WriteDocumentError")
src := &TestValueReaderWriter{}
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteDocument}
got := Copier{}.CopyDocument(dst, src)
if !compareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "Hello", "world")
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
src := newValueReader(doc)
dst := newValueWriterFromSlice(make([]byte, 0))
want := doc
err = Copier{}.CopyDocument(dst, src)
noerr(t, err)
got := dst.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
})
t.Run("copyArray", func(t *testing.T) {
t.Run("ReadArray Error", func(t *testing.T) {
want := errors.New("ReadArrayError")
src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadArray}
got := Copier{}.copyArray(nil, src)
if !compareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("WriteArray Error", func(t *testing.T) {
want := errors.New("WriteArrayError")
src := &TestValueReaderWriter{}
dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteArray}
got := Copier{}.copyArray(dst, src)
if !compareErrors(got, want) {
t.Errorf("Did not receive correct error. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
idx, doc := bsoncore.AppendDocumentStart(nil)
aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo")
doc = bsoncore.AppendStringElement(doc, "0", "Hello, world!")
doc, err := bsoncore.AppendArrayEnd(doc, aidx)
noerr(t, err)
doc, err = bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
src := newValueReader(doc)
_, err = src.ReadDocument()
noerr(t, err)
_, _, err = src.ReadElement()
noerr(t, err)
dst := newValueWriterFromSlice(make([]byte, 0))
_, err = dst.WriteDocument()
noerr(t, err)
_, err = dst.WriteDocumentElement("foo")
noerr(t, err)
want := doc
err = Copier{}.copyArray(dst, src)
noerr(t, err)
err = dst.WriteDocumentEnd()
noerr(t, err)
got := dst.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
})
t.Run("CopyValue", func(t *testing.T) {
testCases := []struct {
name string
dst *TestValueReaderWriter
src *TestValueReaderWriter
err error
}{
{
"Double/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Double, err: errors.New("1"), errAfter: llvrwReadDouble},
errors.New("1"),
},
{
"Double/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Double, err: errors.New("2"), errAfter: llvrwWriteDouble},
&TestValueReaderWriter{bsontype: bsontype.Double, readval: float64(3.14159)},
errors.New("2"),
},
{
"String/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.String, err: errors.New("1"), errAfter: llvrwReadString},
errors.New("1"),
},
{
"String/dst/error",
&TestValueReaderWriter{bsontype: bsontype.String, err: errors.New("2"), errAfter: llvrwWriteString},
&TestValueReaderWriter{bsontype: bsontype.String, readval: "hello, world"},
errors.New("2"),
},
{
"Document/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.EmbeddedDocument, err: errors.New("1"), errAfter: llvrwReadDocument},
errors.New("1"),
},
{
"Array/dst/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Array, err: errors.New("2"), errAfter: llvrwReadArray},
errors.New("2"),
},
{
"Binary/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Binary, err: errors.New("1"), errAfter: llvrwReadBinary},
errors.New("1"),
},
{
"Binary/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Binary, err: errors.New("2"), errAfter: llvrwWriteBinaryWithSubtype},
&TestValueReaderWriter{
bsontype: bsontype.Binary,
readval: bsoncore.Value{
Type: bsontype.Binary,
Data: []byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03},
},
},
errors.New("2"),
},
{
"Undefined/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Undefined, err: errors.New("1"), errAfter: llvrwReadUndefined},
errors.New("1"),
},
{
"Undefined/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Undefined, err: errors.New("2"), errAfter: llvrwWriteUndefined},
&TestValueReaderWriter{bsontype: bsontype.Undefined},
errors.New("2"),
},
{
"ObjectID/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.ObjectID, err: errors.New("1"), errAfter: llvrwReadObjectID},
errors.New("1"),
},
{
"ObjectID/dst/error",
&TestValueReaderWriter{bsontype: bsontype.ObjectID, err: errors.New("2"), errAfter: llvrwWriteObjectID},
&TestValueReaderWriter{bsontype: bsontype.ObjectID, readval: primitive.ObjectID{0x01, 0x02, 0x03}},
errors.New("2"),
},
{
"Boolean/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Boolean, err: errors.New("1"), errAfter: llvrwReadBoolean},
errors.New("1"),
},
{
"Boolean/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Boolean, err: errors.New("2"), errAfter: llvrwWriteBoolean},
&TestValueReaderWriter{bsontype: bsontype.Boolean, readval: bool(true)},
errors.New("2"),
},
{
"DateTime/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.DateTime, err: errors.New("1"), errAfter: llvrwReadDateTime},
errors.New("1"),
},
{
"DateTime/dst/error",
&TestValueReaderWriter{bsontype: bsontype.DateTime, err: errors.New("2"), errAfter: llvrwWriteDateTime},
&TestValueReaderWriter{bsontype: bsontype.DateTime, readval: int64(1234567890)},
errors.New("2"),
},
{
"Null/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Null, err: errors.New("1"), errAfter: llvrwReadNull},
errors.New("1"),
},
{
"Null/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Null, err: errors.New("2"), errAfter: llvrwWriteNull},
&TestValueReaderWriter{bsontype: bsontype.Null},
errors.New("2"),
},
{
"Regex/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Regex, err: errors.New("1"), errAfter: llvrwReadRegex},
errors.New("1"),
},
{
"Regex/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Regex, err: errors.New("2"), errAfter: llvrwWriteRegex},
&TestValueReaderWriter{
bsontype: bsontype.Regex,
readval: bsoncore.Value{
Type: bsontype.Regex,
Data: bsoncore.AppendRegex(nil, "hello", "world"),
},
},
errors.New("2"),
},
{
"DBPointer/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.DBPointer, err: errors.New("1"), errAfter: llvrwReadDBPointer},
errors.New("1"),
},
{
"DBPointer/dst/error",
&TestValueReaderWriter{bsontype: bsontype.DBPointer, err: errors.New("2"), errAfter: llvrwWriteDBPointer},
&TestValueReaderWriter{
bsontype: bsontype.DBPointer,
readval: bsoncore.Value{
Type: bsontype.DBPointer,
Data: bsoncore.AppendDBPointer(nil, "foo", primitive.ObjectID{0x01, 0x02, 0x03}),
},
},
errors.New("2"),
},
{
"Javascript/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.JavaScript, err: errors.New("1"), errAfter: llvrwReadJavascript},
errors.New("1"),
},
{
"Javascript/dst/error",
&TestValueReaderWriter{bsontype: bsontype.JavaScript, err: errors.New("2"), errAfter: llvrwWriteJavascript},
&TestValueReaderWriter{bsontype: bsontype.JavaScript, readval: "hello, world"},
errors.New("2"),
},
{
"Symbol/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Symbol, err: errors.New("1"), errAfter: llvrwReadSymbol},
errors.New("1"),
},
{
"Symbol/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Symbol, err: errors.New("2"), errAfter: llvrwWriteSymbol},
&TestValueReaderWriter{
bsontype: bsontype.Symbol,
readval: bsoncore.Value{
Type: bsontype.Symbol,
Data: bsoncore.AppendSymbol(nil, "hello, world"),
},
},
errors.New("2"),
},
{
"CodeWithScope/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.CodeWithScope, err: errors.New("1"), errAfter: llvrwReadCodeWithScope},
errors.New("1"),
},
{
"CodeWithScope/dst/error",
&TestValueReaderWriter{bsontype: bsontype.CodeWithScope, err: errors.New("2"), errAfter: llvrwWriteCodeWithScope},
&TestValueReaderWriter{bsontype: bsontype.CodeWithScope},
errors.New("2"),
},
{
"CodeWithScope/dst/copyDocumentCore error",
&TestValueReaderWriter{err: errors.New("3"), errAfter: llvrwWriteDocumentElement},
&TestValueReaderWriter{bsontype: bsontype.CodeWithScope},
errors.New("3"),
},
{
"Int32/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Int32, err: errors.New("1"), errAfter: llvrwReadInt32},
errors.New("1"),
},
{
"Int32/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Int32, err: errors.New("2"), errAfter: llvrwWriteInt32},
&TestValueReaderWriter{bsontype: bsontype.Int32, readval: int32(12345)},
errors.New("2"),
},
{
"Timestamp/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Timestamp, err: errors.New("1"), errAfter: llvrwReadTimestamp},
errors.New("1"),
},
{
"Timestamp/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Timestamp, err: errors.New("2"), errAfter: llvrwWriteTimestamp},
&TestValueReaderWriter{
bsontype: bsontype.Timestamp,
readval: bsoncore.Value{
Type: bsontype.Timestamp,
Data: bsoncore.AppendTimestamp(nil, 12345, 67890),
},
},
errors.New("2"),
},
{
"Int64/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Int64, err: errors.New("1"), errAfter: llvrwReadInt64},
errors.New("1"),
},
{
"Int64/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Int64, err: errors.New("2"), errAfter: llvrwWriteInt64},
&TestValueReaderWriter{bsontype: bsontype.Int64, readval: int64(1234567890)},
errors.New("2"),
},
{
"Decimal128/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.Decimal128, err: errors.New("1"), errAfter: llvrwReadDecimal128},
errors.New("1"),
},
{
"Decimal128/dst/error",
&TestValueReaderWriter{bsontype: bsontype.Decimal128, err: errors.New("2"), errAfter: llvrwWriteDecimal128},
&TestValueReaderWriter{bsontype: bsontype.Decimal128, readval: primitive.NewDecimal128(12345, 67890)},
errors.New("2"),
},
{
"MinKey/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.MinKey, err: errors.New("1"), errAfter: llvrwReadMinKey},
errors.New("1"),
},
{
"MinKey/dst/error",
&TestValueReaderWriter{bsontype: bsontype.MinKey, err: errors.New("2"), errAfter: llvrwWriteMinKey},
&TestValueReaderWriter{bsontype: bsontype.MinKey},
errors.New("2"),
},
{
"MaxKey/src/error",
&TestValueReaderWriter{},
&TestValueReaderWriter{bsontype: bsontype.MaxKey, err: errors.New("1"), errAfter: llvrwReadMaxKey},
errors.New("1"),
},
{
"MaxKey/dst/error",
&TestValueReaderWriter{bsontype: bsontype.MaxKey, err: errors.New("2"), errAfter: llvrwWriteMaxKey},
&TestValueReaderWriter{bsontype: bsontype.MaxKey},
errors.New("2"),
},
{
"Unknown BSON type error",
&TestValueReaderWriter{},
&TestValueReaderWriter{},
fmt.Errorf("Cannot copy unknown BSON type %s", bsontype.Type(0)),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.dst.t, tc.src.t = t, t
err := Copier{}.CopyValue(tc.dst, tc.src)
if !compareErrors(err, tc.err) {
t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err)
}
})
}
})
t.Run("CopyValueFromBytes", func(t *testing.T) {
t.Run("BytesWriter", func(t *testing.T) {
vw := newValueWriterFromSlice(make([]byte, 0))
_, err := vw.WriteDocument()
noerr(t, err)
_, err = vw.WriteDocumentElement("foo")
noerr(t, err)
err = Copier{}.CopyValueFromBytes(vw, bsontype.String, bsoncore.AppendString(nil, "bar"))
noerr(t, err)
err = vw.WriteDocumentEnd()
noerr(t, err)
var idx int32
want, err := bsoncore.AppendDocumentEnd(
bsoncore.AppendStringElement(
bsoncore.AppendDocumentStartInline(nil, &idx),
"foo", "bar",
),
idx,
)
noerr(t, err)
got := vw.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
t.Run("Non BytesWriter", func(t *testing.T) {
llvrw := &TestValueReaderWriter{t: t}
err := Copier{}.CopyValueFromBytes(llvrw, bsontype.String, bsoncore.AppendString(nil, "bar"))
noerr(t, err)
got, want := llvrw.invoked, llvrwWriteString
if got != want {
t.Errorf("Incorrect method invoked on llvrw. got %v; want %v", got, want)
}
})
})
t.Run("CopyValueToBytes", func(t *testing.T) {
t.Run("BytesReader", func(t *testing.T) {
var idx int32
b, err := bsoncore.AppendDocumentEnd(
bsoncore.AppendStringElement(
bsoncore.AppendDocumentStartInline(nil, &idx),
"hello", "world",
),
idx,
)
noerr(t, err)
vr := newValueReader(b)
_, err = vr.ReadDocument()
noerr(t, err)
_, _, err = vr.ReadElement()
noerr(t, err)
btype, got, err := Copier{}.CopyValueToBytes(vr)
noerr(t, err)
want := bsoncore.AppendString(nil, "world")
if btype != bsontype.String {
t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
t.Run("Non BytesReader", func(t *testing.T) {
llvrw := &TestValueReaderWriter{t: t, bsontype: bsontype.String, readval: "Hello, world!"}
btype, got, err := Copier{}.CopyValueToBytes(llvrw)
noerr(t, err)
want := bsoncore.AppendString(nil, "Hello, world!")
if btype != bsontype.String {
t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
})
t.Run("AppendValueBytes", func(t *testing.T) {
t.Run("BytesReader", func(t *testing.T) {
var idx int32
b, err := bsoncore.AppendDocumentEnd(
bsoncore.AppendStringElement(
bsoncore.AppendDocumentStartInline(nil, &idx),
"hello", "world",
),
idx,
)
noerr(t, err)
vr := newValueReader(b)
_, err = vr.ReadDocument()
noerr(t, err)
_, _, err = vr.ReadElement()
noerr(t, err)
btype, got, err := Copier{}.AppendValueBytes(nil, vr)
noerr(t, err)
want := bsoncore.AppendString(nil, "world")
if btype != bsontype.String {
t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
t.Run("Non BytesReader", func(t *testing.T) {
llvrw := &TestValueReaderWriter{t: t, bsontype: bsontype.String, readval: "Hello, world!"}
btype, got, err := Copier{}.AppendValueBytes(nil, llvrw)
noerr(t, err)
want := bsoncore.AppendString(nil, "Hello, world!")
if btype != bsontype.String {
t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String)
}
if !bytes.Equal(got, want) {
t.Errorf("Bytes do not match. got %v; want %v", got, want)
}
})
t.Run("CopyValue error", func(t *testing.T) {
want := errors.New("CopyValue error")
llvrw := &TestValueReaderWriter{t: t, bsontype: bsontype.String, err: want, errAfter: llvrwReadString}
_, _, got := Copier{}.AppendValueBytes(make([]byte, 0), llvrw)
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
})
}

9
mongo/bson/bsonrw/doc.go Normal file
View File

@@ -0,0 +1,9 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsonrw contains abstractions for reading and writing
// BSON and BSON like types from sources.
package bsonrw // import "go.mongodb.org/mongo-driver/bson/bsonrw"

View File

@@ -0,0 +1,806 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
const maxNestingDepth = 200
// ErrInvalidJSON indicates the JSON input is invalid
var ErrInvalidJSON = errors.New("invalid JSON input")
type jsonParseState byte
const (
jpsStartState jsonParseState = iota
jpsSawBeginObject
jpsSawEndObject
jpsSawBeginArray
jpsSawEndArray
jpsSawColon
jpsSawComma
jpsSawKey
jpsSawValue
jpsDoneState
jpsInvalidState
)
type jsonParseMode byte
const (
jpmInvalidMode jsonParseMode = iota
jpmObjectMode
jpmArrayMode
)
type extJSONValue struct {
t bsontype.Type
v interface{}
}
type extJSONObject struct {
keys []string
values []*extJSONValue
}
type extJSONParser struct {
js *jsonScanner
s jsonParseState
m []jsonParseMode
k string
v *extJSONValue
err error
canonical bool
depth int
maxDepth int
emptyObject bool
relaxedUUID bool
}
// newExtJSONParser returns a new extended JSON parser, ready to to begin
// parsing from the first character of the argued json input. It will not
// perform any read-ahead and will therefore not report any errors about
// malformed JSON at this point.
func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser {
return &extJSONParser{
js: &jsonScanner{r: r},
s: jpsStartState,
m: []jsonParseMode{},
canonical: canonical,
maxDepth: maxNestingDepth,
}
}
// peekType examines the next value and returns its BSON Type
func (ejp *extJSONParser) peekType() (bsontype.Type, error) {
var t bsontype.Type
var err error
initialState := ejp.s
ejp.advanceState()
switch ejp.s {
case jpsSawValue:
t = ejp.v.t
case jpsSawBeginArray:
t = bsontype.Array
case jpsInvalidState:
err = ejp.err
case jpsSawComma:
// in array mode, seeing a comma means we need to progress again to actually observe a type
if ejp.peekMode() == jpmArrayMode {
return ejp.peekType()
}
case jpsSawEndArray:
// this would only be a valid state if we were in array mode, so return end-of-array error
err = ErrEOA
case jpsSawBeginObject:
// peek key to determine type
ejp.advanceState()
switch ejp.s {
case jpsSawEndObject: // empty embedded document
t = bsontype.EmbeddedDocument
ejp.emptyObject = true
case jpsInvalidState:
err = ejp.err
case jpsSawKey:
if initialState == jpsStartState {
return bsontype.EmbeddedDocument, nil
}
t = wrapperKeyBSONType(ejp.k)
// if $uuid is encountered, parse as binary subtype 4
if ejp.k == "$uuid" {
ejp.relaxedUUID = true
t = bsontype.Binary
}
switch t {
case bsontype.JavaScript:
// just saw $code, need to check for $scope at same level
_, err = ejp.readValue(bsontype.JavaScript)
if err != nil {
break
}
switch ejp.s {
case jpsSawEndObject: // type is TypeJavaScript
case jpsSawComma:
ejp.advanceState()
if ejp.s == jpsSawKey && ejp.k == "$scope" {
t = bsontype.CodeWithScope
} else {
err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k)
}
case jpsInvalidState:
err = ejp.err
default:
err = ErrInvalidJSON
}
case bsontype.CodeWithScope:
err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope")
}
}
}
return t, err
}
// readKey parses the next key and its type and returns them
func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) {
if ejp.emptyObject {
ejp.emptyObject = false
return "", 0, ErrEOD
}
// advance to key (or return with error)
switch ejp.s {
case jpsStartState:
ejp.advanceState()
if ejp.s == jpsSawBeginObject {
ejp.advanceState()
}
case jpsSawBeginObject:
ejp.advanceState()
case jpsSawValue, jpsSawEndObject, jpsSawEndArray:
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject, jpsSawComma:
ejp.advanceState()
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsDoneState:
return "", 0, io.EOF
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, ErrInvalidJSON
}
case jpsSawKey: // do nothing (key was peeked before)
default:
return "", 0, invalidRequestError("key")
}
// read key
var key string
switch ejp.s {
case jpsSawKey:
key = ejp.k
case jpsSawEndObject:
return "", 0, ErrEOD
case jpsInvalidState:
return "", 0, ejp.err
default:
return "", 0, invalidRequestError("key")
}
// check for colon
ejp.advanceState()
if err := ensureColon(ejp.s, key); err != nil {
return "", 0, err
}
// peek at the value to determine type
t, err := ejp.peekType()
if err != nil {
return "", 0, err
}
return key, t, nil
}
// readValue returns the value corresponding to the Type returned by peekType
func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) {
if ejp.s == jpsInvalidState {
return nil, ejp.err
}
var v *extJSONValue
switch t {
case bsontype.Null, bsontype.Boolean, bsontype.String:
if ejp.s != jpsSawValue {
return nil, invalidRequestError(t.String())
}
v = ejp.v
case bsontype.Int32, bsontype.Int64, bsontype.Double:
// relaxed version allows these to be literal number values
if ejp.s == jpsSawValue {
v = ejp.v
break
}
fallthrough
case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("} after value", t)
}
default:
return nil, invalidRequestError(t.String())
}
case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer:
if ejp.s != jpsSawKey {
return nil, invalidRequestError(t.String())
}
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
if t == bsontype.Binary && ejp.s == jpsSawValue {
// convert relaxed $uuid format
if ejp.relaxedUUID {
defer func() { ejp.relaxedUUID = false }()
uuid, err := ejp.v.parseSymbol()
if err != nil {
return nil, err
}
// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing
// in the 8th, 13th, 18th, and 23rd characters.
//
// See https://tools.ietf.org/html/rfc4122#section-3
valid := len(uuid) == 36 &&
string(uuid[8]) == "-" &&
string(uuid[13]) == "-" &&
string(uuid[18]) == "-" &&
string(uuid[23]) == "-"
if !valid {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// remove hyphens
uuidNoHyphens := strings.Replace(uuid, "-", "", -1)
if len(uuidNoHyphens) != 32 {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens")
}
// convert hex to bytes
bytes, err := hex.DecodeString(uuidNoHyphens)
if err != nil {
return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("$uuid and value and then }", bsontype.Binary)
}
base64 := &extJSONValue{
t: bsontype.String,
v: base64.StdEncoding.EncodeToString(bytes),
}
subType := &extJSONValue{
t: bsontype.String,
v: "04",
}
v = &extJSONValue{
t: bsontype.EmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// convert legacy $binary format
base64 := ejp.v
ejp.advanceState()
if ejp.s != jpsSawComma {
return nil, invalidJSONErrorForType(",", bsontype.Binary)
}
ejp.advanceState()
key, t, err := ejp.readKey()
if err != nil {
return nil, err
}
if key != "$type" {
return nil, invalidJSONErrorForType("$type", bsontype.Binary)
}
subType, err := ejp.readValue(t)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary)
}
v = &extJSONValue{
t: bsontype.EmbeddedDocument,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{base64, subType},
},
}
break
}
// read KV pairs
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONErrorForType("{", t)
}
keys, vals, err := ejp.readObject(2, true)
if err != nil {
return nil, err
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("2 key-value pairs and then }", t)
}
v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case bsontype.DateTime:
switch ejp.s {
case jpsSawValue:
v = ejp.v
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
ejp.advanceState()
switch ejp.s {
case jpsSawBeginObject:
keys, vals, err := ejp.readObject(1, true)
if err != nil {
return nil, err
}
v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}}
case jpsSawValue:
if ejp.canonical {
return nil, invalidJSONError("{")
}
v = ejp.v
default:
if ejp.canonical {
return nil, invalidJSONErrorForType("object", t)
}
return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t)
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, invalidJSONErrorForType("value and then }", t)
}
default:
return nil, invalidRequestError(t.String())
}
case bsontype.JavaScript:
switch ejp.s {
case jpsSawKey:
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read value
ejp.advanceState()
if ejp.s != jpsSawValue {
return nil, invalidJSONErrorForType("value", t)
}
v = ejp.v
// read end object or comma and just return
ejp.advanceState()
case jpsSawEndObject:
v = ejp.v
default:
return nil, invalidRequestError(t.String())
}
case bsontype.CodeWithScope:
if ejp.s == jpsSawKey && ejp.k == "$scope" {
v = ejp.v // this is the $code string from earlier
// read colon
ejp.advanceState()
if err := ensureColon(ejp.s, ejp.k); err != nil {
return nil, err
}
// read {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, invalidJSONError("$scope to be embedded document")
}
} else {
return nil, invalidRequestError(t.String())
}
case bsontype.EmbeddedDocument, bsontype.Array:
return nil, invalidRequestError(t.String())
}
return v, nil
}
// readObject is a utility method for reading full objects of known (or expected) size
// it is useful for extended JSON types such as binary, datetime, regex, and timestamp
func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) {
keys := make([]string, numKeys)
vals := make([]*extJSONValue, numKeys)
if !started {
ejp.advanceState()
if ejp.s != jpsSawBeginObject {
return nil, nil, invalidJSONError("{")
}
}
for i := 0; i < numKeys; i++ {
key, t, err := ejp.readKey()
if err != nil {
return nil, nil, err
}
switch ejp.s {
case jpsSawKey:
v, err := ejp.readValue(t)
if err != nil {
return nil, nil, err
}
keys[i] = key
vals[i] = v
case jpsSawValue:
keys[i] = key
vals[i] = ejp.v
default:
return nil, nil, invalidJSONError("value")
}
}
ejp.advanceState()
if ejp.s != jpsSawEndObject {
return nil, nil, invalidJSONError("}")
}
return keys, vals, nil
}
// advanceState reads the next JSON token from the scanner and transitions
// from the current state based on that token's type
func (ejp *extJSONParser) advanceState() {
if ejp.s == jpsDoneState || ejp.s == jpsInvalidState {
return
}
jt, err := ejp.js.nextToken()
if err != nil {
ejp.err = err
ejp.s = jpsInvalidState
return
}
valid := ejp.validateToken(jt.t)
if !valid {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
return
}
switch jt.t {
case jttBeginObject:
ejp.s = jpsSawBeginObject
ejp.pushMode(jpmObjectMode)
ejp.depth++
if ejp.depth > ejp.maxDepth {
ejp.err = nestingDepthError(jt.p, ejp.depth)
ejp.s = jpsInvalidState
}
case jttEndObject:
ejp.s = jpsSawEndObject
ejp.depth--
if ejp.popMode() != jpmObjectMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttBeginArray:
ejp.s = jpsSawBeginArray
ejp.pushMode(jpmArrayMode)
case jttEndArray:
ejp.s = jpsSawEndArray
if ejp.popMode() != jpmArrayMode {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttColon:
ejp.s = jpsSawColon
case jttComma:
ejp.s = jpsSawComma
case jttEOF:
ejp.s = jpsDoneState
if len(ejp.m) != 0 {
ejp.err = unexpectedTokenError(jt)
ejp.s = jpsInvalidState
}
case jttString:
switch ejp.s {
case jpsSawComma:
if ejp.peekMode() == jpmArrayMode {
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
return
}
fallthrough
case jpsSawBeginObject:
ejp.s = jpsSawKey
ejp.k = jt.v.(string)
return
}
fallthrough
default:
ejp.s = jpsSawValue
ejp.v = extendJSONToken(jt)
}
}
var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{
jpsStartState: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
jttEOF: true,
},
jpsSawBeginObject: {
jttEndObject: true,
jttString: true,
},
jpsSawEndObject: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawBeginArray: {
jttBeginObject: true,
jttBeginArray: true,
jttEndArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawEndArray: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsSawColon: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawComma: {
jttBeginObject: true,
jttBeginArray: true,
jttInt32: true,
jttInt64: true,
jttDouble: true,
jttString: true,
jttBool: true,
jttNull: true,
},
jpsSawKey: {
jttColon: true,
},
jpsSawValue: {
jttEndObject: true,
jttEndArray: true,
jttComma: true,
jttEOF: true,
},
jpsDoneState: {},
jpsInvalidState: {},
}
func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool {
switch ejp.s {
case jpsSawEndObject:
// if we are at depth zero and the next token is a '{',
// we can consider it valid only if we are not in array mode.
if jtt == jttBeginObject && ejp.depth == 0 {
return ejp.peekMode() != jpmArrayMode
}
case jpsSawComma:
switch ejp.peekMode() {
// the only valid next token after a comma inside a document is a string (a key)
case jpmObjectMode:
return jtt == jttString
case jpmInvalidMode:
return false
}
}
_, ok := jpsValidTransitionTokens[ejp.s][jtt]
return ok
}
// ensureExtValueType returns true if the current value has the expected
// value type for single-key extended JSON types. For example,
// {"$numberInt": v} v must be TypeString
func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool {
switch t {
case bsontype.MinKey, bsontype.MaxKey:
return ejp.v.t == bsontype.Int32
case bsontype.Undefined:
return ejp.v.t == bsontype.Boolean
case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID:
return ejp.v.t == bsontype.String
default:
return false
}
}
func (ejp *extJSONParser) pushMode(m jsonParseMode) {
ejp.m = append(ejp.m, m)
}
func (ejp *extJSONParser) popMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
m := ejp.m[l-1]
ejp.m = ejp.m[:l-1]
return m
}
func (ejp *extJSONParser) peekMode() jsonParseMode {
l := len(ejp.m)
if l == 0 {
return jpmInvalidMode
}
return ejp.m[l-1]
}
func extendJSONToken(jt *jsonToken) *extJSONValue {
var t bsontype.Type
switch jt.t {
case jttInt32:
t = bsontype.Int32
case jttInt64:
t = bsontype.Int64
case jttDouble:
t = bsontype.Double
case jttString:
t = bsontype.String
case jttBool:
t = bsontype.Boolean
case jttNull:
t = bsontype.Null
default:
return nil
}
return &extJSONValue{t: t, v: jt.v}
}
func ensureColon(s jsonParseState, key string) error {
if s != jpsSawColon {
return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key)
}
return nil
}
func invalidRequestError(s string) error {
return fmt.Errorf("invalid request to read %s", s)
}
func invalidJSONError(expected string) error {
return fmt.Errorf("invalid JSON input; expected %s", expected)
}
func invalidJSONErrorForType(expected string, t bsontype.Type) error {
return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t)
}
func unexpectedTokenError(jt *jsonToken) error {
switch jt.t {
case jttInt32, jttInt64, jttDouble:
return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p)
case jttString:
return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p)
case jttBool:
return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p)
case jttNull:
return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p)
case jttEOF:
return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p)
default:
return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p)
}
}
func nestingDepthError(p, depth int) error {
return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p)
}

View File

@@ -0,0 +1,788 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"io"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var (
keyDiff = specificDiff("key")
typDiff = specificDiff("type")
valDiff = specificDiff("value")
expectErrEOF = expectSpecificError(io.EOF)
expectErrEOD = expectSpecificError(ErrEOD)
expectErrEOA = expectSpecificError(ErrEOA)
)
type expectedErrorFunc func(t *testing.T, err error, desc string)
type peekTypeTestCase struct {
desc string
input string
typs []bsontype.Type
errFs []expectedErrorFunc
}
type readKeyValueTestCase struct {
desc string
input string
keys []string
typs []bsontype.Type
vals []*extJSONValue
keyEFs []expectedErrorFunc
valEFs []expectedErrorFunc
}
func expectSpecificError(expected error) expectedErrorFunc {
return func(t *testing.T, err error, desc string) {
if err != expected {
t.Helper()
t.Errorf("%s: Expected %v but got: %v", desc, expected, err)
t.FailNow()
}
}
}
func specificDiff(name string) func(t *testing.T, expected, actual interface{}, desc string) {
return func(t *testing.T, expected, actual interface{}, desc string) {
if diff := cmp.Diff(expected, actual); diff != "" {
t.Helper()
t.Errorf("%s: Incorrect JSON %s (-want, +got): %s\n", desc, name, diff)
t.FailNow()
}
}
}
func expectErrorNOOP(_ *testing.T, _ error, _ string) {
}
func readKeyDiff(t *testing.T, eKey, aKey string, eTyp, aTyp bsontype.Type, err error, errF expectedErrorFunc, desc string) {
keyDiff(t, eKey, aKey, desc)
typDiff(t, eTyp, aTyp, desc)
errF(t, err, desc)
}
func readValueDiff(t *testing.T, eVal, aVal *extJSONValue, err error, errF expectedErrorFunc, desc string) {
if aVal != nil {
typDiff(t, eVal.t, aVal.t, desc)
valDiff(t, eVal.v, aVal.v, desc)
} else {
valDiff(t, eVal, aVal, desc)
}
errF(t, err, desc)
}
func TestExtJSONParserPeekType(t *testing.T) {
makeValidPeekTypeTestCase := func(input string, typ bsontype.Type, desc string) peekTypeTestCase {
return peekTypeTestCase{
desc: desc, input: input,
typs: []bsontype.Type{typ},
errFs: []expectedErrorFunc{expectNoError},
}
}
makeInvalidTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
return peekTypeTestCase{
desc: desc, input: input,
typs: []bsontype.Type{bsontype.Type(0)},
errFs: []expectedErrorFunc{lastEF},
}
}
makeInvalidPeekTypeTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase {
return peekTypeTestCase{
desc: desc, input: input,
typs: []bsontype.Type{bsontype.Array, bsontype.String, bsontype.Type(0)},
errFs: []expectedErrorFunc{expectNoError, expectNoError, lastEF},
}
}
cases := []peekTypeTestCase{
makeValidPeekTypeTestCase(`null`, bsontype.Null, "Null"),
makeValidPeekTypeTestCase(`"string"`, bsontype.String, "String"),
makeValidPeekTypeTestCase(`true`, bsontype.Boolean, "Boolean--true"),
makeValidPeekTypeTestCase(`false`, bsontype.Boolean, "Boolean--false"),
makeValidPeekTypeTestCase(`{"$minKey": 1}`, bsontype.MinKey, "MinKey"),
makeValidPeekTypeTestCase(`{"$maxKey": 1}`, bsontype.MaxKey, "MaxKey"),
makeValidPeekTypeTestCase(`{"$numberInt": "42"}`, bsontype.Int32, "Int32"),
makeValidPeekTypeTestCase(`{"$numberLong": "42"}`, bsontype.Int64, "Int64"),
makeValidPeekTypeTestCase(`{"$symbol": "symbol"}`, bsontype.Symbol, "Symbol"),
makeValidPeekTypeTestCase(`{"$numberDouble": "42.42"}`, bsontype.Double, "Double"),
makeValidPeekTypeTestCase(`{"$undefined": true}`, bsontype.Undefined, "Undefined"),
makeValidPeekTypeTestCase(`{"$numberDouble": "NaN"}`, bsontype.Double, "Double--NaN"),
makeValidPeekTypeTestCase(`{"$numberDecimal": "1234"}`, bsontype.Decimal128, "Decimal"),
makeValidPeekTypeTestCase(`{"foo": "bar"}`, bsontype.EmbeddedDocument, "Toplevel document"),
makeValidPeekTypeTestCase(`{"$date": {"$numberLong": "0"}}`, bsontype.DateTime, "Datetime"),
makeValidPeekTypeTestCase(`{"$code": "function() {}"}`, bsontype.JavaScript, "Code no scope"),
makeValidPeekTypeTestCase(`[{"$numberInt": "1"},{"$numberInt": "2"}]`, bsontype.Array, "Array"),
makeValidPeekTypeTestCase(`{"$timestamp": {"t": 42, "i": 1}}`, bsontype.Timestamp, "Timestamp"),
makeValidPeekTypeTestCase(`{"$oid": "57e193d7a9cc81b4027498b5"}`, bsontype.ObjectID, "Object ID"),
makeValidPeekTypeTestCase(`{"$binary": {"base64": "AQIDBAU=", "subType": "80"}}`, bsontype.Binary, "Binary"),
makeValidPeekTypeTestCase(`{"$code": "function() {}", "$scope": {}}`, bsontype.CodeWithScope, "Code With Scope"),
makeValidPeekTypeTestCase(`{"$binary": {"base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03"}}`, bsontype.Binary, "Binary"),
makeValidPeekTypeTestCase(`{"$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03"}`, bsontype.Binary, "Binary"),
makeValidPeekTypeTestCase(`{"$regularExpression": {"pattern": "foo*", "options": "ix"}}`, bsontype.Regex, "Regular expression"),
makeValidPeekTypeTestCase(`{"$dbPointer": {"$ref": "db.collection", "$id": {"$oid": "57e193d7a9cc81b4027498b1"}}}`, bsontype.DBPointer, "DBPointer"),
makeValidPeekTypeTestCase(`{"$ref": "collection", "$id": {"$oid": "57fd71e96e32ab4225b723fb"}, "$db": "database"}`, bsontype.EmbeddedDocument, "DBRef"),
makeInvalidPeekTypeTestCase("invalid array--missing ]", `["a"`, expectError),
makeInvalidPeekTypeTestCase("invalid array--colon in array", `["a":`, expectError),
makeInvalidPeekTypeTestCase("invalid array--extra comma", `["a",,`, expectError),
makeInvalidPeekTypeTestCase("invalid array--trailing comma", `["a",]`, expectError),
makeInvalidPeekTypeTestCase("peekType after end of array", `["a"]`, expectErrEOA),
{
desc: "invalid array--leading comma",
input: `[,`,
typs: []bsontype.Type{bsontype.Array, bsontype.Type(0)},
errFs: []expectedErrorFunc{expectNoError, expectError},
},
makeInvalidTestCase("lone $scope", `{"$scope": {}}`, expectError),
makeInvalidTestCase("empty code with unknown extra key", `{"$code":"", "0":""}`, expectError),
makeInvalidTestCase("non-empty code with unknown extra key", `{"$code":"foobar", "0":""}`, expectError),
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
// Manually set the parser's starting state to jpsSawColon so peekType will read ahead to find the extjson
// type of the value. If not set, the parser will be in jpsStartState and advance to jpsSawKey, which will
// cause it to return without peeking the extjson type.
ejp.s = jpsSawColon
for i, eTyp := range tc.typs {
errF := tc.errFs[i]
typ, err := ejp.peekType()
errF(t, err, tc.desc)
if err != nil {
// Don't inspect the type if there was an error
return
}
typDiff(t, eTyp, typ, tc.desc)
}
})
}
}
func TestExtJSONParserReadKeyReadValue(t *testing.T) {
// several test cases will use the same keys, types, and values, and only differ on input structure
keys := []string{"_id", "Symbol", "String", "Int32", "Int64", "Int", "MinKey"}
types := []bsontype.Type{bsontype.ObjectID, bsontype.Symbol, bsontype.String, bsontype.Int32, bsontype.Int64, bsontype.Int32, bsontype.MinKey}
values := []*extJSONValue{
{t: bsontype.String, v: "57e193d7a9cc81b4027498b5"},
{t: bsontype.String, v: "symbol"},
{t: bsontype.String, v: "string"},
{t: bsontype.String, v: "42"},
{t: bsontype.String, v: "42"},
{t: bsontype.Int32, v: int32(42)},
{t: bsontype.Int32, v: int32(1)},
}
errFuncs := make([]expectedErrorFunc, 7)
for i := 0; i < 7; i++ {
errFuncs[i] = expectNoError
}
firstKeyError := func(desc, input string) readKeyValueTestCase {
return readKeyValueTestCase{
desc: desc,
input: input,
keys: []string{""},
typs: []bsontype.Type{bsontype.Type(0)},
vals: []*extJSONValue{nil},
keyEFs: []expectedErrorFunc{expectError},
valEFs: []expectedErrorFunc{expectErrorNOOP},
}
}
secondKeyError := func(desc, input, firstKey string, firstType bsontype.Type, firstValue *extJSONValue) readKeyValueTestCase {
return readKeyValueTestCase{
desc: desc,
input: input,
keys: []string{firstKey, ""},
typs: []bsontype.Type{firstType, bsontype.Type(0)},
vals: []*extJSONValue{firstValue, nil},
keyEFs: []expectedErrorFunc{expectNoError, expectError},
valEFs: []expectedErrorFunc{expectNoError, expectErrorNOOP},
}
}
cases := []readKeyValueTestCase{
{
desc: "normal spacing",
input: `{
"_id": { "$oid": "57e193d7a9cc81b4027498b5" },
"Symbol": { "$symbol": "symbol" },
"String": "string",
"Int32": { "$numberInt": "42" },
"Int64": { "$numberLong": "42" },
"Int": 42,
"MinKey": { "$minKey": 1 }
}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "new line before comma",
input: `{ "_id": { "$oid": "57e193d7a9cc81b4027498b5" }
, "Symbol": { "$symbol": "symbol" }
, "String": "string"
, "Int32": { "$numberInt": "42" }
, "Int64": { "$numberLong": "42" }
, "Int": 42
, "MinKey": { "$minKey": 1 }
}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "tabs around colons",
input: `{
"_id": { "$oid" : "57e193d7a9cc81b4027498b5" },
"Symbol": { "$symbol" : "symbol" },
"String": "string",
"Int32": { "$numberInt" : "42" },
"Int64": { "$numberLong": "42" },
"Int": 42,
"MinKey": { "$minKey": 1 }
}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "no whitespace",
input: `{"_id":{"$oid":"57e193d7a9cc81b4027498b5"},"Symbol":{"$symbol":"symbol"},"String":"string","Int32":{"$numberInt":"42"},"Int64":{"$numberLong":"42"},"Int":42,"MinKey":{"$minKey":1}}`,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "mixed whitespace",
input: ` {
"_id" : { "$oid": "57e193d7a9cc81b4027498b5" },
"Symbol" : { "$symbol": "symbol" } ,
"String" : "string",
"Int32" : { "$numberInt": "42" } ,
"Int64" : {"$numberLong" : "42"},
"Int" : 42,
"MinKey" : { "$minKey": 1 } } `,
keys: keys, typs: types, vals: values,
keyEFs: errFuncs, valEFs: errFuncs,
},
{
desc: "nested object",
input: `{"k1": 1, "k2": { "k3": { "k4": 4 } }, "k5": 5}`,
keys: []string{"k1", "k2", "k3", "k4", "", "", "k5", ""},
typs: []bsontype.Type{bsontype.Int32, bsontype.EmbeddedDocument, bsontype.EmbeddedDocument, bsontype.Int32, bsontype.Type(0), bsontype.Type(0), bsontype.Int32, bsontype.Type(0)},
vals: []*extJSONValue{
{t: bsontype.Int32, v: int32(1)}, nil, nil, {t: bsontype.Int32, v: int32(4)}, nil, nil, {t: bsontype.Int32, v: int32(5)}, nil,
},
keyEFs: []expectedErrorFunc{
expectNoError, expectNoError, expectNoError, expectNoError, expectErrEOD,
expectErrEOD, expectNoError, expectErrEOD,
},
valEFs: []expectedErrorFunc{
expectNoError, expectError, expectError, expectNoError, expectErrorNOOP,
expectErrorNOOP, expectNoError, expectErrorNOOP,
},
},
{
desc: "invalid input: invalid values for extended type",
input: `{"a": {"$numberInt": "1", "x"`,
keys: []string{"a"},
typs: []bsontype.Type{bsontype.Int32},
vals: []*extJSONValue{nil},
keyEFs: []expectedErrorFunc{expectNoError},
valEFs: []expectedErrorFunc{expectError},
},
firstKeyError("invalid input: missing key--EOF", "{"),
firstKeyError("invalid input: missing key--colon first", "{:"),
firstKeyError("invalid input: missing value", `{"a":`),
firstKeyError("invalid input: missing colon", `{"a" 1`),
firstKeyError("invalid input: extra colon", `{"a"::`),
secondKeyError("invalid input: missing }", `{"a": 1`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}),
secondKeyError("invalid input: missing comma", `{"a": 1 "b"`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}),
secondKeyError("invalid input: extra comma", `{"a": 1,, "b"`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}),
secondKeyError("invalid input: trailing comma in object", `{"a": 1,}`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}),
{
desc: "invalid input: lone scope after a complete value",
input: `{"a": "", "b": {"$scope: ""}}`,
keys: []string{"a"},
typs: []bsontype.Type{bsontype.String},
vals: []*extJSONValue{{bsontype.String, ""}},
keyEFs: []expectedErrorFunc{expectNoError, expectNoError},
valEFs: []expectedErrorFunc{expectNoError, expectError},
},
{
desc: "invalid input: lone scope nested",
input: `{"a":{"b":{"$scope":{`,
keys: []string{},
typs: []bsontype.Type{},
vals: []*extJSONValue{nil},
keyEFs: []expectedErrorFunc{expectNoError},
valEFs: []expectedErrorFunc{expectError},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ejp := newExtJSONParser(strings.NewReader(tc.input), true)
for i, eKey := range tc.keys {
eTyp := tc.typs[i]
eVal := tc.vals[i]
keyErrF := tc.keyEFs[i]
valErrF := tc.valEFs[i]
k, typ, err := ejp.readKey()
readKeyDiff(t, eKey, k, eTyp, typ, err, keyErrF, tc.desc)
v, err := ejp.readValue(typ)
readValueDiff(t, eVal, v, err, valErrF, tc.desc)
}
})
}
}
type ejpExpectationTest func(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{})
type ejpTestCase struct {
f ejpExpectationTest
p *extJSONParser
k string
t bsontype.Type
v interface{}
}
// expectSingleValue is used for simple JSON types (strings, numbers, literals) and for extended JSON types that
// have single key-value pairs (i.e. { "$minKey": 1 }, { "$numberLong": "42.42" })
func expectSingleValue(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) {
eVal := expectedValue.(*extJSONValue)
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
v, err := p.readValue(typ)
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
}
// expectMultipleValues is used for values that are subdocuments of known size and with known keys (such as extended
// JSON types { "$timestamp": {"t": 1, "i": 1} } and { "$regularExpression": {"pattern": "", options: ""} })
func expectMultipleValues(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) {
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
v, err := p.readValue(typ)
expectNoError(t, err, "")
typDiff(t, bsontype.EmbeddedDocument, v.t, expectedKey)
actObj := v.v.(*extJSONObject)
expObj := expectedValue.(*extJSONObject)
for i, actKey := range actObj.keys {
expKey := expObj.keys[i]
actVal := actObj.values[i]
expVal := expObj.values[i]
keyDiff(t, expKey, actKey, expectedKey)
typDiff(t, expVal.t, actVal.t, expectedKey)
valDiff(t, expVal.v, actVal.v, expectedKey)
}
}
type ejpKeyTypValTriple struct {
key string
typ bsontype.Type
val *extJSONValue
}
type ejpSubDocumentTestValue struct {
code string // code is only used for TypeCodeWithScope (and is ignored for TypeEmbeddedDocument
ktvs []ejpKeyTypValTriple // list of (key, type, value) triples; this is "scope" for TypeCodeWithScope
}
// expectSubDocument is used for embedded documents and code with scope types; it reads all the keys and values
// in the embedded document (or scope for codeWithScope) and compares them to the expectedValue's list of (key, type,
// value) triples
func expectSubDocument(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) {
subdoc := expectedValue.(ejpSubDocumentTestValue)
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey)
if expectedType == bsontype.CodeWithScope {
v, err := p.readValue(typ)
readValueDiff(t, &extJSONValue{t: bsontype.String, v: subdoc.code}, v, err, expectNoError, expectedKey)
}
for _, ktv := range subdoc.ktvs {
eKey := ktv.key
eTyp := ktv.typ
eVal := ktv.val
k, typ, err = p.readKey()
readKeyDiff(t, eKey, k, eTyp, typ, err, expectNoError, expectedKey)
v, err := p.readValue(typ)
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
}
if expectedType == bsontype.CodeWithScope {
// expect scope doc to close
k, typ, err = p.readKey()
readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOD, expectedKey)
}
// expect subdoc to close
k, typ, err = p.readKey()
readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOD, expectedKey)
}
// expectArray takes the expectedKey, ignores the expectedType, and uses the expectedValue
// as a slice of (type Type, value *extJSONValue) pairs
func expectArray(t *testing.T, p *extJSONParser, expectedKey string, _ bsontype.Type, expectedValue interface{}) {
ktvs := expectedValue.([]ejpKeyTypValTriple)
k, typ, err := p.readKey()
readKeyDiff(t, expectedKey, k, bsontype.Array, typ, err, expectNoError, expectedKey)
for _, ktv := range ktvs {
eTyp := ktv.typ
eVal := ktv.val
typ, err = p.peekType()
typDiff(t, eTyp, typ, expectedKey)
expectNoError(t, err, expectedKey)
v, err := p.readValue(typ)
readValueDiff(t, eVal, v, err, expectNoError, expectedKey)
}
// expect array to end
typ, err = p.peekType()
typDiff(t, bsontype.Type(0), typ, expectedKey)
expectErrEOA(t, err, expectedKey)
}
func TestExtJSONParserAllTypes(t *testing.T) {
in := ` { "_id" : { "$oid": "57e193d7a9cc81b4027498b5"}
, "Symbol" : { "$symbol": "symbol"}
, "String" : "string"
, "Int32" : { "$numberInt": "42"}
, "Int64" : { "$numberLong": "42"}
, "Double" : { "$numberDouble": "42.42"}
, "SpecialFloat" : { "$numberDouble": "NaN" }
, "Decimal" : { "$numberDecimal": "1234" }
, "Binary" : { "$binary": { "base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03" } }
, "BinaryLegacy" : { "$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03" }
, "BinaryUserDefined" : { "$binary": { "base64": "AQIDBAU=", "subType": "80" } }
, "Code" : { "$code": "function() {}" }
, "CodeWithEmptyScope" : { "$code": "function() {}", "$scope": {} }
, "CodeWithScope" : { "$code": "function() {}", "$scope": { "x": 1 } }
, "EmptySubdocument" : {}
, "Subdocument" : { "foo": "bar", "baz": { "$numberInt": "42" } }
, "Array" : [{"$numberInt": "1"}, {"$numberLong": "2"}, {"$numberDouble": "3"}, 4, "string", 5.0]
, "Timestamp" : { "$timestamp": { "t": 42, "i": 1 } }
, "RegularExpression" : { "$regularExpression": { "pattern": "foo*", "options": "ix" } }
, "DatetimeEpoch" : { "$date": { "$numberLong": "0" } }
, "DatetimePositive" : { "$date": { "$numberLong": "9223372036854775807" } }
, "DatetimeNegative" : { "$date": { "$numberLong": "-9223372036854775808" } }
, "True" : true
, "False" : false
, "DBPointer" : { "$dbPointer": { "$ref": "db.collection", "$id": { "$oid": "57e193d7a9cc81b4027498b1" } } }
, "DBRef" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" }, "$db": "database" }
, "DBRefNoDB" : { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" } }
, "MinKey" : { "$minKey": 1 }
, "MaxKey" : { "$maxKey": 1 }
, "Null" : null
, "Undefined" : { "$undefined": true }
}`
ejp := newExtJSONParser(strings.NewReader(in), true)
cases := []ejpTestCase{
{
f: expectSingleValue, p: ejp,
k: "_id", t: bsontype.ObjectID, v: &extJSONValue{t: bsontype.String, v: "57e193d7a9cc81b4027498b5"},
},
{
f: expectSingleValue, p: ejp,
k: "Symbol", t: bsontype.Symbol, v: &extJSONValue{t: bsontype.String, v: "symbol"},
},
{
f: expectSingleValue, p: ejp,
k: "String", t: bsontype.String, v: &extJSONValue{t: bsontype.String, v: "string"},
},
{
f: expectSingleValue, p: ejp,
k: "Int32", t: bsontype.Int32, v: &extJSONValue{t: bsontype.String, v: "42"},
},
{
f: expectSingleValue, p: ejp,
k: "Int64", t: bsontype.Int64, v: &extJSONValue{t: bsontype.String, v: "42"},
},
{
f: expectSingleValue, p: ejp,
k: "Double", t: bsontype.Double, v: &extJSONValue{t: bsontype.String, v: "42.42"},
},
{
f: expectSingleValue, p: ejp,
k: "SpecialFloat", t: bsontype.Double, v: &extJSONValue{t: bsontype.String, v: "NaN"},
},
{
f: expectSingleValue, p: ejp,
k: "Decimal", t: bsontype.Decimal128, v: &extJSONValue{t: bsontype.String, v: "1234"},
},
{
f: expectMultipleValues, p: ejp,
k: "Binary", t: bsontype.Binary,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{
{t: bsontype.String, v: "o0w498Or7cijeBSpkquNtg=="},
{t: bsontype.String, v: "03"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "BinaryLegacy", t: bsontype.Binary,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{
{t: bsontype.String, v: "o0w498Or7cijeBSpkquNtg=="},
{t: bsontype.String, v: "03"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "BinaryUserDefined", t: bsontype.Binary,
v: &extJSONObject{
keys: []string{"base64", "subType"},
values: []*extJSONValue{
{t: bsontype.String, v: "AQIDBAU="},
{t: bsontype.String, v: "80"},
},
},
},
{
f: expectSingleValue, p: ejp,
k: "Code", t: bsontype.JavaScript, v: &extJSONValue{t: bsontype.String, v: "function() {}"},
},
{
f: expectSubDocument, p: ejp,
k: "CodeWithEmptyScope", t: bsontype.CodeWithScope,
v: ejpSubDocumentTestValue{
code: "function() {}",
ktvs: []ejpKeyTypValTriple{},
},
},
{
f: expectSubDocument, p: ejp,
k: "CodeWithScope", t: bsontype.CodeWithScope,
v: ejpSubDocumentTestValue{
code: "function() {}",
ktvs: []ejpKeyTypValTriple{
{"x", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}},
},
},
},
{
f: expectSubDocument, p: ejp,
k: "EmptySubdocument", t: bsontype.EmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{},
},
},
{
f: expectSubDocument, p: ejp,
k: "Subdocument", t: bsontype.EmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{
{"foo", bsontype.String, &extJSONValue{t: bsontype.String, v: "bar"}},
{"baz", bsontype.Int32, &extJSONValue{t: bsontype.String, v: "42"}},
},
},
},
{
f: expectArray, p: ejp,
k: "Array", t: bsontype.Array,
v: []ejpKeyTypValTriple{
{typ: bsontype.Int32, val: &extJSONValue{t: bsontype.String, v: "1"}},
{typ: bsontype.Int64, val: &extJSONValue{t: bsontype.String, v: "2"}},
{typ: bsontype.Double, val: &extJSONValue{t: bsontype.String, v: "3"}},
{typ: bsontype.Int32, val: &extJSONValue{t: bsontype.Int32, v: int32(4)}},
{typ: bsontype.String, val: &extJSONValue{t: bsontype.String, v: "string"}},
{typ: bsontype.Double, val: &extJSONValue{t: bsontype.Double, v: 5.0}},
},
},
{
f: expectMultipleValues, p: ejp,
k: "Timestamp", t: bsontype.Timestamp,
v: &extJSONObject{
keys: []string{"t", "i"},
values: []*extJSONValue{
{t: bsontype.Int32, v: int32(42)},
{t: bsontype.Int32, v: int32(1)},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "RegularExpression", t: bsontype.Regex,
v: &extJSONObject{
keys: []string{"pattern", "options"},
values: []*extJSONValue{
{t: bsontype.String, v: "foo*"},
{t: bsontype.String, v: "ix"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "DatetimeEpoch", t: bsontype.DateTime,
v: &extJSONObject{
keys: []string{"$numberLong"},
values: []*extJSONValue{
{t: bsontype.String, v: "0"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "DatetimePositive", t: bsontype.DateTime,
v: &extJSONObject{
keys: []string{"$numberLong"},
values: []*extJSONValue{
{t: bsontype.String, v: "9223372036854775807"},
},
},
},
{
f: expectMultipleValues, p: ejp,
k: "DatetimeNegative", t: bsontype.DateTime,
v: &extJSONObject{
keys: []string{"$numberLong"},
values: []*extJSONValue{
{t: bsontype.String, v: "-9223372036854775808"},
},
},
},
{
f: expectSingleValue, p: ejp,
k: "True", t: bsontype.Boolean, v: &extJSONValue{t: bsontype.Boolean, v: true},
},
{
f: expectSingleValue, p: ejp,
k: "False", t: bsontype.Boolean, v: &extJSONValue{t: bsontype.Boolean, v: false},
},
{
f: expectMultipleValues, p: ejp,
k: "DBPointer", t: bsontype.DBPointer,
v: &extJSONObject{
keys: []string{"$ref", "$id"},
values: []*extJSONValue{
{t: bsontype.String, v: "db.collection"},
{t: bsontype.String, v: "57e193d7a9cc81b4027498b1"},
},
},
},
{
f: expectSubDocument, p: ejp,
k: "DBRef", t: bsontype.EmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{
{"$ref", bsontype.String, &extJSONValue{t: bsontype.String, v: "collection"}},
{"$id", bsontype.ObjectID, &extJSONValue{t: bsontype.String, v: "57fd71e96e32ab4225b723fb"}},
{"$db", bsontype.String, &extJSONValue{t: bsontype.String, v: "database"}},
},
},
},
{
f: expectSubDocument, p: ejp,
k: "DBRefNoDB", t: bsontype.EmbeddedDocument,
v: ejpSubDocumentTestValue{
ktvs: []ejpKeyTypValTriple{
{"$ref", bsontype.String, &extJSONValue{t: bsontype.String, v: "collection"}},
{"$id", bsontype.ObjectID, &extJSONValue{t: bsontype.String, v: "57fd71e96e32ab4225b723fb"}},
},
},
},
{
f: expectSingleValue, p: ejp,
k: "MinKey", t: bsontype.MinKey, v: &extJSONValue{t: bsontype.Int32, v: int32(1)},
},
{
f: expectSingleValue, p: ejp,
k: "MaxKey", t: bsontype.MaxKey, v: &extJSONValue{t: bsontype.Int32, v: int32(1)},
},
{
f: expectSingleValue, p: ejp,
k: "Null", t: bsontype.Null, v: &extJSONValue{t: bsontype.Null, v: nil},
},
{
f: expectSingleValue, p: ejp,
k: "Undefined", t: bsontype.Undefined, v: &extJSONValue{t: bsontype.Boolean, v: true},
},
}
// run the test cases
for _, tc := range cases {
tc.f(t, tc.p, tc.k, tc.t, tc.v)
}
// expect end of whole document: read final }
k, typ, err := ejp.readKey()
readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOD, "")
// expect end of whole document: read EOF
k, typ, err = ejp.readKey()
readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOF, "")
if diff := cmp.Diff(jpsDoneState, ejp.s); diff != "" {
t.Errorf("expected parser to be in done state but instead is in %v\n", ejp.s)
t.FailNow()
}
}
func TestExtJSONValue(t *testing.T) {
t.Run("Large Date", func(t *testing.T) {
val := &extJSONValue{
t: bsontype.String,
v: "3001-01-01T00:00:00Z",
}
intVal, err := val.parseDateTime()
if err != nil {
t.Fatalf("error parsing date time: %v", err)
}
if intVal <= 0 {
t.Fatalf("expected value above 0, got %v", intVal)
}
})
t.Run("fallback time format", func(t *testing.T) {
val := &extJSONValue{
t: bsontype.String,
v: "2019-06-04T14:54:31.416+0000",
}
_, err := val.parseDateTime()
if err != nil {
t.Fatalf("error parsing date time: %v", err)
}
})
}

View File

@@ -0,0 +1,644 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ExtJSONValueReaderPool is a pool for ValueReaders that read ExtJSON.
type ExtJSONValueReaderPool struct {
pool sync.Pool
}
// NewExtJSONValueReaderPool instantiates a new ExtJSONValueReaderPool.
func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool {
return &ExtJSONValueReaderPool{
pool: sync.Pool{
New: func() interface{} {
return new(extJSONValueReader)
},
},
}
}
// Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON.
func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) {
vr := bvrp.pool.Get().(*extJSONValueReader)
return vr.reset(r, canonical)
}
// Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing
// is inserted into the pool and ok will be false.
func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) {
bvr, ok := vr.(*extJSONValueReader)
if !ok {
return false
}
bvr, _ = bvr.reset(nil, false)
bvrp.pool.Put(bvr)
return true
}
type ejvrState struct {
mode mode
vType bsontype.Type
depth int
}
// extJSONValueReader is for reading extended JSON.
type extJSONValueReader struct {
p *extJSONParser
stack []ejvrState
frame int
}
// NewExtJSONValueReader creates a new ValueReader from a given io.Reader
// It will interpret the JSON of r as canonical or relaxed according to the
// given canonical flag
func NewExtJSONValueReader(r io.Reader, canonical bool) (ValueReader, error) {
return newExtJSONValueReader(r, canonical)
}
func newExtJSONValueReader(r io.Reader, canonical bool) (*extJSONValueReader, error) {
ejvr := new(extJSONValueReader)
return ejvr.reset(r, canonical)
}
func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) (*extJSONValueReader, error) {
p := newExtJSONParser(r, canonical)
typ, err := p.peekType()
if err != nil {
return nil, ErrInvalidJSON
}
var m mode
switch typ {
case bsontype.EmbeddedDocument:
m = mTopLevel
case bsontype.Array:
m = mArray
default:
m = mValue
}
stack := make([]ejvrState, 1, 5)
stack[0] = ejvrState{
mode: m,
vType: typ,
}
return &extJSONValueReader{
p: p,
stack: stack,
}, nil
}
func (ejvr *extJSONValueReader) advanceFrame() {
if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack
length := len(ejvr.stack)
if length+1 >= cap(ejvr.stack) {
// double it
buf := make([]ejvrState, 2*cap(ejvr.stack)+1)
copy(buf, ejvr.stack)
ejvr.stack = buf
}
ejvr.stack = ejvr.stack[:length+1]
}
ejvr.frame++
// Clean the stack
ejvr.stack[ejvr.frame].mode = 0
ejvr.stack[ejvr.frame].vType = 0
ejvr.stack[ejvr.frame].depth = 0
}
func (ejvr *extJSONValueReader) pushDocument() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mDocument
ejvr.stack[ejvr.frame].depth = ejvr.p.depth
}
func (ejvr *extJSONValueReader) pushCodeWithScope() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mCodeWithScope
}
func (ejvr *extJSONValueReader) pushArray() {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = mArray
}
func (ejvr *extJSONValueReader) push(m mode, t bsontype.Type) {
ejvr.advanceFrame()
ejvr.stack[ejvr.frame].mode = m
ejvr.stack[ejvr.frame].vType = t
}
func (ejvr *extJSONValueReader) pop() {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
ejvr.frame--
case mDocument, mArray, mCodeWithScope:
ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (ejvr *extJSONValueReader) skipObject() {
// read entire object until depth returns to 0 (last ending } or ] seen)
depth := 1
for depth > 0 {
ejvr.p.advanceState()
// If object is empty, raise depth and continue. When emptyObject is true, the
// parser has already read both the opening and closing brackets of an empty
// object ("{}"), so the next valid token will be part of the parent document,
// not part of the nested document.
//
// If there is a comma, there are remaining fields, emptyObject must be set back
// to false, and comma must be skipped with advanceState().
if ejvr.p.emptyObject {
if ejvr.p.s == jpsSawComma {
ejvr.p.emptyObject = false
ejvr.p.advanceState()
}
depth--
continue
}
switch ejvr.p.s {
case jpsSawBeginObject, jpsSawBeginArray:
depth++
case jpsSawEndObject, jpsSawEndArray:
depth--
}
}
}
func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvr.stack[ejvr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if ejvr.frame != 0 {
te.parent = ejvr.stack[ejvr.frame-1].mode
}
return te
}
func (ejvr *extJSONValueReader) typeError(t bsontype.Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t)
}
func (ejvr *extJSONValueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string, addModes ...mode) error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != t {
return ejvr.typeError(t)
}
default:
modes := []mode{mElement, mValue}
if addModes != nil {
modes = append(modes, addModes...)
}
return ejvr.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvr *extJSONValueReader) Type() bsontype.Type {
return ejvr.stack[ejvr.frame].vType
}
func (ejvr *extJSONValueReader) Skip() error {
switch ejvr.stack[ejvr.frame].mode {
case mElement, mValue:
default:
return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
defer ejvr.pop()
t := ejvr.stack[ejvr.frame].vType
switch t {
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
// read entire array, doc or CodeWithScope
ejvr.skipObject()
default:
_, err := ejvr.p.readValue(t)
if err != nil {
return err
}
}
return nil
}
func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel: // allow reading array from top level
case mArray:
return ejvr, nil
default:
if err := ejvr.ensureElementValue(bsontype.Array, mArray, "ReadArray", mTopLevel, mArray); err != nil {
return nil, err
}
}
ejvr.pushArray()
return ejvr, nil
}
func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := ejvr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
v, err := ejvr.p.readValue(bsontype.Binary)
if err != nil {
return nil, 0, err
}
b, btype, err = v.parseBinary()
ejvr.pop()
return b, btype, err
}
func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) {
if err := ejvr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil {
return false, err
}
v, err := ejvr.p.readValue(bsontype.Boolean)
if err != nil {
return false, err
}
if v.t != bsontype.Boolean {
return false, fmt.Errorf("expected type bool, but got type %s", v.t)
}
ejvr.pop()
return v.v.(bool), nil
}
func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel:
return ejvr, nil
case mElement, mValue:
if ejvr.stack[ejvr.frame].vType != bsontype.EmbeddedDocument {
return nil, ejvr.typeError(bsontype.EmbeddedDocument)
}
ejvr.pushDocument()
return ejvr, nil
default:
return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
}
func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err = ejvr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
v, err := ejvr.p.readValue(bsontype.CodeWithScope)
if err != nil {
return "", nil, err
}
code, err = v.parseJavascript()
ejvr.pushCodeWithScope()
return code, ejvr, err
}
func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
if err = ejvr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil {
return "", primitive.NilObjectID, err
}
v, err := ejvr.p.readValue(bsontype.DBPointer)
if err != nil {
return "", primitive.NilObjectID, err
}
ns, oid, err = v.parseDBPointer()
ejvr.pop()
return ns, oid, err
}
func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) {
if err := ejvr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.DateTime)
if err != nil {
return 0, err
}
d, err := v.parseDateTime()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDecimal128() (primitive.Decimal128, error) {
if err := ejvr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil {
return primitive.Decimal128{}, err
}
v, err := ejvr.p.readValue(bsontype.Decimal128)
if err != nil {
return primitive.Decimal128{}, err
}
d, err := v.parseDecimal128()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadDouble() (float64, error) {
if err := ejvr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Double)
if err != nil {
return 0, err
}
d, err := v.parseDouble()
ejvr.pop()
return d, err
}
func (ejvr *extJSONValueReader) ReadInt32() (int32, error) {
if err := ejvr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Int32)
if err != nil {
return 0, err
}
i, err := v.parseInt32()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadInt64() (int64, error) {
if err := ejvr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil {
return 0, err
}
v, err := ejvr.p.readValue(bsontype.Int64)
if err != nil {
return 0, err
}
i, err := v.parseInt64()
ejvr.pop()
return i, err
}
func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) {
if err = ejvr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.JavaScript)
if err != nil {
return "", err
}
code, err = v.parseJavascript()
ejvr.pop()
return code, err
}
func (ejvr *extJSONValueReader) ReadMaxKey() error {
if err := ejvr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.MaxKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("max")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadMinKey() error {
if err := ejvr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.MinKey)
if err != nil {
return err
}
err = v.parseMinMaxKey("min")
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadNull() error {
if err := ejvr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.Null)
if err != nil {
return err
}
if v.t != bsontype.Null {
return fmt.Errorf("expected type null but got type %s", v.t)
}
ejvr.pop()
return nil
}
func (ejvr *extJSONValueReader) ReadObjectID() (primitive.ObjectID, error) {
if err := ejvr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil {
return primitive.ObjectID{}, err
}
v, err := ejvr.p.readValue(bsontype.ObjectID)
if err != nil {
return primitive.ObjectID{}, err
}
oid, err := v.parseObjectID()
ejvr.pop()
return oid, err
}
func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) {
if err = ejvr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil {
return "", "", err
}
v, err := ejvr.p.readValue(bsontype.Regex)
if err != nil {
return "", "", err
}
pattern, options, err = v.parseRegex()
ejvr.pop()
return pattern, options, err
}
func (ejvr *extJSONValueReader) ReadString() (string, error) {
if err := ejvr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.String)
if err != nil {
return "", err
}
if v.t != bsontype.String {
return "", fmt.Errorf("expected type string but got type %s", v.t)
}
ejvr.pop()
return v.v.(string), nil
}
func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) {
if err = ejvr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil {
return "", err
}
v, err := ejvr.p.readValue(bsontype.Symbol)
if err != nil {
return "", err
}
symbol, err = v.parseSymbol()
ejvr.pop()
return symbol, err
}
func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err = ejvr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
v, err := ejvr.p.readValue(bsontype.Timestamp)
if err != nil {
return 0, 0, err
}
t, i, err = v.parseTimestamp()
ejvr.pop()
return t, i, err
}
func (ejvr *extJSONValueReader) ReadUndefined() error {
if err := ejvr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil {
return err
}
v, err := ejvr.p.readValue(bsontype.Undefined)
if err != nil {
return err
}
err = v.parseUndefined()
ejvr.pop()
return err
}
func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
name, t, err := ejvr.p.readKey()
if err != nil {
if err == ErrEOD {
if ejvr.stack[ejvr.frame].mode == mCodeWithScope {
_, err := ejvr.p.peekType()
if err != nil {
return "", nil, err
}
}
ejvr.pop()
}
return "", nil, err
}
ejvr.push(mElement, t)
return name, ejvr, nil
}
func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) {
switch ejvr.stack[ejvr.frame].mode {
case mArray:
default:
return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := ejvr.p.peekType()
if err != nil {
if err == ErrEOA {
ejvr.pop()
}
return nil, err
}
ejvr.push(mValue, t)
return ejvr, nil
}

View File

@@ -0,0 +1,168 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
func TestExtJSONReader(t *testing.T) {
t.Run("ReadDocument", func(t *testing.T) {
t.Run("EmbeddedDocument", func(t *testing.T) {
ejvr := &extJSONValueReader{
stack: []ejvrState{
{mode: mTopLevel},
{mode: mElement, vType: bsontype.Boolean},
},
frame: 1,
}
ejvr.stack[1].mode = mArray
wanterr := ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
_, err := ejvr.ReadDocument()
if err == nil || err.Error() != wanterr.Error() {
t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr)
}
})
})
t.Run("invalid transition", func(t *testing.T) {
t.Run("Skip", func(t *testing.T) {
ejvr := &extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}
wanterr := (&extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}).invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
goterr := ejvr.Skip()
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr)
}
})
})
}
func TestReadMultipleTopLevelDocuments(t *testing.T) {
testCases := []struct {
name string
input string
expected [][]byte
}{
{
"single top-level document",
"{\"foo\":1}",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
},
{
"single top-level document with leading and trailing whitespace",
"\n\n {\"foo\":1} \n",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
},
{
"two top-level documents",
"{\"foo\":1}{\"foo\":2}",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
},
},
{
"two top-level documents with leading and trailing whitespace and whitespace separation ",
"\n\n {\"foo\":1}\n{\"foo\":2}\n ",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
},
},
{
"top-level array with single document",
"[{\"foo\":1}]",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
},
},
{
"top-level array with 2 documents",
"[{\"foo\":1},{\"foo\":2}]",
[][]byte{
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := strings.NewReader(tc.input)
vr, err := NewExtJSONValueReader(r, false)
if err != nil {
t.Fatalf("expected no error, but got %v", err)
}
actual, err := readAllDocuments(vr)
if err != nil {
t.Fatalf("expected no error, but got %v", err)
}
if diff := cmp.Diff(tc.expected, actual); diff != "" {
t.Fatalf("expected does not match actual: %v", diff)
}
})
}
}
func readAllDocuments(vr ValueReader) ([][]byte, error) {
c := NewCopier()
var actual [][]byte
switch vr.Type() {
case bsontype.EmbeddedDocument:
for {
result, err := c.CopyDocumentToBytes(vr)
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
actual = append(actual, result)
}
case bsontype.Array:
ar, err := vr.ReadArray()
if err != nil {
return nil, err
}
for {
evr, err := ar.ReadValue()
if err != nil {
if err == ErrEOA {
break
}
return nil, err
}
result, err := c.CopyDocumentToBytes(evr)
if err != nil {
return nil, err
}
actual = append(actual, result)
}
default:
return nil, fmt.Errorf("expected an array or a document, but got %s", vr.Type())
}
return actual, nil
}

View File

@@ -0,0 +1,223 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on github.com/golang/go by The Go Authors
// See THIRD-PARTY-NOTICES for original license terms.
package bsonrw
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

View File

@@ -0,0 +1,492 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"encoding/base64"
"errors"
"fmt"
"math"
"strconv"
"time"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func wrapperKeyBSONType(key string) bsontype.Type {
switch key {
case "$numberInt":
return bsontype.Int32
case "$numberLong":
return bsontype.Int64
case "$oid":
return bsontype.ObjectID
case "$symbol":
return bsontype.Symbol
case "$numberDouble":
return bsontype.Double
case "$numberDecimal":
return bsontype.Decimal128
case "$binary":
return bsontype.Binary
case "$code":
return bsontype.JavaScript
case "$scope":
return bsontype.CodeWithScope
case "$timestamp":
return bsontype.Timestamp
case "$regularExpression":
return bsontype.Regex
case "$dbPointer":
return bsontype.DBPointer
case "$date":
return bsontype.DateTime
case "$minKey":
return bsontype.MinKey
case "$maxKey":
return bsontype.MaxKey
case "$undefined":
return bsontype.Undefined
}
return bsontype.EmbeddedDocument
}
func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t)
}
binObj := ejv.v.(*extJSONObject)
bFound := false
stFound := false
for i, key := range binObj.keys {
val := binObj.values[i]
switch key {
case "base64":
if bFound {
return nil, 0, errors.New("duplicate base64 key in $binary")
}
if val.t != bsontype.String {
return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t)
}
base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string))
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string))
}
b = base64Bytes
bFound = true
case "subType":
if stFound {
return nil, 0, errors.New("duplicate subType key in $binary")
}
if val.t != bsontype.String {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}
i, err := strconv.ParseInt(val.v.(string), 16, 64)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
}
subType = byte(i)
stFound = true
default:
return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key)
}
}
if !bFound {
return nil, 0, errors.New("missing base64 field in $binary object")
}
if !stFound {
return nil, 0, errors.New("missing subType field in $binary object")
}
return b, subType, nil
}
func (ejv *extJSONValue) parseDBPointer() (ns string, oid primitive.ObjectID, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t)
}
dbpObj := ejv.v.(*extJSONObject)
oidFound := false
nsFound := false
for i, key := range dbpObj.keys {
val := dbpObj.values[i]
switch key {
case "$ref":
if nsFound {
return "", primitive.NilObjectID, errors.New("duplicate $ref key in $dbPointer")
}
if val.t != bsontype.String {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t)
}
ns = val.v.(string)
nsFound = true
case "$id":
if oidFound {
return "", primitive.NilObjectID, errors.New("duplicate $id key in $dbPointer")
}
if val.t != bsontype.String {
return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t)
}
oid, err = primitive.ObjectIDFromHex(val.v.(string))
if err != nil {
return "", primitive.NilObjectID, err
}
oidFound = true
default:
return "", primitive.NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key)
}
}
if !nsFound {
return "", oid, errors.New("missing $ref field in $dbPointer object")
}
if !oidFound {
return "", oid, errors.New("missing $id field in $dbPointer object")
}
return ns, oid, nil
}
const (
rfc3339Milli = "2006-01-02T15:04:05.999Z07:00"
)
var (
timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"}
)
func (ejv *extJSONValue) parseDateTime() (int64, error) {
switch ejv.t {
case bsontype.Int32:
return int64(ejv.v.(int32)), nil
case bsontype.Int64:
return ejv.v.(int64), nil
case bsontype.String:
return parseDatetimeString(ejv.v.(string))
case bsontype.EmbeddedDocument:
return parseDatetimeObject(ejv.v.(*extJSONObject))
default:
return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t)
}
}
func parseDatetimeString(data string) (int64, error) {
var t time.Time
var err error
// try acceptable time formats until one matches
for _, format := range timeFormats {
t, err = time.Parse(format, data)
if err == nil {
break
}
}
if err != nil {
return 0, fmt.Errorf("invalid $date value string: %s", data)
}
return int64(primitive.NewDateTimeFromTime(t)), nil
}
func parseDatetimeObject(data *extJSONObject) (d int64, err error) {
dFound := false
for i, key := range data.keys {
val := data.values[i]
switch key {
case "$numberLong":
if dFound {
return 0, errors.New("duplicate $numberLong key in $date")
}
if val.t != bsontype.String {
return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t)
}
d, err = val.parseInt64()
if err != nil {
return 0, err
}
dFound = true
default:
return 0, fmt.Errorf("invalid key in $date object: %s", key)
}
}
if !dFound {
return 0, errors.New("missing $numberLong field in $date object")
}
return d, nil
}
func (ejv *extJSONValue) parseDecimal128() (primitive.Decimal128, error) {
if ejv.t != bsontype.String {
return primitive.Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t)
}
d, err := primitive.ParseDecimal128(ejv.v.(string))
if err != nil {
return primitive.Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string))
}
return d, nil
}
func (ejv *extJSONValue) parseDouble() (float64, error) {
if ejv.t == bsontype.Double {
return ejv.v.(float64), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t)
}
switch ejv.v.(string) {
case "Infinity":
return math.Inf(1), nil
case "-Infinity":
return math.Inf(-1), nil
case "NaN":
return math.NaN(), nil
}
f, err := strconv.ParseFloat(ejv.v.(string), 64)
if err != nil {
return 0, err
}
return f, nil
}
func (ejv *extJSONValue) parseInt32() (int32, error) {
if ejv.t == bsontype.Int32 {
return ejv.v.(int32), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
if i < math.MinInt32 || i > math.MaxInt32 {
return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i)
}
return int32(i), nil
}
func (ejv *extJSONValue) parseInt64() (int64, error) {
if ejv.t == bsontype.Int64 {
return ejv.v.(int64), nil
}
if ejv.t != bsontype.String {
return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t)
}
i, err := strconv.ParseInt(ejv.v.(string), 10, 64)
if err != nil {
return 0, err
}
return i, nil
}
func (ejv *extJSONValue) parseJavascript() (code string, err error) {
if ejv.t != bsontype.String {
return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseMinMaxKey(minmax string) error {
if ejv.t != bsontype.Int32 {
return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t)
}
if ejv.v.(int32) != 1 {
return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32))
}
return nil
}
func (ejv *extJSONValue) parseObjectID() (primitive.ObjectID, error) {
if ejv.t != bsontype.String {
return primitive.NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t)
}
return primitive.ObjectIDFromHex(ejv.v.(string))
}
func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t)
}
regexObj := ejv.v.(*extJSONObject)
patFound := false
optFound := false
for i, key := range regexObj.keys {
val := regexObj.values[i]
switch key {
case "pattern":
if patFound {
return "", "", errors.New("duplicate pattern key in $regularExpression")
}
if val.t != bsontype.String {
return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t)
}
pattern = val.v.(string)
patFound = true
case "options":
if optFound {
return "", "", errors.New("duplicate options key in $regularExpression")
}
if val.t != bsontype.String {
return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t)
}
options = val.v.(string)
optFound = true
default:
return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key)
}
}
if !patFound {
return "", "", errors.New("missing pattern field in $regularExpression object")
}
if !optFound {
return "", "", errors.New("missing options field in $regularExpression object")
}
return pattern, options, nil
}
func (ejv *extJSONValue) parseSymbol() (string, error) {
if ejv.t != bsontype.String {
return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t)
}
return ejv.v.(string), nil
}
func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) {
if ejv.t != bsontype.EmbeddedDocument {
return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t)
}
handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) {
if flag {
return 0, fmt.Errorf("duplicate %s key in $timestamp", key)
}
switch val.t {
case bsontype.Int32:
value := val.v.(int32)
if value < 0 {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
case bsontype.Int64:
value := val.v.(int64)
if value < 0 || value > int64(math.MaxUint32) {
return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value)
}
return uint32(value), nil
default:
return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t)
}
}
tsObj := ejv.v.(*extJSONObject)
tFound := false
iFound := false
for j, key := range tsObj.keys {
val := tsObj.values[j]
switch key {
case "t":
if t, err = handleKey(key, val, tFound); err != nil {
return 0, 0, err
}
tFound = true
case "i":
if i, err = handleKey(key, val, iFound); err != nil {
return 0, 0, err
}
iFound = true
default:
return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key)
}
}
if !tFound {
return 0, 0, errors.New("missing t field in $timestamp object")
}
if !iFound {
return 0, 0, errors.New("missing i field in $timestamp object")
}
return t, i, nil
}
func (ejv *extJSONValue) parseUndefined() error {
if ejv.t != bsontype.Boolean {
return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t)
}
if !ejv.v.(bool) {
return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool))
}
return nil
}

View File

@@ -0,0 +1,732 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"math"
"sort"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters.
type ExtJSONValueWriterPool struct {
pool sync.Pool
}
// NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON.
func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool {
return &ExtJSONValueWriterPool{
pool: sync.Pool{
New: func() interface{} {
return new(extJSONValueWriter)
},
},
}
}
// Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination.
func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter {
vw := bvwp.pool.Get().(*extJSONValueWriter)
if writer, ok := w.(*SliceWriter); ok {
vw.reset(*writer, canonical, escapeHTML)
vw.w = writer
return vw
}
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing
// happens and ok will be false.
func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
bvw, ok := vw.(*extJSONValueWriter)
if !ok {
return false
}
if _, ok := bvw.w.(*SliceWriter); ok {
bvw.buf = nil
}
bvw.w = nil
bvwp.pool.Put(bvw)
return true
}
type ejvwState struct {
mode mode
}
type extJSONValueWriter struct {
w io.Writer
buf []byte
stack []ejvwState
frame int64
canonical bool
escapeHTML bool
}
// NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w.
func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) {
if w == nil {
return nil, errNilWriter
}
return newExtJSONWriter(w, canonical, escapeHTML), nil
}
func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
w: w,
buf: []byte{},
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter {
stack := make([]ejvwState, 1, 5)
stack[0] = ejvwState{mode: mTopLevel}
return &extJSONValueWriter{
buf: buf,
stack: stack,
canonical: canonical,
escapeHTML: escapeHTML,
}
}
func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) {
if ejvw.stack == nil {
ejvw.stack = make([]ejvwState, 1, 5)
}
ejvw.stack = ejvw.stack[:1]
ejvw.stack[0] = ejvwState{mode: mTopLevel}
ejvw.canonical = canonical
ejvw.escapeHTML = escapeHTML
ejvw.frame = 0
ejvw.buf = buf
ejvw.w = nil
}
func (ejvw *extJSONValueWriter) advanceFrame() {
if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack
length := len(ejvw.stack)
if length+1 >= cap(ejvw.stack) {
// double it
buf := make([]ejvwState, 2*cap(ejvw.stack)+1)
copy(buf, ejvw.stack)
ejvw.stack = buf
}
ejvw.stack = ejvw.stack[:length+1]
}
ejvw.frame++
}
func (ejvw *extJSONValueWriter) push(m mode) {
ejvw.advanceFrame()
ejvw.stack[ejvw.frame].mode = m
}
func (ejvw *extJSONValueWriter) pop() {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
ejvw.frame--
case mDocument, mArray, mCodeWithScope:
ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: ejvw.stack[ejvw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if ejvw.frame != 0 {
te.parent = ejvw.stack[ejvw.frame-1].mode
}
return te
}
func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error {
switch ejvw.stack[ejvw.frame].mode {
case mElement, mValue:
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return ejvw.invalidTransitionErr(destination, callerName, modes)
}
return nil
}
func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) {
var s string
if quotes {
s = fmt.Sprintf(`{"$%s":"%s"}`, key, value)
} else {
s = fmt.Sprintf(`{"$%s":%s}`, key, value)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) {
if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '[')
ejvw.push(mArray)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error {
return ejvw.WriteBinaryWithSubtype(b, 0x00)
}
func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$binary":{"base64":"`)
buf.WriteString(base64.StdEncoding.EncodeToString(b))
buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype))
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error {
if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString(`{"$code":`)
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
buf.WriteString(`,"$scope":{`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.push(mCodeWithScope)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$dbPointer":{"$ref":"`)
buf.WriteString(ns)
buf.WriteString(`","$id":{"$oid":"`)
buf.WriteString(oid.Hex())
buf.WriteString(`"}}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil {
return err
}
t := time.Unix(dt/1e3, dt%1e3*1e6).UTC()
if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 {
s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt)
ejvw.writeExtendedSingleValue("date", s, false)
} else {
ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) {
if ejvw.stack[ejvw.frame].mode == mTopLevel {
ejvw.buf = append(ejvw.buf, '{')
return ejvw, nil
}
if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
ejvw.buf = append(ejvw.buf, '{')
ejvw.push(mDocument)
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDouble(f float64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil {
return err
}
s := formatDouble(f)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberDouble", s, true)
} else {
switch s {
case "Infinity":
fallthrough
case "-Infinity":
fallthrough
case "NaN":
s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s)
}
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt32(i int32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil {
return err
}
s := strconv.FormatInt(int64(i), 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberInt", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteInt64(i int64) error {
if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil {
return err
}
s := strconv.FormatInt(i, 10)
if ejvw.canonical {
ejvw.writeExtendedSingleValue("numberLong", s, true)
} else {
ejvw.buf = append(ejvw.buf, []byte(s)...)
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteJavascript(code string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(code, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("code", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMaxKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("maxKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteMinKey() error {
if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("minKey", "1", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteNull() error {
if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil {
return err
}
ejvw.buf = append(ejvw.buf, []byte("null")...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error {
if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("oid", oid.Hex(), true)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$regularExpression":{"pattern":`)
writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML)
buf.WriteString(`,"options":"`)
buf.WriteString(sortStringAlphebeticAscending(options))
buf.WriteString(`"}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteString(s string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(s, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error {
if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil {
return err
}
var buf bytes.Buffer
writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML)
ejvw.writeExtendedSingleValue("symbol", buf.String(), false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil {
return err
}
var buf bytes.Buffer
buf.WriteString(`{"$timestamp":{"t":`)
buf.WriteString(strconv.FormatUint(uint64(t), 10))
buf.WriteString(`,"i":`)
buf.WriteString(strconv.FormatUint(uint64(i), 10))
buf.WriteString(`}},`)
ejvw.buf = append(ejvw.buf, buf.Bytes()...)
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteUndefined() error {
if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil {
return err
}
ejvw.writeExtendedSingleValue("undefined", "true", false)
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
var buf bytes.Buffer
writeStringWithEscapes(key, &buf, ejvw.escapeHTML)
ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...)
ejvw.push(mElement)
default:
return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteDocumentEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mDocument, mTopLevel, mCodeWithScope:
default:
return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode)
}
// close the document
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = '}'
} else {
ejvw.buf = append(ejvw.buf, '}')
}
switch ejvw.stack[ejvw.frame].mode {
case mCodeWithScope:
ejvw.buf = append(ejvw.buf, '}')
fallthrough
case mDocument:
ejvw.buf = append(ejvw.buf, ',')
case mTopLevel:
if ejvw.w != nil {
if _, err := ejvw.w.Write(ejvw.buf); err != nil {
return err
}
ejvw.buf = ejvw.buf[:0]
}
}
ejvw.pop()
return nil
}
func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
ejvw.push(mValue)
default:
return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray})
}
return ejvw, nil
}
func (ejvw *extJSONValueWriter) WriteArrayEnd() error {
switch ejvw.stack[ejvw.frame].mode {
case mArray:
// close the array
if ejvw.buf[len(ejvw.buf)-1] == ',' {
ejvw.buf[len(ejvw.buf)-1] = ']'
} else {
ejvw.buf = append(ejvw.buf, ']')
}
ejvw.buf = append(ejvw.buf, ',')
ejvw.pop()
default:
return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode)
}
return nil
}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
var hexChars = "0123456789abcdef"
func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) {
buf.WriteByte('"')
start := 0
for i := 0; i < len(s); {
if b := s[i]; b < utf8.RuneSelf {
if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) {
i++
continue
}
if start < i {
buf.WriteString(s[start:i])
}
switch b {
case '\\', '"':
buf.WriteByte('\\')
buf.WriteByte(b)
case '\n':
buf.WriteByte('\\')
buf.WriteByte('n')
case '\r':
buf.WriteByte('\\')
buf.WriteByte('r')
case '\t':
buf.WriteByte('\\')
buf.WriteByte('t')
case '\b':
buf.WriteByte('\\')
buf.WriteByte('b')
case '\f':
buf.WriteByte('\\')
buf.WriteByte('f')
default:
// This encodes bytes < 0x20 except for \t, \n and \r.
// If escapeHTML is set, it also escapes <, >, and &
// because they can lead to security holes when
// user-controlled strings are rendered into JSON
// and served to some browsers.
buf.WriteString(`\u00`)
buf.WriteByte(hexChars[b>>4])
buf.WriteByte(hexChars[b&0xF])
}
i++
start = i
continue
}
c, size := utf8.DecodeRuneInString(s[i:])
if c == utf8.RuneError && size == 1 {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\ufffd`)
i += size
start = i
continue
}
// U+2028 is LINE SEPARATOR.
// U+2029 is PARAGRAPH SEPARATOR.
// They are both technically valid characters in JSON strings,
// but don't work in JSONP, which has to be evaluated as JavaScript,
// and can lead to security holes there. It is valid JSON to
// escape them, so we do so unconditionally.
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
if c == '\u2028' || c == '\u2029' {
if start < i {
buf.WriteString(s[start:i])
}
buf.WriteString(`\u202`)
buf.WriteByte(hexChars[c&0xF])
i += size
start = i
continue
}
i += size
}
if start < len(s) {
buf.WriteString(s[start:])
}
buf.WriteByte('"')
}
type sortableString []rune
func (ss sortableString) Len() int {
return len(ss)
}
func (ss sortableString) Less(i, j int) bool {
return ss[i] < ss[j]
}
func (ss sortableString) Swap(i, j int) {
oldI := ss[i]
ss[i] = ss[j]
ss[j] = oldI
}
func sortStringAlphebeticAscending(s string) string {
ss := sortableString([]rune(s))
sort.Sort(ss)
return string([]rune(ss))
}

View File

@@ -0,0 +1,260 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
"io/ioutil"
"reflect"
"strings"
"testing"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func TestExtJSONValueWriter(t *testing.T) {
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
testCases := []struct {
name string
fn interface{}
params []interface{}
}{
{
"WriteBinary",
(*extJSONValueWriter).WriteBinary,
[]interface{}{[]byte{0x01, 0x02, 0x03}},
},
{
"WriteBinaryWithSubtype (not 0x02)",
(*extJSONValueWriter).WriteBinaryWithSubtype,
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)},
},
{
"WriteBinaryWithSubtype (0x02)",
(*extJSONValueWriter).WriteBinaryWithSubtype,
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)},
},
{
"WriteBoolean",
(*extJSONValueWriter).WriteBoolean,
[]interface{}{true},
},
{
"WriteDBPointer",
(*extJSONValueWriter).WriteDBPointer,
[]interface{}{"bar", oid},
},
{
"WriteDateTime",
(*extJSONValueWriter).WriteDateTime,
[]interface{}{int64(12345678)},
},
{
"WriteDecimal128",
(*extJSONValueWriter).WriteDecimal128,
[]interface{}{primitive.NewDecimal128(10, 20)},
},
{
"WriteDouble",
(*extJSONValueWriter).WriteDouble,
[]interface{}{float64(3.14159)},
},
{
"WriteInt32",
(*extJSONValueWriter).WriteInt32,
[]interface{}{int32(123456)},
},
{
"WriteInt64",
(*extJSONValueWriter).WriteInt64,
[]interface{}{int64(1234567890)},
},
{
"WriteJavascript",
(*extJSONValueWriter).WriteJavascript,
[]interface{}{"var foo = 'bar';"},
},
{
"WriteMaxKey",
(*extJSONValueWriter).WriteMaxKey,
[]interface{}{},
},
{
"WriteMinKey",
(*extJSONValueWriter).WriteMinKey,
[]interface{}{},
},
{
"WriteNull",
(*extJSONValueWriter).WriteNull,
[]interface{}{},
},
{
"WriteObjectID",
(*extJSONValueWriter).WriteObjectID,
[]interface{}{oid},
},
{
"WriteRegex",
(*extJSONValueWriter).WriteRegex,
[]interface{}{"bar", "baz"},
},
{
"WriteString",
(*extJSONValueWriter).WriteString,
[]interface{}{"hello, world!"},
},
{
"WriteSymbol",
(*extJSONValueWriter).WriteSymbol,
[]interface{}{"symbollolz"},
},
{
"WriteTimestamp",
(*extJSONValueWriter).WriteTimestamp,
[]interface{}{uint32(10), uint32(20)},
},
{
"WriteUndefined",
(*extJSONValueWriter).WriteUndefined,
[]interface{}{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*extJSONValueWriter)(nil)) {
t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter")
}
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
t.Fatalf("fn must have one return value and it must be an error.")
}
params := make([]reflect.Value, 1, len(tc.params)+1)
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
params[0] = reflect.ValueOf(ejvw)
for _, param := range tc.params {
params = append(params, reflect.ValueOf(param))
}
t.Run("incorrect transition", func(t *testing.T) {
results := fn.Call(params)
got := results[0].Interface().(error)
fnName := tc.name
if strings.Contains(fnName, "WriteBinary") {
fnName = "WriteBinaryWithSubtype"
}
want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue},
action: "write"}
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
})
}
t.Run("WriteArray", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mArray)
want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel,
name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"}
_, got := ejvw.WriteArray()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteCodeWithScope", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mArray)
want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel,
name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"}
_, got := ejvw.WriteCodeWithScope("")
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocument", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mArray)
want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel,
name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"}
_, got := ejvw.WriteDocument()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocumentElement", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mElement)
want := TransitionError{current: mElement,
destination: mElement,
parent: mTopLevel,
name: "WriteDocumentElement",
modes: []mode{mDocument, mTopLevel, mCodeWithScope},
action: "write"}
_, got := ejvw.WriteDocumentElement("")
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocumentEnd", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mElement)
want := fmt.Errorf("incorrect mode to end document: %s", mElement)
got := ejvw.WriteDocumentEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteArrayElement", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mElement)
want := TransitionError{current: mElement,
destination: mValue,
parent: mTopLevel,
name: "WriteArrayElement",
modes: []mode{mArray},
action: "write"}
_, got := ejvw.WriteArrayElement()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteArrayEnd", func(t *testing.T) {
ejvw := newExtJSONWriter(ioutil.Discard, true, true)
ejvw.push(mElement)
want := fmt.Errorf("incorrect mode to end array: %s", mElement)
got := ejvw.WriteArrayEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteBytes", func(t *testing.T) {
t.Run("writeElementHeader error", func(t *testing.T) {
ejvw := newExtJSONWriterFromSlice(nil, true, true)
want := TransitionError{current: mTopLevel, destination: mode(0),
name: "WriteBinaryWithSubtype", modes: []mode{mElement, mValue}, action: "write"}
got := ejvw.WriteBinaryWithSubtype(nil, (byte)(bsontype.EmbeddedDocument))
if !compareErrors(got, want) {
t.Errorf("Did not received expected error. got %v; want %v", got, want)
}
})
})
t.Run("FormatDoubleWithExponent", func(t *testing.T) {
want := "3E-12"
got := formatDouble(float64(0.000000000003))
if got != want {
t.Errorf("Did not receive expected string. got %s: want %s", got, want)
}
})
}

View File

@@ -0,0 +1,528 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"errors"
"fmt"
"io"
"math"
"strconv"
"unicode"
"unicode/utf16"
)
type jsonTokenType byte
const (
jttBeginObject jsonTokenType = iota
jttEndObject
jttBeginArray
jttEndArray
jttColon
jttComma
jttInt32
jttInt64
jttDouble
jttString
jttBool
jttNull
jttEOF
)
type jsonToken struct {
t jsonTokenType
v interface{}
p int
}
type jsonScanner struct {
r io.Reader
buf []byte
pos int
lastReadErr error
}
// nextToken returns the next JSON token if one exists. A token is a character
// of the JSON grammar, a number, a string, or a literal.
func (js *jsonScanner) nextToken() (*jsonToken, error) {
c, err := js.readNextByte()
// keep reading until a non-space is encountered (break on read error or EOF)
for isWhiteSpace(c) && err == nil {
c, err = js.readNextByte()
}
if err == io.EOF {
return &jsonToken{t: jttEOF}, nil
} else if err != nil {
return nil, err
}
// switch on the character
switch c {
case '{':
return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil
case '}':
return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil
case '[':
return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil
case ']':
return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil
case ':':
return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil
case ',':
return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil
case '"': // RFC-8259 only allows for double quotes (") not single (')
return js.scanString()
default:
// check if it's a number
if c == '-' || isDigit(c) {
return js.scanNumber(c)
} else if c == 't' || c == 'f' || c == 'n' {
// maybe a literal
return js.scanLiteral(c)
} else {
return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c)
}
}
}
// readNextByte attempts to read the next byte from the buffer. If the buffer
// has been exhausted, this function calls readIntoBuf, thus refilling the
// buffer and resetting the read position to 0
func (js *jsonScanner) readNextByte() (byte, error) {
if js.pos >= len(js.buf) {
err := js.readIntoBuf()
if err != nil {
return 0, err
}
}
b := js.buf[js.pos]
js.pos++
return b, nil
}
// readNNextBytes reads n bytes into dst, starting at offset
func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error {
var err error
for i := 0; i < n; i++ {
dst[i+offset], err = js.readNextByte()
if err != nil {
return err
}
}
return nil
}
// readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer
func (js *jsonScanner) readIntoBuf() error {
if js.lastReadErr != nil {
js.buf = js.buf[:0]
js.pos = 0
return js.lastReadErr
}
if cap(js.buf) == 0 {
js.buf = make([]byte, 0, 512)
}
n, err := js.r.Read(js.buf[:cap(js.buf)])
if err != nil {
js.lastReadErr = err
if n > 0 {
err = nil
}
}
js.buf = js.buf[:n]
js.pos = 0
return err
}
func isWhiteSpace(c byte) bool {
return c == ' ' || c == '\t' || c == '\r' || c == '\n'
}
func isDigit(c byte) bool {
return unicode.IsDigit(rune(c))
}
func isValueTerminator(c byte) bool {
return c == ',' || c == '}' || c == ']' || isWhiteSpace(c)
}
// getu4 decodes the 4-byte hex sequence from the beginning of s, returning the hex value as a rune,
// or it returns -1. Note that the "\u" from the unicode escape sequence should not be present.
// It is copied and lightly modified from the Go JSON decode function at
// https://github.com/golang/go/blob/1b0a0316802b8048d69da49dc23c5a5ab08e8ae8/src/encoding/json/decode.go#L1169-L1188
func getu4(s []byte) rune {
if len(s) < 4 {
return -1
}
var r rune
for _, c := range s[:4] {
switch {
case '0' <= c && c <= '9':
c = c - '0'
case 'a' <= c && c <= 'f':
c = c - 'a' + 10
case 'A' <= c && c <= 'F':
c = c - 'A' + 10
default:
return -1
}
r = r*16 + rune(c)
}
return r
}
// scanString reads from an opening '"' to a closing '"' and handles escaped characters
func (js *jsonScanner) scanString() (*jsonToken, error) {
var b bytes.Buffer
var c byte
var err error
p := js.pos - 1
for {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
evalNextChar:
switch c {
case '\\':
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
evalNextEscapeChar:
switch c {
case '"', '\\', '/':
b.WriteByte(c)
case 'b':
b.WriteByte('\b')
case 'f':
b.WriteByte('\f')
case 'n':
b.WriteByte('\n')
case 'r':
b.WriteByte('\r')
case 't':
b.WriteByte('\t')
case 'u':
us := make([]byte, 4)
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
rn := getu4(us)
// If the rune we just decoded is the high or low value of a possible surrogate pair,
// try to decode the next sequence as the low value of a surrogate pair. We're
// expecting the next sequence to be another Unicode escape sequence (e.g. "\uDD1E"),
// but need to handle cases where the input is not a valid surrogate pair.
// For more context on unicode surrogate pairs, see:
// https://www.christianfscott.com/rust-chars-vs-go-runes/
// https://www.unicode.org/glossary/#high_surrogate_code_point
if utf16.IsSurrogate(rn) {
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
// If the next value isn't the beginning of a backslash escape sequence, write
// the Unicode replacement character for the surrogate value and goto the
// beginning of the next char eval block.
if c != '\\' {
b.WriteRune(unicode.ReplacementChar)
goto evalNextChar
}
c, err = js.readNextByte()
if err != nil {
if err == io.EOF {
return nil, errors.New("end of input in JSON string")
}
return nil, err
}
// If the next value isn't the beginning of a unicode escape sequence, write the
// Unicode replacement character for the surrogate value and goto the beginning
// of the next escape char eval block.
if c != 'u' {
b.WriteRune(unicode.ReplacementChar)
goto evalNextEscapeChar
}
err = js.readNNextBytes(us, 4, 0)
if err != nil {
return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us)
}
rn2 := getu4(us)
// Try to decode the pair of runes as a utf16 surrogate pair. If that fails, write
// the Unicode replacement character for the surrogate value and the 2nd decoded rune.
if rnPair := utf16.DecodeRune(rn, rn2); rnPair != unicode.ReplacementChar {
b.WriteRune(rnPair)
} else {
b.WriteRune(unicode.ReplacementChar)
b.WriteRune(rn2)
}
break
}
b.WriteRune(rn)
default:
return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c)
}
case '"':
return &jsonToken{t: jttString, v: b.String(), p: p}, nil
default:
b.WriteByte(c)
}
}
}
// scanLiteral reads an unquoted sequence of characters and determines if it is one of
// three valid JSON literals (true, false, null); if so, it returns the appropriate
// jsonToken; otherwise, it returns an error
func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) {
p := js.pos - 1
lit := make([]byte, 4)
lit[0] = first
err := js.readNNextBytes(lit, 3, 1)
if err != nil {
return nil, err
}
c5, err := js.readNextByte()
if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: true, p: p}, nil
} else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttNull, v: nil, p: p}, nil
} else if bytes.Equal([]byte("fals"), lit) {
if c5 == 'e' {
c5, err = js.readNextByte()
if isValueTerminator(c5) || err == io.EOF {
js.pos = int(math.Max(0, float64(js.pos-1)))
return &jsonToken{t: jttBool, v: false, p: p}, nil
}
}
}
return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit)
}
type numberScanState byte
const (
nssSawLeadingMinus numberScanState = iota
nssSawLeadingZero
nssSawIntegerDigits
nssSawDecimalPoint
nssSawFractionDigits
nssSawExponentLetter
nssSawExponentSign
nssSawExponentDigits
nssDone
nssInvalid
)
// scanNumber reads a JSON number (according to RFC-8259)
func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) {
var b bytes.Buffer
var s numberScanState
var c byte
var err error
t := jttInt64 // assume it's an int64 until the type can be determined
start := js.pos - 1
b.WriteByte(first)
switch first {
case '-':
s = nssSawLeadingMinus
case '0':
s = nssSawLeadingZero
default:
s = nssSawIntegerDigits
}
for {
c, err = js.readNextByte()
if err != nil && err != io.EOF {
return nil, err
}
switch s {
case nssSawLeadingMinus:
switch c {
case '0':
s = nssSawLeadingZero
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawLeadingZero:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else {
s = nssInvalid
}
}
case nssSawIntegerDigits:
switch c {
case '.':
s = nssSawDecimalPoint
b.WriteByte(c)
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawIntegerDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawDecimalPoint:
t = jttDouble
if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawFractionDigits:
switch c {
case 'e', 'E':
s = nssSawExponentLetter
b.WriteByte(c)
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawFractionDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentLetter:
t = jttDouble
switch c {
case '+', '-':
s = nssSawExponentSign
b.WriteByte(c)
default:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
case nssSawExponentSign:
if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
case nssSawExponentDigits:
switch c {
case '}', ']', ',':
s = nssDone
default:
if isWhiteSpace(c) || err == io.EOF {
s = nssDone
} else if isDigit(c) {
s = nssSawExponentDigits
b.WriteByte(c)
} else {
s = nssInvalid
}
}
}
switch s {
case nssInvalid:
return nil, fmt.Errorf("invalid JSON number. Position: %d", start)
case nssDone:
js.pos = int(math.Max(0, float64(js.pos-1)))
if t != jttDouble {
v, err := strconv.ParseInt(b.String(), 10, 64)
if err == nil {
if v < math.MinInt32 || v > math.MaxInt32 {
return &jsonToken{t: jttInt64, v: v, p: start}, nil
}
return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil
}
}
v, err := strconv.ParseFloat(b.String(), 64)
if err != nil {
return nil, err
}
return &jsonToken{t: jttDouble, v: v, p: start}, nil
}
}
}

View File

@@ -0,0 +1,376 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"strings"
"testing"
"testing/iotest"
"github.com/google/go-cmp/cmp"
)
func jttDiff(t *testing.T, expected, actual jsonTokenType, desc string) {
if diff := cmp.Diff(expected, actual); diff != "" {
t.Helper()
t.Errorf("%s: Incorrect JSON Token Type (-want, +got): %s\n", desc, diff)
t.FailNow()
}
}
func jtvDiff(t *testing.T, expected, actual interface{}, desc string) {
if diff := cmp.Diff(expected, actual); diff != "" {
t.Helper()
t.Errorf("%s: Incorrect JSON Token Value (-want, +got): %s\n", desc, diff)
t.FailNow()
}
}
func expectNilToken(t *testing.T, v *jsonToken, desc string) {
if v != nil {
t.Helper()
t.Errorf("%s: Expected nil JSON token", desc)
t.FailNow()
}
}
func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}
func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}
type jsonScannerTestCase struct {
desc string
input string
tokens []jsonToken
}
// length = 512
const longKey = "abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" +
"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqr"
func TestJsonScannerValidInputs(t *testing.T) {
cases := []jsonScannerTestCase{
{
desc: "empty input", input: "",
tokens: []jsonToken{},
},
{
desc: "empty object", input: "{}",
tokens: []jsonToken{{t: jttBeginObject, v: byte('{')}, {t: jttEndObject, v: byte('}')}},
},
{
desc: "empty array", input: "[]",
tokens: []jsonToken{{t: jttBeginArray, v: byte('[')}, {t: jttEndArray, v: byte(']')}},
},
{
desc: "valid empty string", input: `""`,
tokens: []jsonToken{{t: jttString, v: ""}},
},
{
desc: "valid string--no escaped characters",
input: `"string"`,
tokens: []jsonToken{{t: jttString, v: "string"}},
},
{
desc: "valid string--escaped characters",
input: `"\"\\\/\b\f\n\r\t"`,
tokens: []jsonToken{{t: jttString, v: "\"\\/\b\f\n\r\t"}},
},
{
desc: "valid string--surrogate pair",
input: `"abc \uD834\uDd1e 123"`,
tokens: []jsonToken{{t: jttString, v: "abc 𝄞 123"}},
},
{
desc: "valid string--high surrogate at end of string",
input: `"abc \uD834"`,
tokens: []jsonToken{{t: jttString, v: "abc <20>"}},
},
{
desc: "valid string--low surrogate at end of string",
input: `"abc \uDD1E"`,
tokens: []jsonToken{{t: jttString, v: "abc <20>"}},
},
{
desc: "valid string--high surrogate with non-surrogate Unicode value",
input: `"abc \uDD1E\u00BF"`,
tokens: []jsonToken{{t: jttString, v: "abc <20>¿"}},
},
{
desc: "valid string--high surrogate with non-Unicode escape sequence",
input: `"abc \uDD1E\t"`,
tokens: []jsonToken{{t: jttString, v: "abc <20>\t"}},
},
{
desc: "valid literal--true", input: "true",
tokens: []jsonToken{{t: jttBool, v: true}},
},
{
desc: "valid literal--false", input: "false",
tokens: []jsonToken{{t: jttBool, v: false}},
},
{
desc: "valid literal--null", input: "null",
tokens: []jsonToken{{t: jttNull}},
},
{
desc: "valid int32: 0", input: "0",
tokens: []jsonToken{{t: jttInt32, v: int32(0)}},
},
{
desc: "valid int32: -0", input: "-0",
tokens: []jsonToken{{t: jttInt32, v: int32(0)}},
},
{
desc: "valid int32: 1", input: "1",
tokens: []jsonToken{{t: jttInt32, v: int32(1)}},
},
{
desc: "valid int32: -1", input: "-1",
tokens: []jsonToken{{t: jttInt32, v: int32(-1)}},
},
{
desc: "valid int32: 10", input: "10",
tokens: []jsonToken{{t: jttInt32, v: int32(10)}},
},
{
desc: "valid int32: 1234", input: "1234",
tokens: []jsonToken{{t: jttInt32, v: int32(1234)}},
},
{
desc: "valid int32: -10", input: "-10",
tokens: []jsonToken{{t: jttInt32, v: int32(-10)}},
},
{
desc: "valid int32: -1234", input: "-1234",
tokens: []jsonToken{{t: jttInt32, v: int32(-1234)}},
},
{
desc: "valid int64: 2147483648", input: "2147483648",
tokens: []jsonToken{{t: jttInt64, v: int64(2147483648)}},
},
{
desc: "valid int64: -2147483649", input: "-2147483649",
tokens: []jsonToken{{t: jttInt64, v: int64(-2147483649)}},
},
{
desc: "valid double: 0.0", input: "0.0",
tokens: []jsonToken{{t: jttDouble, v: 0.0}},
},
{
desc: "valid double: -0.0", input: "-0.0",
tokens: []jsonToken{{t: jttDouble, v: 0.0}},
},
{
desc: "valid double: 0.1", input: "0.1",
tokens: []jsonToken{{t: jttDouble, v: 0.1}},
},
{
desc: "valid double: 0.1234", input: "0.1234",
tokens: []jsonToken{{t: jttDouble, v: 0.1234}},
},
{
desc: "valid double: 1.0", input: "1.0",
tokens: []jsonToken{{t: jttDouble, v: 1.0}},
},
{
desc: "valid double: -1.0", input: "-1.0",
tokens: []jsonToken{{t: jttDouble, v: -1.0}},
},
{
desc: "valid double: 1.234", input: "1.234",
tokens: []jsonToken{{t: jttDouble, v: 1.234}},
},
{
desc: "valid double: -1.234", input: "-1.234",
tokens: []jsonToken{{t: jttDouble, v: -1.234}},
},
{
desc: "valid double: 1e10", input: "1e10",
tokens: []jsonToken{{t: jttDouble, v: 1e+10}},
},
{
desc: "valid double: 1E10", input: "1E10",
tokens: []jsonToken{{t: jttDouble, v: 1e+10}},
},
{
desc: "valid double: 1.2e10", input: "1.2e10",
tokens: []jsonToken{{t: jttDouble, v: 1.2e+10}},
},
{
desc: "valid double: 1.2E10", input: "1.2E10",
tokens: []jsonToken{{t: jttDouble, v: 1.2e+10}},
},
{
desc: "valid double: -1.2e10", input: "-1.2e10",
tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}},
},
{
desc: "valid double: -1.2E10", input: "-1.2E10",
tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}},
},
{
desc: "valid double: -1.2e+10", input: "-1.2e+10",
tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}},
},
{
desc: "valid double: -1.2E+10", input: "-1.2E+10",
tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}},
},
{
desc: "valid double: 1.2e-10", input: "1.2e-10",
tokens: []jsonToken{{t: jttDouble, v: 1.2e-10}},
},
{
desc: "valid double: 1.2E-10", input: "1.2e-10",
tokens: []jsonToken{{t: jttDouble, v: 1.2e-10}},
},
{
desc: "valid double: -1.2e-10", input: "-1.2e-10",
tokens: []jsonToken{{t: jttDouble, v: -1.2e-10}},
},
{
desc: "valid double: -1.2E-10", input: "-1.2E-10",
tokens: []jsonToken{{t: jttDouble, v: -1.2e-10}},
},
{
desc: "valid double: 8005332285744496613785600", input: "8005332285744496613785600",
tokens: []jsonToken{{t: jttDouble, v: float64(8005332285744496613785600)}},
},
{
desc: "valid object, only spaces",
input: `{"key": "string", "key2": 2, "key3": {}, "key4": [], "key5": false }`,
tokens: []jsonToken{
{t: jttBeginObject, v: byte('{')}, {t: jttString, v: "key"}, {t: jttColon, v: byte(':')}, {t: jttString, v: "string"},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key2"}, {t: jttColon, v: byte(':')}, {t: jttInt32, v: int32(2)},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key3"}, {t: jttColon, v: byte(':')}, {t: jttBeginObject, v: byte('{')}, {t: jttEndObject, v: byte('}')},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key4"}, {t: jttColon, v: byte(':')}, {t: jttBeginArray, v: byte('[')}, {t: jttEndArray, v: byte(']')},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key5"}, {t: jttColon, v: byte(':')}, {t: jttBool, v: false}, {t: jttEndObject, v: byte('}')},
},
},
{
desc: "valid object, mixed whitespace",
input: `
{ "key" : "string"
, "key2": 2
, "key3": {}
, "key4": []
, "key5": false
}`,
tokens: []jsonToken{
{t: jttBeginObject, v: byte('{')}, {t: jttString, v: "key"}, {t: jttColon, v: byte(':')}, {t: jttString, v: "string"},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key2"}, {t: jttColon, v: byte(':')}, {t: jttInt32, v: int32(2)},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key3"}, {t: jttColon, v: byte(':')}, {t: jttBeginObject, v: byte('{')}, {t: jttEndObject, v: byte('}')},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key4"}, {t: jttColon, v: byte(':')}, {t: jttBeginArray, v: byte('[')}, {t: jttEndArray, v: byte(']')},
{t: jttComma, v: byte(',')}, {t: jttString, v: "key5"}, {t: jttColon, v: byte(':')}, {t: jttBool, v: false}, {t: jttEndObject, v: byte('}')},
},
},
{
desc: "input greater than buffer size",
input: `{"` + longKey + `": 1}`,
tokens: []jsonToken{
{t: jttBeginObject, v: byte('{')}, {t: jttString, v: longKey}, {t: jttColon, v: byte(':')},
{t: jttInt32, v: int32(1)}, {t: jttEndObject, v: byte('}')},
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
js := &jsonScanner{r: strings.NewReader(tc.input)}
for _, token := range tc.tokens {
c, err := js.nextToken()
expectNoError(t, err, tc.desc)
jttDiff(t, token.t, c.t, tc.desc)
jtvDiff(t, token.v, c.v, tc.desc)
}
c, err := js.nextToken()
noerr(t, err)
jttDiff(t, jttEOF, c.t, tc.desc)
// testing early EOF reading
js = &jsonScanner{r: iotest.DataErrReader(strings.NewReader(tc.input))}
for _, token := range tc.tokens {
c, err := js.nextToken()
expectNoError(t, err, tc.desc)
jttDiff(t, token.t, c.t, tc.desc)
jtvDiff(t, token.v, c.v, tc.desc)
}
c, err = js.nextToken()
noerr(t, err)
jttDiff(t, jttEOF, c.t, tc.desc)
})
}
}
func TestJsonScannerInvalidInputs(t *testing.T) {
cases := []jsonScannerTestCase{
{desc: "missing quotation", input: `"missing`},
{desc: "invalid escape character--first character", input: `"\invalid"`},
{desc: "invalid escape character--middle", input: `"i\nv\alid"`},
{desc: "invalid escape character--single quote", input: `"f\'oo"`},
{desc: "invalid literal--trueee", input: "trueee"},
{desc: "invalid literal--tire", input: "tire"},
{desc: "invalid literal--nulll", input: "nulll"},
{desc: "invalid literal--fals", input: "fals"},
{desc: "invalid literal--falsee", input: "falsee"},
{desc: "invalid literal--fake", input: "fake"},
{desc: "invalid literal--bad", input: "bad"},
{desc: "invalid number: -", input: "-"},
{desc: "invalid number: --0", input: "--0"},
{desc: "invalid number: -a", input: "-a"},
{desc: "invalid number: 00", input: "00"},
{desc: "invalid number: 01", input: "01"},
{desc: "invalid number: 0-", input: "0-"},
{desc: "invalid number: 1-", input: "1-"},
{desc: "invalid number: 0..", input: "0.."},
{desc: "invalid number: 0.-", input: "0.-"},
{desc: "invalid number: 0..0", input: "0..0"},
{desc: "invalid number: 0.1.0", input: "0.1.0"},
{desc: "invalid number: 0e", input: "0e"},
{desc: "invalid number: 0e.", input: "0e."},
{desc: "invalid number: 0e1.", input: "0e1."},
{desc: "invalid number: 0e1e", input: "0e1e"},
{desc: "invalid number: 0e+.1", input: "0e+.1"},
{desc: "invalid number: 0e+1.", input: "0e+1."},
{desc: "invalid number: 0e+1e", input: "0e+1e"},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
js := &jsonScanner{r: strings.NewReader(tc.input)}
c, err := js.nextToken()
expectNilToken(t, c, tc.desc)
expectError(t, err, tc.desc)
})
}
}

108
mongo/bson/bsonrw/mode.go Normal file
View File

@@ -0,0 +1,108 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"fmt"
)
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "DocumentMode"
case mArray:
str = "ArrayMode"
case mValue:
str = "ValueMode"
case mElement:
str = "ElementMode"
case mCodeWithScope:
str = "CodeWithScopeMode"
case mSpacer:
str = "CodeWithScopeSpacerFrame"
default:
str = "UnknownMode"
}
return str
}
func (m mode) TypeString() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "Document"
case mArray:
str = "Array"
case mValue:
str = "Value"
case mElement:
str = "Element"
case mCodeWithScope:
str = "CodeWithScope"
case mSpacer:
str = "CodeWithScopeSpacer"
default:
str = "Unknown"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
// If read is false, the error is for writing
type TransitionError struct {
name string
parent mode
current mode
destination mode
modes []mode
action string
}
func (te TransitionError) Error() string {
errString := fmt.Sprintf("%s can only %s", te.name, te.action)
if te.destination != mode(0) {
errString = fmt.Sprintf("%s a %s", errString, te.destination.TypeString())
}
errString = fmt.Sprintf("%s while positioned on a", errString)
for ind, m := range te.modes {
if ind != 0 && len(te.modes) > 2 {
errString = fmt.Sprintf("%s,", errString)
}
if ind == len(te.modes)-1 && len(te.modes) > 1 {
errString = fmt.Sprintf("%s or", errString)
}
errString = fmt.Sprintf("%s %s", errString, m.TypeString())
}
errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current.TypeString())
if te.parent != mode(0) {
errString = fmt.Sprintf("%s with parent %s", errString, te.parent.TypeString())
}
return errString
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayReader is implemented by types that allow reading values from a BSON
// array.
type ArrayReader interface {
ReadValue() (ValueReader, error)
}
// DocumentReader is implemented by types that allow reading elements from a
// BSON document.
type DocumentReader interface {
ReadElement() (string, ValueReader, error)
}
// ValueReader is a generic interface used to read values from BSON. This type
// is implemented by several types with different underlying representations of
// BSON, such as a bson.Document, raw BSON bytes, or extended JSON.
type ValueReader interface {
Type() bsontype.Type
Skip() error
ReadArray() (ArrayReader, error)
ReadBinary() (b []byte, btype byte, err error)
ReadBoolean() (bool, error)
ReadDocument() (DocumentReader, error)
ReadCodeWithScope() (code string, dr DocumentReader, err error)
ReadDBPointer() (ns string, oid primitive.ObjectID, err error)
ReadDateTime() (int64, error)
ReadDecimal128() (primitive.Decimal128, error)
ReadDouble() (float64, error)
ReadInt32() (int32, error)
ReadInt64() (int64, error)
ReadJavascript() (code string, err error)
ReadMaxKey() error
ReadMinKey() error
ReadNull() error
ReadObjectID() (primitive.ObjectID, error)
ReadRegex() (pattern, options string, err error)
ReadString() (string, error)
ReadSymbol() (symbol string, err error)
ReadTimestamp() (t, i uint32, err error)
ReadUndefined() error
}
// BytesReader is a generic interface used to read BSON bytes from a
// ValueReader. This imterface is meant to be a superset of ValueReader, so that
// types that implement ValueReader may also implement this interface.
//
// The bytes of the value will be appended to dst.
type BytesReader interface {
ReadValueBytes(dst []byte) (bsontype.Type, []byte, error)
}

View File

@@ -0,0 +1,874 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var _ ValueReader = (*valueReader)(nil)
var vrPool = sync.Pool{
New: func() interface{} {
return new(valueReader)
},
}
// BSONValueReaderPool is a pool for ValueReaders that read BSON.
type BSONValueReaderPool struct {
pool sync.Pool
}
// NewBSONValueReaderPool instantiates a new BSONValueReaderPool.
func NewBSONValueReaderPool() *BSONValueReaderPool {
return &BSONValueReaderPool{
pool: sync.Pool{
New: func() interface{} {
return new(valueReader)
},
},
}
}
// Get retrieves a ValueReader from the pool and uses src as the underlying BSON.
func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader {
vr := bvrp.pool.Get().(*valueReader)
vr.reset(src)
return vr
}
// Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing
// is inserted into the pool and ok will be false.
func (bvrp *BSONValueReaderPool) Put(vr ValueReader) (ok bool) {
bvr, ok := vr.(*valueReader)
if !ok {
return false
}
bvr.reset(nil)
bvrp.pool.Put(bvr)
return true
}
// ErrEOA is the error returned when the end of a BSON array has been reached.
var ErrEOA = errors.New("end of array")
// ErrEOD is the error returned when the end of a BSON document has been reached.
var ErrEOD = errors.New("end of document")
type vrState struct {
mode mode
vType bsontype.Type
end int64
}
// valueReader is for reading BSON values.
type valueReader struct {
offset int64
d []byte
stack []vrState
frame int64
}
// NewBSONDocumentReader returns a ValueReader using b for the underlying BSON
// representation. Parameter b must be a BSON Document.
func NewBSONDocumentReader(b []byte) ValueReader {
// TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes a []byte while the
// TODO writer takes an io.Writer. We should have two versions of each, one that takes a []byte and one that takes an
// TODO io.Reader or io.Writer. The []byte version will need to return a thing that can return the finished []byte since
// TODO it might be reallocated when appended to.
return newValueReader(b)
}
// NewBSONValueReader returns a ValueReader that starts in the Value mode instead of in top
// level document mode. This enables the creation of a ValueReader for a single BSON value.
func NewBSONValueReader(t bsontype.Type, val []byte) ValueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mValue,
vType: t,
}
return &valueReader{
d: val,
stack: stack,
}
}
func newValueReader(b []byte) *valueReader {
stack := make([]vrState, 1, 5)
stack[0] = vrState{
mode: mTopLevel,
}
return &valueReader{
d: b,
stack: stack,
}
}
func (vr *valueReader) reset(b []byte) {
if vr.stack == nil {
vr.stack = make([]vrState, 1, 5)
}
vr.stack = vr.stack[:1]
vr.stack[0] = vrState{mode: mTopLevel}
vr.d = b
vr.offset = 0
vr.frame = 0
}
func (vr *valueReader) advanceFrame() {
if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack
length := len(vr.stack)
if length+1 >= cap(vr.stack) {
// double it
buf := make([]vrState, 2*cap(vr.stack)+1)
copy(buf, vr.stack)
vr.stack = buf
}
vr.stack = vr.stack[:length+1]
}
vr.frame++
// Clean the stack
vr.stack[vr.frame].mode = 0
vr.stack[vr.frame].vType = 0
vr.stack[vr.frame].end = 0
}
func (vr *valueReader) pushDocument() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mDocument
size, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return nil
}
func (vr *valueReader) pushArray() error {
vr.advanceFrame()
vr.stack[vr.frame].mode = mArray
size, err := vr.readLength()
if err != nil {
return err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return nil
}
func (vr *valueReader) pushElement(t bsontype.Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mElement
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushValue(t bsontype.Type) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mValue
vr.stack[vr.frame].vType = t
}
func (vr *valueReader) pushCodeWithScope() (int64, error) {
vr.advanceFrame()
vr.stack[vr.frame].mode = mCodeWithScope
size, err := vr.readLength()
if err != nil {
return 0, err
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return int64(size), nil
}
func (vr *valueReader) pop() {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
vr.frame--
case mDocument, mArray, mCodeWithScope:
vr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc...
}
}
func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vr.stack[vr.frame].mode,
destination: destination,
modes: modes,
action: "read",
}
if vr.frame != 0 {
te.parent = vr.stack[vr.frame-1].mode
}
return te
}
func (vr *valueReader) typeError(t bsontype.Type) error {
return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t)
}
func (vr *valueReader) invalidDocumentLengthError() error {
return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset)
}
func (vr *valueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string) error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
if vr.stack[vr.frame].vType != t {
return vr.typeError(t)
}
default:
return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue})
}
return nil
}
func (vr *valueReader) Type() bsontype.Type {
return vr.stack[vr.frame].vType
}
func (vr *valueReader) nextElementLength() (int32, error) {
var length int32
var err error
switch vr.stack[vr.frame].vType {
case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope:
length, err = vr.peekLength()
case bsontype.Binary:
length, err = vr.peekLength()
length += 4 + 1 // binary length + subtype byte
case bsontype.Boolean:
length = 1
case bsontype.DBPointer:
length, err = vr.peekLength()
length += 4 + 12 // string length + ObjectID length
case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp:
length = 8
case bsontype.Decimal128:
length = 16
case bsontype.Int32:
length = 4
case bsontype.JavaScript, bsontype.String, bsontype.Symbol:
length, err = vr.peekLength()
length += 4
case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined:
length = 0
case bsontype.ObjectID:
length = 12
case bsontype.Regex:
regex := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if regex < 0 {
err = io.EOF
break
}
pattern := bytes.IndexByte(vr.d[vr.offset+int64(regex)+1:], 0x00)
if pattern < 0 {
err = io.EOF
break
}
length = int32(int64(regex) + 1 + int64(pattern) + 1)
default:
return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType)
}
return length, err
}
func (vr *valueReader) ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
length, err := vr.peekLength()
if err != nil {
return bsontype.Type(0), nil, err
}
dst, err = vr.appendBytes(dst, length)
if err != nil {
return bsontype.Type(0), nil, err
}
return bsontype.Type(0), dst, nil
case mElement, mValue:
length, err := vr.nextElementLength()
if err != nil {
return bsontype.Type(0), dst, err
}
dst, err = vr.appendBytes(dst, length)
t := vr.stack[vr.frame].vType
vr.pop()
return t, dst, err
default:
return bsontype.Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue})
}
}
func (vr *valueReader) Skip() error {
switch vr.stack[vr.frame].mode {
case mElement, mValue:
default:
return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue})
}
length, err := vr.nextElementLength()
if err != nil {
return err
}
err = vr.skipBytes(length)
vr.pop()
return err
}
func (vr *valueReader) ReadArray() (ArrayReader, error) {
if err := vr.ensureElementValue(bsontype.Array, mArray, "ReadArray"); err != nil {
return nil, err
}
err := vr.pushArray()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) {
if err := vr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil {
return nil, 0, err
}
length, err := vr.readLength()
if err != nil {
return nil, 0, err
}
btype, err = vr.readByte()
if err != nil {
return nil, 0, err
}
// Check length in case it is an old binary without a length.
if btype == 0x02 && length > 4 {
length, err = vr.readLength()
if err != nil {
return nil, 0, err
}
}
b, err = vr.readBytes(length)
if err != nil {
return nil, 0, err
}
// Make a copy of the returned byte slice because it's just a subslice from the valueReader's
// buffer and is not safe to return in the unmarshaled value.
cp := make([]byte, len(b))
copy(cp, b)
vr.pop()
return cp, btype, nil
}
func (vr *valueReader) ReadBoolean() (bool, error) {
if err := vr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil {
return false, err
}
b, err := vr.readByte()
if err != nil {
return false, err
}
if b > 1 {
return false, fmt.Errorf("invalid byte for boolean, %b", b)
}
vr.pop()
return b == 1, nil
}
func (vr *valueReader) ReadDocument() (DocumentReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel:
// read size
size, err := vr.readLength()
if err != nil {
return nil, err
}
if int(size) != len(vr.d) {
return nil, fmt.Errorf("invalid document length")
}
vr.stack[vr.frame].end = int64(size) + vr.offset - 4
return vr, nil
case mElement, mValue:
if vr.stack[vr.frame].vType != bsontype.EmbeddedDocument {
return nil, vr.typeError(bsontype.EmbeddedDocument)
}
default:
return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue})
}
err := vr.pushDocument()
if err != nil {
return nil, err
}
return vr, nil
}
func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
if err := vr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil {
return "", nil, err
}
totalLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
strLength, err := vr.readLength()
if err != nil {
return "", nil, err
}
if strLength <= 0 {
return "", nil, fmt.Errorf("invalid string length: %d", strLength)
}
strBytes, err := vr.readBytes(strLength)
if err != nil {
return "", nil, err
}
code = string(strBytes[:len(strBytes)-1])
size, err := vr.pushCodeWithScope()
if err != nil {
return "", nil, err
}
// The total length should equal:
// 4 (total length) + strLength + 4 (the length of str itself) + (document length)
componentsLength := int64(4+strLength+4) + size
if int64(totalLength) != componentsLength {
return "", nil, fmt.Errorf(
"length of CodeWithScope does not match lengths of components; total: %d; components: %d",
totalLength, componentsLength,
)
}
return code, vr, nil
}
func (vr *valueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
if err := vr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil {
return "", oid, err
}
ns, err = vr.readString()
if err != nil {
return "", oid, err
}
oidbytes, err := vr.readBytes(12)
if err != nil {
return "", oid, err
}
copy(oid[:], oidbytes)
vr.pop()
return ns, oid, nil
}
func (vr *valueReader) ReadDateTime() (int64, error) {
if err := vr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil {
return 0, err
}
i, err := vr.readi64()
if err != nil {
return 0, err
}
vr.pop()
return i, nil
}
func (vr *valueReader) ReadDecimal128() (primitive.Decimal128, error) {
if err := vr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil {
return primitive.Decimal128{}, err
}
b, err := vr.readBytes(16)
if err != nil {
return primitive.Decimal128{}, err
}
l := binary.LittleEndian.Uint64(b[0:8])
h := binary.LittleEndian.Uint64(b[8:16])
vr.pop()
return primitive.NewDecimal128(h, l), nil
}
func (vr *valueReader) ReadDouble() (float64, error) {
if err := vr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil {
return 0, err
}
u, err := vr.readu64()
if err != nil {
return 0, err
}
vr.pop()
return math.Float64frombits(u), nil
}
func (vr *valueReader) ReadInt32() (int32, error) {
if err := vr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil {
return 0, err
}
vr.pop()
return vr.readi32()
}
func (vr *valueReader) ReadInt64() (int64, error) {
if err := vr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil {
return 0, err
}
vr.pop()
return vr.readi64()
}
func (vr *valueReader) ReadJavascript() (code string, err error) {
if err := vr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadMaxKey() error {
if err := vr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadMinKey() error {
if err := vr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadNull() error {
if err := vr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadObjectID() (primitive.ObjectID, error) {
if err := vr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil {
return primitive.ObjectID{}, err
}
oidbytes, err := vr.readBytes(12)
if err != nil {
return primitive.ObjectID{}, err
}
var oid primitive.ObjectID
copy(oid[:], oidbytes)
vr.pop()
return oid, nil
}
func (vr *valueReader) ReadRegex() (string, string, error) {
if err := vr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil {
return "", "", err
}
pattern, err := vr.readCString()
if err != nil {
return "", "", err
}
options, err := vr.readCString()
if err != nil {
return "", "", err
}
vr.pop()
return pattern, options, nil
}
func (vr *valueReader) ReadString() (string, error) {
if err := vr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadSymbol() (symbol string, err error) {
if err := vr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil {
return "", err
}
vr.pop()
return vr.readString()
}
func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) {
if err := vr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil {
return 0, 0, err
}
i, err = vr.readu32()
if err != nil {
return 0, 0, err
}
t, err = vr.readu32()
if err != nil {
return 0, 0, err
}
vr.pop()
return t, i, nil
}
func (vr *valueReader) ReadUndefined() error {
if err := vr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil {
return err
}
vr.pop()
return nil
}
func (vr *valueReader) ReadElement() (string, ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mTopLevel, mDocument, mCodeWithScope:
default:
return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope})
}
t, err := vr.readByte()
if err != nil {
return "", nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return "", nil, vr.invalidDocumentLengthError()
}
vr.pop()
return "", nil, ErrEOD
}
name, err := vr.readCString()
if err != nil {
return "", nil, err
}
vr.pushElement(bsontype.Type(t))
return name, vr, nil
}
func (vr *valueReader) ReadValue() (ValueReader, error) {
switch vr.stack[vr.frame].mode {
case mArray:
default:
return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray})
}
t, err := vr.readByte()
if err != nil {
return nil, err
}
if t == 0 {
if vr.offset != vr.stack[vr.frame].end {
return nil, vr.invalidDocumentLengthError()
}
vr.pop()
return nil, ErrEOA
}
_, err = vr.readCString()
if err != nil {
return nil, err
}
vr.pushValue(bsontype.Type(t))
return vr, nil
}
// readBytes reads length bytes from the valueReader starting at the current offset. Note that the
// returned byte slice is a subslice from the valueReader buffer and must be converted or copied
// before returning in an unmarshaled value.
func (vr *valueReader) readBytes(length int32) ([]byte, error) {
if length < 0 {
return nil, fmt.Errorf("invalid length: %d", length)
}
if vr.offset+int64(length) > int64(len(vr.d)) {
return nil, io.EOF
}
start := vr.offset
vr.offset += int64(length)
return vr.d[start : start+int64(length)], nil
}
func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) {
if vr.offset+int64(length) > int64(len(vr.d)) {
return nil, io.EOF
}
start := vr.offset
vr.offset += int64(length)
return append(dst, vr.d[start:start+int64(length)]...), nil
}
func (vr *valueReader) skipBytes(length int32) error {
if vr.offset+int64(length) > int64(len(vr.d)) {
return io.EOF
}
vr.offset += int64(length)
return nil
}
func (vr *valueReader) readByte() (byte, error) {
if vr.offset+1 > int64(len(vr.d)) {
return 0x0, io.EOF
}
vr.offset++
return vr.d[vr.offset-1], nil
}
func (vr *valueReader) readCString() (string, error) {
idx := bytes.IndexByte(vr.d[vr.offset:], 0x00)
if idx < 0 {
return "", io.EOF
}
start := vr.offset
// idx does not include the null byte
vr.offset += int64(idx) + 1
return string(vr.d[start : start+int64(idx)]), nil
}
func (vr *valueReader) readString() (string, error) {
length, err := vr.readLength()
if err != nil {
return "", err
}
if int64(length)+vr.offset > int64(len(vr.d)) {
return "", io.EOF
}
if length <= 0 {
return "", fmt.Errorf("invalid string length: %d", length)
}
if vr.d[vr.offset+int64(length)-1] != 0x00 {
return "", fmt.Errorf("string does not end with null byte, but with %v", vr.d[vr.offset+int64(length)-1])
}
start := vr.offset
vr.offset += int64(length)
return string(vr.d[start : start+int64(length)-1]), nil
}
func (vr *valueReader) peekLength() (int32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }
func (vr *valueReader) readi32() (int32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 4
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readu32() (uint32, error) {
if vr.offset+4 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 4
return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil
}
func (vr *valueReader) readi64() (int64, error) {
if vr.offset+8 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 8
return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 |
int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil
}
func (vr *valueReader) readu64() (uint64, error) {
if vr.offset+8 > int64(len(vr.d)) {
return 0, io.EOF
}
idx := vr.offset
vr.offset += 8
return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 |
uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,608 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"testing"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
type VRWInvoked byte
const (
_ = iota
llvrwReadArray VRWInvoked = 1
llvrwReadBinary
llvrwReadBoolean
llvrwReadDocument
llvrwReadCodeWithScope
llvrwReadDBPointer
llvrwReadDateTime
llvrwReadDecimal128
llvrwReadDouble
llvrwReadInt32
llvrwReadInt64
llvrwReadJavascript
llvrwReadMaxKey
llvrwReadMinKey
llvrwReadNull
llvrwReadObjectID
llvrwReadRegex
llvrwReadString
llvrwReadSymbol
llvrwReadTimestamp
llvrwReadUndefined
llvrwReadElement
llvrwReadValue
llvrwWriteArray
llvrwWriteBinary
llvrwWriteBinaryWithSubtype
llvrwWriteBoolean
llvrwWriteCodeWithScope
llvrwWriteDBPointer
llvrwWriteDateTime
llvrwWriteDecimal128
llvrwWriteDouble
llvrwWriteInt32
llvrwWriteInt64
llvrwWriteJavascript
llvrwWriteMaxKey
llvrwWriteMinKey
llvrwWriteNull
llvrwWriteObjectID
llvrwWriteRegex
llvrwWriteString
llvrwWriteDocument
llvrwWriteSymbol
llvrwWriteTimestamp
llvrwWriteUndefined
llvrwWriteDocumentElement
llvrwWriteDocumentEnd
llvrwWriteArrayElement
llvrwWriteArrayEnd
)
type TestValueReaderWriter struct {
t *testing.T
invoked VRWInvoked
readval interface{}
bsontype bsontype.Type
err error
errAfter VRWInvoked // error after this method is called
}
func (llvrw *TestValueReaderWriter) Type() bsontype.Type {
return llvrw.bsontype
}
func (llvrw *TestValueReaderWriter) Skip() error {
panic("not implemented")
}
func (llvrw *TestValueReaderWriter) ReadArray() (ArrayReader, error) {
llvrw.invoked = llvrwReadArray
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) ReadBinary() (b []byte, btype byte, err error) {
llvrw.invoked = llvrwReadBinary
if llvrw.errAfter == llvrw.invoked {
return nil, 0x00, llvrw.err
}
switch tt := llvrw.readval.(type) {
case bsoncore.Value:
subtype, data, _, ok := bsoncore.ReadBinary(tt.Data)
if !ok {
llvrw.t.Error("Invalid Value provided for return value of ReadBinary.")
return nil, 0x00, nil
}
return data, subtype, nil
default:
llvrw.t.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.readval)
return nil, 0x00, nil
}
}
func (llvrw *TestValueReaderWriter) ReadBoolean() (bool, error) {
llvrw.invoked = llvrwReadBoolean
if llvrw.errAfter == llvrw.invoked {
return false, llvrw.err
}
b, ok := llvrw.readval.(bool)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.readval)
return false, nil
}
return b, llvrw.err
}
func (llvrw *TestValueReaderWriter) ReadDocument() (DocumentReader, error) {
llvrw.invoked = llvrwReadDocument
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) ReadCodeWithScope() (code string, dr DocumentReader, err error) {
llvrw.invoked = llvrwReadCodeWithScope
if llvrw.errAfter == llvrw.invoked {
return "", nil, llvrw.err
}
return "", llvrw, nil
}
func (llvrw *TestValueReaderWriter) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) {
llvrw.invoked = llvrwReadDBPointer
if llvrw.errAfter == llvrw.invoked {
return "", primitive.ObjectID{}, llvrw.err
}
switch tt := llvrw.readval.(type) {
case bsoncore.Value:
ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data)
if !ok {
llvrw.t.Error("Invalid Value instance provided for return value of ReadDBPointer")
return "", primitive.ObjectID{}, nil
}
return ns, oid, nil
default:
llvrw.t.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.readval)
return "", primitive.ObjectID{}, nil
}
}
func (llvrw *TestValueReaderWriter) ReadDateTime() (int64, error) {
llvrw.invoked = llvrwReadDateTime
if llvrw.errAfter == llvrw.invoked {
return 0, llvrw.err
}
dt, ok := llvrw.readval.(int64)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.readval)
return 0, nil
}
return dt, nil
}
func (llvrw *TestValueReaderWriter) ReadDecimal128() (primitive.Decimal128, error) {
llvrw.invoked = llvrwReadDecimal128
if llvrw.errAfter == llvrw.invoked {
return primitive.Decimal128{}, llvrw.err
}
d128, ok := llvrw.readval.(primitive.Decimal128)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.readval)
return primitive.Decimal128{}, nil
}
return d128, nil
}
func (llvrw *TestValueReaderWriter) ReadDouble() (float64, error) {
llvrw.invoked = llvrwReadDouble
if llvrw.errAfter == llvrw.invoked {
return 0, llvrw.err
}
f64, ok := llvrw.readval.(float64)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.readval)
return 0, nil
}
return f64, nil
}
func (llvrw *TestValueReaderWriter) ReadInt32() (int32, error) {
llvrw.invoked = llvrwReadInt32
if llvrw.errAfter == llvrw.invoked {
return 0, llvrw.err
}
i32, ok := llvrw.readval.(int32)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.readval)
return 0, nil
}
return i32, nil
}
func (llvrw *TestValueReaderWriter) ReadInt64() (int64, error) {
llvrw.invoked = llvrwReadInt64
if llvrw.errAfter == llvrw.invoked {
return 0, llvrw.err
}
i64, ok := llvrw.readval.(int64)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.readval)
return 0, nil
}
return i64, nil
}
func (llvrw *TestValueReaderWriter) ReadJavascript() (code string, err error) {
llvrw.invoked = llvrwReadJavascript
if llvrw.errAfter == llvrw.invoked {
return "", llvrw.err
}
js, ok := llvrw.readval.(string)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.readval)
return "", nil
}
return js, nil
}
func (llvrw *TestValueReaderWriter) ReadMaxKey() error {
llvrw.invoked = llvrwReadMaxKey
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) ReadMinKey() error {
llvrw.invoked = llvrwReadMinKey
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) ReadNull() error {
llvrw.invoked = llvrwReadNull
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) ReadObjectID() (primitive.ObjectID, error) {
llvrw.invoked = llvrwReadObjectID
if llvrw.errAfter == llvrw.invoked {
return primitive.ObjectID{}, llvrw.err
}
oid, ok := llvrw.readval.(primitive.ObjectID)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.readval)
return primitive.ObjectID{}, nil
}
return oid, nil
}
func (llvrw *TestValueReaderWriter) ReadRegex() (pattern string, options string, err error) {
llvrw.invoked = llvrwReadRegex
if llvrw.errAfter == llvrw.invoked {
return "", "", llvrw.err
}
switch tt := llvrw.readval.(type) {
case bsoncore.Value:
pattern, options, _, ok := bsoncore.ReadRegex(tt.Data)
if !ok {
llvrw.t.Error("Invalid Value instance provided for ReadRegex")
return "", "", nil
}
return pattern, options, nil
default:
llvrw.t.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.readval)
return "", "", nil
}
}
func (llvrw *TestValueReaderWriter) ReadString() (string, error) {
llvrw.invoked = llvrwReadString
if llvrw.errAfter == llvrw.invoked {
return "", llvrw.err
}
str, ok := llvrw.readval.(string)
if !ok {
llvrw.t.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.readval)
return "", nil
}
return str, nil
}
func (llvrw *TestValueReaderWriter) ReadSymbol() (symbol string, err error) {
llvrw.invoked = llvrwReadSymbol
if llvrw.errAfter == llvrw.invoked {
return "", llvrw.err
}
switch tt := llvrw.readval.(type) {
case bsoncore.Value:
symbol, _, ok := bsoncore.ReadSymbol(tt.Data)
if !ok {
llvrw.t.Error("Invalid Value instance provided for ReadSymbol")
return "", nil
}
return symbol, nil
default:
llvrw.t.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.readval)
return "", nil
}
}
func (llvrw *TestValueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) {
llvrw.invoked = llvrwReadTimestamp
if llvrw.errAfter == llvrw.invoked {
return 0, 0, llvrw.err
}
switch tt := llvrw.readval.(type) {
case bsoncore.Value:
t, i, _, ok := bsoncore.ReadTimestamp(tt.Data)
if !ok {
llvrw.t.Errorf("Invalid Value instance provided for return value of ReadTimestamp")
return 0, 0, nil
}
return t, i, nil
default:
llvrw.t.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.readval)
return 0, 0, nil
}
}
func (llvrw *TestValueReaderWriter) ReadUndefined() error {
llvrw.invoked = llvrwReadUndefined
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteArray() (ArrayWriter, error) {
llvrw.invoked = llvrwWriteArray
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteBinary(b []byte) error {
llvrw.invoked = llvrwWriteBinary
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
llvrw.invoked = llvrwWriteBinaryWithSubtype
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteBoolean(bool) error {
llvrw.invoked = llvrwWriteBoolean
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
llvrw.invoked = llvrwWriteCodeWithScope
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
llvrw.invoked = llvrwWriteDBPointer
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteDateTime(dt int64) error {
llvrw.invoked = llvrwWriteDateTime
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteDecimal128(primitive.Decimal128) error {
llvrw.invoked = llvrwWriteDecimal128
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteDouble(float64) error {
llvrw.invoked = llvrwWriteDouble
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteInt32(int32) error {
llvrw.invoked = llvrwWriteInt32
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteInt64(int64) error {
llvrw.invoked = llvrwWriteInt64
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteJavascript(code string) error {
llvrw.invoked = llvrwWriteJavascript
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteMaxKey() error {
llvrw.invoked = llvrwWriteMaxKey
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteMinKey() error {
llvrw.invoked = llvrwWriteMinKey
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteNull() error {
llvrw.invoked = llvrwWriteNull
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteObjectID(primitive.ObjectID) error {
llvrw.invoked = llvrwWriteObjectID
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteRegex(pattern string, options string) error {
llvrw.invoked = llvrwWriteRegex
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteString(string) error {
llvrw.invoked = llvrwWriteString
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteDocument() (DocumentWriter, error) {
llvrw.invoked = llvrwWriteDocument
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteSymbol(symbol string) error {
llvrw.invoked = llvrwWriteSymbol
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteTimestamp(t uint32, i uint32) error {
llvrw.invoked = llvrwWriteTimestamp
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) WriteUndefined() error {
llvrw.invoked = llvrwWriteUndefined
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) ReadElement() (string, ValueReader, error) {
llvrw.invoked = llvrwReadElement
if llvrw.errAfter == llvrw.invoked {
return "", nil, llvrw.err
}
return "", llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteDocumentElement(string) (ValueWriter, error) {
llvrw.invoked = llvrwWriteDocumentElement
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteDocumentEnd() error {
llvrw.invoked = llvrwWriteDocumentEnd
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}
func (llvrw *TestValueReaderWriter) ReadValue() (ValueReader, error) {
llvrw.invoked = llvrwReadValue
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteArrayElement() (ValueWriter, error) {
llvrw.invoked = llvrwWriteArrayElement
if llvrw.errAfter == llvrw.invoked {
return nil, llvrw.err
}
return llvrw, nil
}
func (llvrw *TestValueReaderWriter) WriteArrayEnd() error {
llvrw.invoked = llvrwWriteArrayEnd
if llvrw.errAfter == llvrw.invoked {
return llvrw.err
}
return nil
}

View File

@@ -0,0 +1,606 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"errors"
"fmt"
"io"
"math"
"strconv"
"strings"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var _ ValueWriter = (*valueWriter)(nil)
var vwPool = sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
}
// BSONValueWriterPool is a pool for BSON ValueWriters.
type BSONValueWriterPool struct {
pool sync.Pool
}
// NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON.
func NewBSONValueWriterPool() *BSONValueWriterPool {
return &BSONValueWriterPool{
pool: sync.Pool{
New: func() interface{} {
return new(valueWriter)
},
},
}
}
// Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination.
func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter {
vw := bvwp.pool.Get().(*valueWriter)
// TODO: Having to call reset here with the same buffer doesn't really make sense.
vw.reset(vw.buf)
vw.buf = vw.buf[:0]
vw.w = w
return vw
}
// GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination.
func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher {
vw := bvwp.Get(w).(*valueWriter)
vw.push(mElement)
return vw
}
// Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing
// happens and ok will be false.
func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) {
bvw, ok := vw.(*valueWriter)
if !ok {
return false
}
bvwp.pool.Put(bvw)
return true
}
// This is here so that during testing we can change it and not require
// allocating a 4GB slice.
var maxSize = math.MaxInt32
var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer")
type errMaxDocumentSizeExceeded struct {
size int64
}
func (mdse errMaxDocumentSizeExceeded) Error() string {
return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size)
}
type vwMode int
const (
_ vwMode = iota
vwTopLevel
vwDocument
vwArray
vwValue
vwElement
vwCodeWithScope
)
func (vm vwMode) String() string {
var str string
switch vm {
case vwTopLevel:
str = "TopLevel"
case vwDocument:
str = "DocumentMode"
case vwArray:
str = "ArrayMode"
case vwValue:
str = "ValueMode"
case vwElement:
str = "ElementMode"
case vwCodeWithScope:
str = "CodeWithScopeMode"
default:
str = "UnknownMode"
}
return str
}
type vwState struct {
mode mode
key string
arrkey int
start int32
}
type valueWriter struct {
w io.Writer
buf []byte
stack []vwState
frame int64
}
func (vw *valueWriter) advanceFrame() {
if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack
length := len(vw.stack)
if length+1 >= cap(vw.stack) {
// double it
buf := make([]vwState, 2*cap(vw.stack)+1)
copy(buf, vw.stack)
vw.stack = buf
}
vw.stack = vw.stack[:length+1]
}
vw.frame++
}
func (vw *valueWriter) push(m mode) {
vw.advanceFrame()
// Clean the stack
vw.stack[vw.frame].mode = m
vw.stack[vw.frame].key = ""
vw.stack[vw.frame].arrkey = 0
vw.stack[vw.frame].start = 0
vw.stack[vw.frame].mode = m
switch m {
case mDocument, mArray, mCodeWithScope:
vw.reserveLength()
}
}
func (vw *valueWriter) reserveLength() {
vw.stack[vw.frame].start = int32(len(vw.buf))
vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00)
}
func (vw *valueWriter) pop() {
switch vw.stack[vw.frame].mode {
case mElement, mValue:
vw.frame--
case mDocument, mArray, mCodeWithScope:
vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc...
}
}
// NewBSONValueWriter creates a ValueWriter that writes BSON to w.
//
// This ValueWriter will only write entire documents to the io.Writer and it
// will buffer the document as it is built.
func NewBSONValueWriter(w io.Writer) (ValueWriter, error) {
if w == nil {
return nil, errNilWriter
}
return newValueWriter(w), nil
}
func newValueWriter(w io.Writer) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.w = w
vw.stack = stack
return vw
}
func newValueWriterFromSlice(buf []byte) *valueWriter {
vw := new(valueWriter)
stack := make([]vwState, 1, 5)
stack[0] = vwState{mode: mTopLevel}
vw.stack = stack
vw.buf = buf
return vw
}
func (vw *valueWriter) reset(buf []byte) {
if vw.stack == nil {
vw.stack = make([]vwState, 1, 5)
}
vw.stack = vw.stack[:1]
vw.stack[0] = vwState{mode: mTopLevel}
vw.buf = buf
vw.frame = 0
vw.w = nil
}
func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error {
te := TransitionError{
name: name,
current: vw.stack[vw.frame].mode,
destination: destination,
modes: modes,
action: "write",
}
if vw.frame != 0 {
te.parent = vw.stack[vw.frame-1].mode
}
return te
}
func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error {
switch vw.stack[vw.frame].mode {
case mElement:
key := vw.stack[vw.frame].key
if !isValidCString(key) {
return errors.New("BSON element key cannot contain null bytes")
}
vw.buf = bsoncore.AppendHeader(vw.buf, t, key)
case mValue:
// TODO: Do this with a cache of the first 1000 or so array keys.
vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey))
default:
modes := []mode{mElement, mValue}
if addmodes != nil {
modes = append(modes, addmodes...)
}
return vw.invalidTransitionError(destination, callerName, modes)
}
return nil
}
func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error {
if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil {
return err
}
vw.buf = append(vw.buf, b...)
vw.pop()
return nil
}
func (vw *valueWriter) WriteArray() (ArrayWriter, error) {
if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil {
return nil, err
}
vw.push(mArray)
return vw, nil
}
func (vw *valueWriter) WriteBinary(b []byte) error {
return vw.WriteBinaryWithSubtype(b, 0x00)
}
func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error {
if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil {
return err
}
vw.buf = bsoncore.AppendBinary(vw.buf, btype, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteBoolean(b bool) error {
if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil {
return err
}
vw.buf = bsoncore.AppendBoolean(vw.buf, b)
vw.pop()
return nil
}
func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) {
if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil {
return nil, err
}
// CodeWithScope is a different than other types because we need an extra
// frame on the stack. In the EndDocument code, we write the document
// length, pop, write the code with scope length, and pop. To simplify the
// pop code, we push a spacer frame that we'll always jump over.
vw.push(mCodeWithScope)
vw.buf = bsoncore.AppendString(vw.buf, code)
vw.push(mSpacer)
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error {
if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil {
return err
}
vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDateTime(dt int64) error {
if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil {
return err
}
vw.buf = bsoncore.AppendDateTime(vw.buf, dt)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error {
if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil {
return err
}
vw.buf = bsoncore.AppendDecimal128(vw.buf, d128)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDouble(f float64) error {
if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil {
return err
}
vw.buf = bsoncore.AppendDouble(vw.buf, f)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt32(i32 int32) error {
if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt32(vw.buf, i32)
vw.pop()
return nil
}
func (vw *valueWriter) WriteInt64(i64 int64) error {
if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil {
return err
}
vw.buf = bsoncore.AppendInt64(vw.buf, i64)
vw.pop()
return nil
}
func (vw *valueWriter) WriteJavascript(code string) error {
if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil {
return err
}
vw.buf = bsoncore.AppendJavaScript(vw.buf, code)
vw.pop()
return nil
}
func (vw *valueWriter) WriteMaxKey() error {
if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteMinKey() error {
if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteNull() error {
if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error {
if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil {
return err
}
vw.buf = bsoncore.AppendObjectID(vw.buf, oid)
vw.pop()
return nil
}
func (vw *valueWriter) WriteRegex(pattern string, options string) error {
if !isValidCString(pattern) || !isValidCString(options) {
return errors.New("BSON regex values cannot contain null bytes")
}
if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil {
return err
}
vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options))
vw.pop()
return nil
}
func (vw *valueWriter) WriteString(s string) error {
if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil {
return err
}
vw.buf = bsoncore.AppendString(vw.buf, s)
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocument() (DocumentWriter, error) {
if vw.stack[vw.frame].mode == mTopLevel {
vw.reserveLength()
return vw, nil
}
if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil {
return nil, err
}
vw.push(mDocument)
return vw, nil
}
func (vw *valueWriter) WriteSymbol(symbol string) error {
if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil {
return err
}
vw.buf = bsoncore.AppendSymbol(vw.buf, symbol)
vw.pop()
return nil
}
func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error {
if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil {
return err
}
vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i)
vw.pop()
return nil
}
func (vw *valueWriter) WriteUndefined() error {
if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil {
return err
}
vw.pop()
return nil
}
func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument})
}
vw.push(mElement)
vw.stack[vw.frame].key = key
return vw, nil
}
func (vw *valueWriter) WriteDocumentEnd() error {
switch vw.stack[vw.frame].mode {
case mTopLevel, mDocument:
default:
return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
if vw.stack[vw.frame].mode == mTopLevel {
if err = vw.Flush(); err != nil {
return err
}
}
vw.pop()
if vw.stack[vw.frame].mode == mCodeWithScope {
// We ignore the error here because of the guarantee of writeLength.
// See the docs for writeLength for more info.
_ = vw.writeLength()
vw.pop()
}
return nil
}
func (vw *valueWriter) Flush() error {
if vw.w == nil {
return nil
}
if _, err := vw.w.Write(vw.buf); err != nil {
return err
}
// reset buffer
vw.buf = vw.buf[:0]
return nil
}
func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) {
if vw.stack[vw.frame].mode != mArray {
return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray})
}
arrkey := vw.stack[vw.frame].arrkey
vw.stack[vw.frame].arrkey++
vw.push(mValue)
vw.stack[vw.frame].arrkey = arrkey
return vw, nil
}
func (vw *valueWriter) WriteArrayEnd() error {
if vw.stack[vw.frame].mode != mArray {
return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode)
}
vw.buf = append(vw.buf, 0x00)
err := vw.writeLength()
if err != nil {
return err
}
vw.pop()
return nil
}
// NOTE: We assume that if we call writeLength more than once the same function
// within the same function without altering the vw.buf that this method will
// not return an error. If this changes ensure that the following methods are
// updated:
//
// - WriteDocumentEnd
func (vw *valueWriter) writeLength() error {
length := len(vw.buf)
if length > maxSize {
return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))}
}
length = length - int(vw.stack[vw.frame].start)
start := vw.stack[vw.frame].start
vw.buf[start+0] = byte(length)
vw.buf[start+1] = byte(length >> 8)
vw.buf[start+2] = byte(length >> 16)
vw.buf[start+3] = byte(length >> 24)
return nil
}
func isValidCString(cs string) bool {
return !strings.ContainsRune(cs, '\x00')
}

View File

@@ -0,0 +1,368 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"math"
"reflect"
"strings"
"testing"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestNewBSONValueWriter(t *testing.T) {
_, got := NewBSONValueWriter(nil)
want := errNilWriter
if !compareErrors(got, want) {
t.Errorf("Returned error did not match what was expected. got %v; want %v", got, want)
}
vw, got := NewBSONValueWriter(errWriter{})
want = nil
if !compareErrors(got, want) {
t.Errorf("Returned error did not match what was expected. got %v; want %v", got, want)
}
if vw == nil {
t.Errorf("Expected non-nil ValueWriter to be returned from NewBSONValueWriter")
}
}
func TestValueWriter(t *testing.T) {
header := []byte{0x00, 0x00, 0x00, 0x00}
oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C}
testCases := []struct {
name string
fn interface{}
params []interface{}
want []byte
}{
{
"WriteBinary",
(*valueWriter).WriteBinary,
[]interface{}{[]byte{0x01, 0x02, 0x03}},
bsoncore.AppendBinaryElement(header, "foo", 0x00, []byte{0x01, 0x02, 0x03}),
},
{
"WriteBinaryWithSubtype (not 0x02)",
(*valueWriter).WriteBinaryWithSubtype,
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)},
bsoncore.AppendBinaryElement(header, "foo", 0xFF, []byte{0x01, 0x02, 0x03}),
},
{
"WriteBinaryWithSubtype (0x02)",
(*valueWriter).WriteBinaryWithSubtype,
[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)},
bsoncore.AppendBinaryElement(header, "foo", 0x02, []byte{0x01, 0x02, 0x03}),
},
{
"WriteBoolean",
(*valueWriter).WriteBoolean,
[]interface{}{true},
bsoncore.AppendBooleanElement(header, "foo", true),
},
{
"WriteDBPointer",
(*valueWriter).WriteDBPointer,
[]interface{}{"bar", oid},
bsoncore.AppendDBPointerElement(header, "foo", "bar", oid),
},
{
"WriteDateTime",
(*valueWriter).WriteDateTime,
[]interface{}{int64(12345678)},
bsoncore.AppendDateTimeElement(header, "foo", 12345678),
},
{
"WriteDecimal128",
(*valueWriter).WriteDecimal128,
[]interface{}{primitive.NewDecimal128(10, 20)},
bsoncore.AppendDecimal128Element(header, "foo", primitive.NewDecimal128(10, 20)),
},
{
"WriteDouble",
(*valueWriter).WriteDouble,
[]interface{}{float64(3.14159)},
bsoncore.AppendDoubleElement(header, "foo", 3.14159),
},
{
"WriteInt32",
(*valueWriter).WriteInt32,
[]interface{}{int32(123456)},
bsoncore.AppendInt32Element(header, "foo", 123456),
},
{
"WriteInt64",
(*valueWriter).WriteInt64,
[]interface{}{int64(1234567890)},
bsoncore.AppendInt64Element(header, "foo", 1234567890),
},
{
"WriteJavascript",
(*valueWriter).WriteJavascript,
[]interface{}{"var foo = 'bar';"},
bsoncore.AppendJavaScriptElement(header, "foo", "var foo = 'bar';"),
},
{
"WriteMaxKey",
(*valueWriter).WriteMaxKey,
[]interface{}{},
bsoncore.AppendMaxKeyElement(header, "foo"),
},
{
"WriteMinKey",
(*valueWriter).WriteMinKey,
[]interface{}{},
bsoncore.AppendMinKeyElement(header, "foo"),
},
{
"WriteNull",
(*valueWriter).WriteNull,
[]interface{}{},
bsoncore.AppendNullElement(header, "foo"),
},
{
"WriteObjectID",
(*valueWriter).WriteObjectID,
[]interface{}{oid},
bsoncore.AppendObjectIDElement(header, "foo", oid),
},
{
"WriteRegex",
(*valueWriter).WriteRegex,
[]interface{}{"bar", "baz"},
bsoncore.AppendRegexElement(header, "foo", "bar", "abz"),
},
{
"WriteString",
(*valueWriter).WriteString,
[]interface{}{"hello, world!"},
bsoncore.AppendStringElement(header, "foo", "hello, world!"),
},
{
"WriteSymbol",
(*valueWriter).WriteSymbol,
[]interface{}{"symbollolz"},
bsoncore.AppendSymbolElement(header, "foo", "symbollolz"),
},
{
"WriteTimestamp",
(*valueWriter).WriteTimestamp,
[]interface{}{uint32(10), uint32(20)},
bsoncore.AppendTimestampElement(header, "foo", 10, 20),
},
{
"WriteUndefined",
(*valueWriter).WriteUndefined,
[]interface{}{},
bsoncore.AppendUndefinedElement(header, "foo"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
fn := reflect.ValueOf(tc.fn)
if fn.Kind() != reflect.Func {
t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind())
}
if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*valueWriter)(nil)) {
t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter")
}
if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
t.Fatalf("fn must have one return value and it must be an error.")
}
params := make([]reflect.Value, 1, len(tc.params)+1)
vw := newValueWriter(ioutil.Discard)
params[0] = reflect.ValueOf(vw)
for _, param := range tc.params {
params = append(params, reflect.ValueOf(param))
}
_, err := vw.WriteDocument()
noerr(t, err)
_, err = vw.WriteDocumentElement("foo")
noerr(t, err)
results := fn.Call(params)
if !results[0].IsValid() {
err = results[0].Interface().(error)
} else {
err = nil
}
noerr(t, err)
got := vw.buf
want := tc.want
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal.\n\tgot %v\n\twant %v", got, want)
}
t.Run("incorrect transition", func(t *testing.T) {
vw = newValueWriter(ioutil.Discard)
results := fn.Call(params)
got := results[0].Interface().(error)
fnName := tc.name
if strings.Contains(fnName, "WriteBinary") {
fnName = "WriteBinaryWithSubtype"
}
want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue},
action: "write"}
if !compareErrors(got, want) {
t.Errorf("Errors do not match. got %v; want %v", got, want)
}
})
})
}
t.Run("WriteArray", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mArray)
want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel,
name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"}
_, got := vw.WriteArray()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteCodeWithScope", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mArray)
want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel,
name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"}
_, got := vw.WriteCodeWithScope("")
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocument", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mArray)
want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel,
name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"}
_, got := vw.WriteDocument()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocumentElement", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mElement)
want := TransitionError{current: mElement,
destination: mElement,
parent: mTopLevel,
name: "WriteDocumentElement",
modes: []mode{mTopLevel, mDocument},
action: "write"}
_, got := vw.WriteDocumentElement("")
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteDocumentEnd", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mElement)
want := fmt.Errorf("incorrect mode to end document: %s", mElement)
got := vw.WriteDocumentEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
vw.pop()
vw.buf = append(vw.buf, make([]byte, 1023)...)
maxSize = 512
want = errMaxDocumentSizeExceeded{size: 1024}
got = vw.WriteDocumentEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
maxSize = math.MaxInt32
want = errors.New("what a nice fake error we have here")
vw.w = errWriter{err: want}
got = vw.WriteDocumentEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteArrayElement", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mElement)
want := TransitionError{current: mElement,
destination: mValue,
parent: mTopLevel,
name: "WriteArrayElement",
modes: []mode{mArray},
action: "write"}
_, got := vw.WriteArrayElement()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
})
t.Run("WriteArrayEnd", func(t *testing.T) {
vw := newValueWriter(ioutil.Discard)
vw.push(mElement)
want := fmt.Errorf("incorrect mode to end array: %s", mElement)
got := vw.WriteArrayEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
vw.push(mArray)
vw.buf = append(vw.buf, make([]byte, 1019)...)
maxSize = 512
want = errMaxDocumentSizeExceeded{size: 1024}
got = vw.WriteArrayEnd()
if !compareErrors(got, want) {
t.Errorf("Did not get expected error. got %v; want %v", got, want)
}
maxSize = math.MaxInt32
})
t.Run("WriteBytes", func(t *testing.T) {
t.Run("writeElementHeader error", func(t *testing.T) {
vw := newValueWriterFromSlice(nil)
want := TransitionError{current: mTopLevel, destination: mode(0),
name: "WriteValueBytes", modes: []mode{mElement, mValue}, action: "write"}
got := vw.WriteValueBytes(bsontype.EmbeddedDocument, nil)
if !compareErrors(got, want) {
t.Errorf("Did not received expected error. got %v; want %v", got, want)
}
})
t.Run("success", func(t *testing.T) {
index, doc := bsoncore.ReserveLength(nil)
doc = bsoncore.AppendStringElement(doc, "hello", "world")
doc = append(doc, 0x00)
doc = bsoncore.UpdateLength(doc, index, int32(len(doc)))
index, want := bsoncore.ReserveLength(nil)
want = bsoncore.AppendDocumentElement(want, "foo", doc)
want = append(want, 0x00)
want = bsoncore.UpdateLength(want, index, int32(len(want)))
vw := newValueWriterFromSlice(make([]byte, 0, 512))
_, err := vw.WriteDocument()
noerr(t, err)
_, err = vw.WriteDocumentElement("foo")
noerr(t, err)
err = vw.WriteValueBytes(bsontype.EmbeddedDocument, doc)
noerr(t, err)
err = vw.WriteDocumentEnd()
noerr(t, err)
got := vw.buf
if !bytes.Equal(got, want) {
t.Errorf("Bytes are not equal. got %v; want %v", got, want)
}
})
})
}
type errWriter struct {
err error
}
func (ew errWriter) Write([]byte) (int, error) { return 0, ew.err }

View File

@@ -0,0 +1,78 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsonrw
import (
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// ArrayWriter is the interface used to create a BSON or BSON adjacent array.
// Callers must ensure they call WriteArrayEnd when they have finished creating
// the array.
type ArrayWriter interface {
WriteArrayElement() (ValueWriter, error)
WriteArrayEnd() error
}
// DocumentWriter is the interface used to create a BSON or BSON adjacent
// document. Callers must ensure they call WriteDocumentEnd when they have
// finished creating the document.
type DocumentWriter interface {
WriteDocumentElement(string) (ValueWriter, error)
WriteDocumentEnd() error
}
// ValueWriter is the interface used to write BSON values. Implementations of
// this interface handle creating BSON or BSON adjacent representations of the
// values.
type ValueWriter interface {
WriteArray() (ArrayWriter, error)
WriteBinary(b []byte) error
WriteBinaryWithSubtype(b []byte, btype byte) error
WriteBoolean(bool) error
WriteCodeWithScope(code string) (DocumentWriter, error)
WriteDBPointer(ns string, oid primitive.ObjectID) error
WriteDateTime(dt int64) error
WriteDecimal128(primitive.Decimal128) error
WriteDouble(float64) error
WriteInt32(int32) error
WriteInt64(int64) error
WriteJavascript(code string) error
WriteMaxKey() error
WriteMinKey() error
WriteNull() error
WriteObjectID(primitive.ObjectID) error
WriteRegex(pattern, options string) error
WriteString(string) error
WriteDocument() (DocumentWriter, error)
WriteSymbol(symbol string) error
WriteTimestamp(t, i uint32) error
WriteUndefined() error
}
// ValueWriterFlusher is a superset of ValueWriter that exposes functionality to flush to the underlying buffer.
type ValueWriterFlusher interface {
ValueWriter
Flush() error
}
// BytesWriter is the interface used to write BSON bytes to a ValueWriter.
// This interface is meant to be a superset of ValueWriter, so that types that
// implement ValueWriter may also implement this interface.
type BytesWriter interface {
WriteValueBytes(t bsontype.Type, b []byte) error
}
// SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer.
type SliceWriter []byte
func (sw *SliceWriter) Write(p []byte) (int, error) {
written := len(p)
*sw = append(*sw, p...)
return written, nil
}