Compare commits
	
		
			25 Commits
		
	
	
		
			v0.0.141
			...
			feature/mo
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 9d88ab3a2b | |||
| 2ec88e81f3 | |||
| d471d7c396 | |||
| 21d241f9b1 | |||
| 2569c165f8 | |||
| ee262a94fb | |||
| 7977c0e59c | |||
| ceff0161c6 | |||
| a30da61419 | |||
| b613b122e3 | |||
| d017530444 | |||
| 8de83cc290 | |||
| 603ec82b83 | |||
| 93c4cf31a8 | |||
| dc2d8a9103 | |||
| 6589e8d5cd | |||
| 0006c6859d | |||
| 827b3fc1b7 | |||
| f7dce4a102 | |||
| 45d4fd7101 | |||
| c7df9d2264 | |||
| d0954bf133 | |||
| 8affa81bb9 | |||
| fe9ebf0bab | |||
| a4b5f33d15 | 
							
								
								
									
										6
									
								
								.idea/goext.iml
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										6
									
								
								.idea/goext.iml
									
									
									
										generated
									
									
									
								
							| @@ -1,6 +1,10 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <module type="WEB_MODULE" version="4"> | ||||
|   <component name="Go" enabled="true" /> | ||||
|   <component name="Go" enabled="true"> | ||||
|     <buildTags> | ||||
|       <option name="goVersion" value="1.19" /> | ||||
|     </buildTags> | ||||
|   </component> | ||||
|   <component name="NewModuleRootManager"> | ||||
|     <content url="file://$MODULE_DIR$" /> | ||||
|     <orderEntry type="inheritedJdk" /> | ||||
|   | ||||
							
								
								
									
										5
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								Makefile
									
									
									
									
									
								
							| @@ -1,4 +1,6 @@ | ||||
|  | ||||
| .PHONY: run test version update-mongo | ||||
|  | ||||
| run: | ||||
| 	echo "This is a library - can't be run" && false | ||||
|  | ||||
| @@ -9,3 +11,6 @@ test: | ||||
|  | ||||
| version: | ||||
| 	_data/version.sh | ||||
|  | ||||
| update-mongo: | ||||
| 	_data/update-mongo.sh | ||||
							
								
								
									
										57
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										57
									
								
								README.md
									
									
									
									
									
								
							| @@ -10,31 +10,32 @@ Potentially needs `export GOPRIVATE="gogs.mikescher.com"` | ||||
|  | ||||
| ### Packages: | ||||
|  | ||||
| | Name        | Maintainer | Description                                                                                                   | | ||||
| |-------------|------------|---------------------------------------------------------------------------------------------------------------| | ||||
| | langext     | Mike       | General uttility/helper functions, (everything thats missing from go standard library)                        | | ||||
| | mathext     | Mike       | Utility/Helper functions for math                                                                             | | ||||
| | cryptext    | Mike       | Utility/Helper functions for encryption                                                                       | | ||||
| | syncext     | Mike       | Utility/Helper funtions for multi-threading / mutex / channels                                                | | ||||
| | dataext     | Mike       | Various useful data structures                                                                                | | ||||
| | zipext      | Mike       | Utility for zip/gzip/tar etc                                                                                  | | ||||
| |             |            |                                                                                                               | | ||||
| | mongoext    | Mike       | Utility/Helper functions for mongodb                                                                          | | ||||
| | cursortoken | Mike       | MongoDB cursortoken implementation                                                                            | | ||||
| |             |            |                                                                                                               | | ||||
| | totpext     | Mike       | Implementation of TOTP (2-Factor-Auth)                                                                        | | ||||
| | termext     | Mike       | Utilities for terminals (mostly color output)                                                                 | | ||||
| | confext     | Mike       | Parses environment configuration into structs                                                                 | | ||||
| | cmdext      | Mike       | Runner for external commands/processes                                                                        | | ||||
| |             |            |                                                                                                               | | ||||
| | sq          | Mike       | Utility functions for sql based databases                                                                     | | ||||
| | tst         | Mike       | Utility functions for unit tests                                                                              | | ||||
| |             |            |                                                                                                               | | ||||
| | rfctime     | Mike       | Classes for time seriallization, with different marshallign method for mongo and json                         | | ||||
| | gojson      | Mike       | Same interface for marshalling/unmarshalling as go/json, except with proper serialization of null arrays/maps | | ||||
| |             |            |                                                                                                               | | ||||
| | bfcodegen   | Mike       | Various codegen tools (run via go generate)                                                                   | | ||||
| |             |            |                                                                                                               | | ||||
| | rext        | Mike       | Regex Wrapper, wraps regexp with a better interface                                                           | | ||||
| | wmo         | Mike       | Mongo Wrapper, wraps mongodb with a better interface                                                          | | ||||
| |             |            |                                                                                                               | | ||||
| | Name         | Maintainer | Description                                                                                                   | | ||||
| |--------------|------------|---------------------------------------------------------------------------------------------------------------| | ||||
| | langext      | Mike       | General uttility/helper functions, (everything thats missing from go standard library)                        | | ||||
| | mathext      | Mike       | Utility/Helper functions for math                                                                             | | ||||
| | cryptext     | Mike       | Utility/Helper functions for encryption                                                                       | | ||||
| | syncext      | Mike       | Utility/Helper funtions for multi-threading / mutex / channels                                                | | ||||
| | dataext      | Mike       | Various useful data structures                                                                                | | ||||
| | zipext       | Mike       | Utility for zip/gzip/tar etc                                                                                  | | ||||
| | reflectext   | Mike       | Utility for golagn reflection                                                                                 | | ||||
| |              |            |                                                                                                               | | ||||
| | mongoext     | Mike       | Utility/Helper functions for mongodb                                                                          | | ||||
| | cursortoken  | Mike       | MongoDB cursortoken implementation                                                                            | | ||||
| |              |            |                                                                                                               | | ||||
| | totpext      | Mike       | Implementation of TOTP (2-Factor-Auth)                                                                        | | ||||
| | termext      | Mike       | Utilities for terminals (mostly color output)                                                                 | | ||||
| | confext      | Mike       | Parses environment configuration into structs                                                                 | | ||||
| | cmdext       | Mike       | Runner for external commands/processes                                                                        | | ||||
| |              |            |                                                                                                               | | ||||
| | sq           | Mike       | Utility functions for sql based databases                                                                     | | ||||
| | tst          | Mike       | Utility functions for unit tests                                                                              | | ||||
| |              |            |                                                                                                               | | ||||
| | rfctime      | Mike       | Classes for time seriallization, with different marshallign method for mongo and json                         | | ||||
| | gojson       | Mike       | Same interface for marshalling/unmarshalling as go/json, except with proper serialization of null arrays/maps | | ||||
| |              |            |                                                                                                               | | ||||
| | bfcodegen    | Mike       | Various codegen tools (run via go generate)                                                                   | | ||||
| |              |            |                                                                                                               | | ||||
| | rext         | Mike       | Regex Wrapper, wraps regexp with a better interface                                                           | | ||||
| | wmo          | Mike       | Mongo Wrapper, wraps mongodb with a better interface                                                          | | ||||
| |              |            |                                                                                                               | | ||||
							
								
								
									
										80
									
								
								_data/mongo.patch
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								_data/mongo.patch
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,80 @@ | ||||
| diff --git a/mongo/bson/bsoncodec/struct_codec.go b/mongo/bson/bsoncodec/struct_codec.go | ||||
| --- a/mongo/bson/bsoncodec/struct_codec.go | ||||
| +++ b/mongo/bson/bsoncodec/struct_codec.go | ||||
| @@ -122,6 +122,10 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r | ||||
|  	} | ||||
|  	var rv reflect.Value | ||||
|  	for _, desc := range sd.fl { | ||||
| +		if desc.omitAlways { | ||||
| +			continue | ||||
| +		} | ||||
| + | ||||
|  		if desc.inline == nil { | ||||
|  			rv = val.Field(desc.idx) | ||||
|  		} else { | ||||
| @@ -400,15 +404,16 @@ type structDescription struct { | ||||
|  } | ||||
|   | ||||
|  type fieldDescription struct { | ||||
| -	name      string // BSON key name | ||||
| -	fieldName string // struct field name | ||||
| -	idx       int | ||||
| -	omitEmpty bool | ||||
| -	minSize   bool | ||||
| -	truncate  bool | ||||
| -	inline    []int | ||||
| -	encoder   ValueEncoder | ||||
| -	decoder   ValueDecoder | ||||
| +	name       string // BSON key name | ||||
| +	fieldName  string // struct field name | ||||
| +	idx        int | ||||
| +	omitEmpty  bool | ||||
| +	omitAlways bool | ||||
| +	minSize    bool | ||||
| +	truncate   bool | ||||
| +	inline     []int | ||||
| +	encoder    ValueEncoder | ||||
| +	decoder    ValueDecoder | ||||
|  } | ||||
|   | ||||
|  type byIndex []fieldDescription | ||||
| @@ -491,6 +496,7 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr | ||||
|  		} | ||||
|  		description.name = stags.Name | ||||
|  		description.omitEmpty = stags.OmitEmpty | ||||
| +		description.omitAlways = stags.OmitAlways | ||||
|  		description.minSize = stags.MinSize | ||||
|  		description.truncate = stags.Truncate | ||||
|   | ||||
| diff --git a/mongo/bson/bsoncodec/struct_tag_parser.go b/mongo/bson/bsoncodec/struct_tag_parser.go | ||||
| --- a/mongo/bson/bsoncodec/struct_tag_parser.go | ||||
| +++ b/mongo/bson/bsoncodec/struct_tag_parser.go | ||||
| @@ -52,12 +52,13 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT | ||||
|  // | ||||
|  // TODO(skriptble): Add tags for undefined as nil and for null as nil. | ||||
|  type StructTags struct { | ||||
| -	Name      string | ||||
| -	OmitEmpty bool | ||||
| -	MinSize   bool | ||||
| -	Truncate  bool | ||||
| -	Inline    bool | ||||
| -	Skip      bool | ||||
| +	Name       string | ||||
| +	OmitEmpty  bool | ||||
| +	OmitAlways bool | ||||
| +	MinSize    bool | ||||
| +	Truncate   bool | ||||
| +	Inline     bool | ||||
| +	Skip       bool | ||||
|  } | ||||
|   | ||||
|  // DefaultStructTagParser is the StructTagParser used by the StructCodec by default. | ||||
| @@ -108,6 +109,8 @@ func parseTags(key string, tag string) (StructTags, error) { | ||||
|  		switch str { | ||||
|  		case "omitempty": | ||||
|  			st.OmitEmpty = true | ||||
| +		case "omitalways": | ||||
| +			st.OmitAlways = true | ||||
|  		case "minsize": | ||||
|  			st.MinSize = true | ||||
|  		case "truncate": | ||||
							
								
								
									
										95
									
								
								_data/update-mongo.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										95
									
								
								_data/update-mongo.sh
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,95 @@ | ||||
| #!/bin/bash | ||||
|  | ||||
|  | ||||
|     set -o nounset   # disallow usage of unset vars  ( set -u ) | ||||
|     set -o errexit   # Exit immediately if a pipeline returns non-zero.  ( set -e ) | ||||
|     set -o errtrace  # Allow the above trap be inherited by all functions in the script.  ( set -E ) | ||||
|     set -o pipefail  # Return value of a pipeline is the value of the last (rightmost) command to exit with a non-zero status | ||||
|     IFS=$'\n\t'      # Set $IFS to only newline and tab. | ||||
|  | ||||
|  | ||||
|  | ||||
| dir="/tmp/mongo_repo_$( uuidgen )" | ||||
|  | ||||
| echo "" | ||||
| echo "> Clone https://github.dev/mongodb/mongo-go-driver" | ||||
| echo "" | ||||
|  | ||||
| git clone "https://github.com/mongodb/mongo-go-driver" "$dir" | ||||
|  | ||||
| pushd "$dir" | ||||
|  | ||||
| git fetch --tags | ||||
|  | ||||
| latestTag="$( git describe --tags `git rev-list --tags --max-count=1` )" | ||||
|  | ||||
| git -c "advice.detachedHead=false" checkout $latestTag | ||||
|  | ||||
| latestSHA="$( git rev-parse HEAD )" | ||||
|  | ||||
| popd | ||||
|  | ||||
| existingTag=$( cat mongoPatchVersion.go | grep -oP "(?<=const MongoCloneTag = \")([A-Za-z0-9.]+)(?=\")" ) | ||||
| existingSHA=$( cat mongoPatchVersion.go | grep -oP "(?<=const MongoCloneCommit = \")([A-Za-z0-9.]+)(?=\")" ) | ||||
|  | ||||
| echo "====================================" | ||||
| echo "ID  (online) $latestSHA" | ||||
| echo "ID  (local)  $existingSHA" | ||||
| echo "Tag (online) $latestTag" | ||||
| echo "Tag (local)  $existingTag" | ||||
| echo "====================================" | ||||
|  | ||||
| if [[ "$latestTag" == "$existingTag" ]]; then | ||||
|   echo "Nothing to do" | ||||
|   rm -rf "$dir" | ||||
|   exit 0 | ||||
| fi | ||||
|  | ||||
| echo "" | ||||
| echo "> Copy repository" | ||||
| echo "" | ||||
|  | ||||
| rm -rf mongo | ||||
| cp -r "$dir" "mongo" | ||||
| rm -rf "$dir" | ||||
|  | ||||
| echo "" | ||||
| echo "> Clean repository" | ||||
| echo "" | ||||
|  | ||||
| rm -rf "mongo/.git" | ||||
| rm -rf "mongo/.evergreen" | ||||
| rm -rf "mongo/cmd" | ||||
| rm -rf "mongo/docs" | ||||
| rm -rf "mongo/etc" | ||||
| rm -rf "mongo/examples" | ||||
| rm -rf "mongo/testdata" | ||||
| rm -rf "mongo/benchmark" | ||||
| rm -rf "mongo/vendor" | ||||
| rm -rf "mongo/internal/test" | ||||
| rm -rf "mongo/go.mod" | ||||
| rm -rf "mongo/go.sum" | ||||
|  | ||||
| echo "" | ||||
| echo "> Update mongoPatchVersion.go" | ||||
| echo "" | ||||
|  | ||||
| { | ||||
|  | ||||
|   printf "package goext\n" | ||||
|   printf "\n" | ||||
|   printf "// %s\n" "$( date +"%Y-%m-%d %H:%M:%S%z" )" | ||||
|   printf "\n" | ||||
|   printf "const MongoCloneTag = \"%s\"\n" "$latestTag" | ||||
|   printf "const MongoCloneCommit = \"%s\"\n" "$latestSHA" | ||||
|  | ||||
| } > mongoPatchVersion.go | ||||
|  | ||||
| echo "" | ||||
| echo "> Patch mongo" | ||||
| echo "" | ||||
|  | ||||
| git apply -v _data/mongo.patch | ||||
|  | ||||
| echo "" | ||||
| echo "Done." | ||||
| @@ -23,6 +23,8 @@ fi | ||||
|  | ||||
| git pull --ff | ||||
|  | ||||
| go get -u ./... | ||||
|  | ||||
| curr_vers=$(git describe --tags --abbrev=0 | sed 's/v//g') | ||||
|  | ||||
| next_ver=$(echo "$curr_vers" | awk -F. -v OFS=. 'NF==1{print ++$NF}; NF>1{if(length($NF+1)>length($NF))$(NF-1)++; $NF=sprintf("%0*d", length($NF), ($NF+1)%(10^length($NF))); print}') | ||||
| @@ -32,6 +34,8 @@ echo "> Current Version: ${curr_vers}" | ||||
| echo "> Next    Version: ${next_ver}" | ||||
| echo "" | ||||
|  | ||||
| printf "package goext\n\nconst GoextVersion = \"%s\"\n\nconst GoextVersionTimestamp = \"%s\"\n" "${next_ver}" "$( date +"%Y-%m-%dT%H:%M:%S%z" )" > "goextVersion.go" | ||||
|  | ||||
| git add --verbose . | ||||
|  | ||||
| msg="v${next_ver}" | ||||
|   | ||||
| @@ -3,11 +3,15 @@ package bfcodegen | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext/cmdext" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext/cryptext" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext/langext" | ||||
| 	"gogs.mikescher.com/BlackForestBytes/goext/rext" | ||||
| 	"io" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"path/filepath" | ||||
| 	"regexp" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| @@ -21,6 +25,7 @@ type EnumDefVal struct { | ||||
|  | ||||
| type EnumDef struct { | ||||
| 	File         string | ||||
| 	FileRelative string | ||||
| 	EnumTypeName string | ||||
| 	Type         string | ||||
| 	Values       []EnumDefVal | ||||
| @@ -32,6 +37,8 @@ var rexEnumDef = rext.W(regexp.MustCompile("^\\s*type\\s+(?P<name>[A-Za-z0-9_]+) | ||||
|  | ||||
| var rexValueDef = rext.W(regexp.MustCompile("^\\s*(?P<name>[A-Za-z0-9_]+)\\s+(?P<type>[A-Za-z0-9_]+)\\s*=\\s*(?P<value>(\"[A-Za-z0-9_:]+\"|[0-9]+))\\s*(//(?P<descr>.*))?.*$")) | ||||
|  | ||||
| var rexChecksumConst = rext.W(regexp.MustCompile("const ChecksumGenerator = \"(?P<cs>[A-Za-z0-9_]*)\"")) | ||||
|  | ||||
| func GenerateEnumSpecs(sourceDir string, destFile string) error { | ||||
|  | ||||
| 	files, err := os.ReadDir(sourceDir) | ||||
| @@ -39,17 +46,46 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	oldChecksum := "N/A" | ||||
| 	if _, err := os.Stat(destFile); !os.IsNotExist(err) { | ||||
| 		content, err := os.ReadFile(destFile) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if m, ok := rexChecksumConst.MatchFirst(string(content)); ok { | ||||
| 			oldChecksum = m.GroupByName("cs").Value() | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	files = langext.ArrFilter(files, func(v os.DirEntry) bool { return v.Name() != path.Base(destFile) }) | ||||
| 	files = langext.ArrFilter(files, func(v os.DirEntry) bool { return strings.HasSuffix(v.Name(), ".go") }) | ||||
| 	langext.SortBy(files, func(v os.DirEntry) string { return v.Name() }) | ||||
|  | ||||
| 	newChecksumStr := goext.GoextVersion | ||||
| 	for _, f := range files { | ||||
| 		content, err := os.ReadFile(path.Join(sourceDir, f.Name())) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		newChecksumStr += "\n" + f.Name() + "\t" + cryptext.BytesSha256(content) | ||||
| 	} | ||||
|  | ||||
| 	newChecksum := cryptext.BytesSha256([]byte(newChecksumStr)) | ||||
|  | ||||
| 	if newChecksum != oldChecksum { | ||||
| 		fmt.Printf("[EnumGenerate] Checksum has changed ( %s -> %s ), will generate new file\n\n", oldChecksum, newChecksum) | ||||
| 	} else { | ||||
| 		fmt.Printf("[EnumGenerate] Checksum unchanged ( %s ), nothing to do\n", oldChecksum) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	allEnums := make([]EnumDef, 0) | ||||
|  | ||||
| 	pkgname := "" | ||||
|  | ||||
| 	for _, f := range files { | ||||
| 		if !strings.HasSuffix(f.Name(), ".go") { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		fmt.Printf("========= %s =========\n\n", f.Name()) | ||||
| 		fileEnums, pn, err := processFile(f.Name()) | ||||
| 		fileEnums, pn, err := processFile(sourceDir, path.Join(sourceDir, f.Name())) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -67,7 +103,7 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error { | ||||
| 		return errors.New("no package name found in any file") | ||||
| 	} | ||||
|  | ||||
| 	err = os.WriteFile(destFile, []byte(fmtOutput(allEnums, pkgname)), 0o755) | ||||
| 	err = os.WriteFile(destFile, []byte(fmtOutput(newChecksum, allEnums, pkgname)), 0o755) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -89,7 +125,7 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func processFile(fn string) ([]EnumDef, string, error) { | ||||
| func processFile(basedir string, fn string) ([]EnumDef, string, error) { | ||||
| 	file, err := os.Open(fn) | ||||
| 	if err != nil { | ||||
| 		return nil, "", err | ||||
| @@ -119,8 +155,15 @@ func processFile(fn string) ([]EnumDef, string, error) { | ||||
| 		} | ||||
|  | ||||
| 		if match, ok := rexEnumDef.MatchFirst(line); ok { | ||||
|  | ||||
| 			rfp, err := filepath.Rel(basedir, fn) | ||||
| 			if err != nil { | ||||
| 				return nil, "", err | ||||
| 			} | ||||
|  | ||||
| 			def := EnumDef{ | ||||
| 				File:         fn, | ||||
| 				FileRelative: rfp, | ||||
| 				EnumTypeName: match.GroupByName("name").Value(), | ||||
| 				Type:         match.GroupByName("type").Value(), | ||||
| 				Values:       make([]EnumDefVal, 0), | ||||
| @@ -159,8 +202,8 @@ func processFile(fn string) ([]EnumDef, string, error) { | ||||
| 	return enums, pkgname, nil | ||||
| } | ||||
|  | ||||
| func fmtOutput(enums []EnumDef, pkgname string) string { | ||||
| 	str := "// Code generated by permissions_gen.sh DO NOT EDIT.\n" | ||||
| func fmtOutput(cs string, enums []EnumDef, pkgname string) string { | ||||
| 	str := "// Code generated by enum-generate.go DO NOT EDIT.\n" | ||||
| 	str += "\n" | ||||
| 	str += "package " + pkgname + "\n" | ||||
| 	str += "\n" | ||||
| @@ -168,6 +211,9 @@ func fmtOutput(enums []EnumDef, pkgname string) string { | ||||
| 	str += "import \"gogs.mikescher.com/BlackForestBytes/goext/langext\"" + "\n" | ||||
| 	str += "\n" | ||||
|  | ||||
| 	str += "const ChecksumGenerator = \"" + cs + "\"" + "\n" | ||||
| 	str += "\n" | ||||
|  | ||||
| 	str += "type Enum interface {" + "\n" | ||||
| 	str += "    Valid() bool" + "\n" | ||||
| 	str += "    ValuesAny() []any" + "\n" | ||||
| @@ -202,7 +248,7 @@ func fmtOutput(enums []EnumDef, pkgname string) string { | ||||
|  | ||||
| 		str += "// ================================ " + enumdef.EnumTypeName + " ================================" + "\n" | ||||
| 		str += "//" + "\n" | ||||
| 		str += "// File:       " + enumdef.File + "\n" | ||||
| 		str += "// File:       " + enumdef.FileRelative + "\n" | ||||
| 		str += "// StringEnum: " + langext.Conditional(hasStr, "true", "false") + "\n" | ||||
| 		str += "// DescrEnum:  " + langext.Conditional(hasDescr, "true", "false") + "\n" | ||||
| 		str += "//" + "\n" | ||||
|   | ||||
							
								
								
									
										15
									
								
								bfcodegen/enum-generate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								bfcodegen/enum-generate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| package bfcodegen | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestApplyEnvOverridesSimple(t *testing.T) { | ||||
|  | ||||
| 	err := GenerateEnumSpecs("/home/mike/Code/reiff/badennet/bnet-backend/models", "/home/mike/Code/reiff/badennet/bnet-backend/models/enums_gen.go") | ||||
| 	if err != nil { | ||||
| 		t.Error(err) | ||||
| 		t.Fail() | ||||
| 	} | ||||
|  | ||||
| } | ||||
| @@ -178,9 +178,9 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref | ||||
| 		if strings.TrimSpace(strings.ToLower(envval)) == "true" { | ||||
| 			return reflect.ValueOf(true).Convert(rvtype), nil | ||||
| 		} else if strings.TrimSpace(strings.ToLower(envval)) == "false" { | ||||
| 			return reflect.ValueOf(true).Convert(rvtype), nil | ||||
| 		} else if strings.TrimSpace(strings.ToLower(envval)) == "1" { | ||||
| 			return reflect.ValueOf(false).Convert(rvtype), nil | ||||
| 		} else if strings.TrimSpace(strings.ToLower(envval)) == "1" { | ||||
| 			return reflect.ValueOf(true).Convert(rvtype), nil | ||||
| 		} else if strings.TrimSpace(strings.ToLower(envval)) == "0" { | ||||
| 			return reflect.ValueOf(false).Convert(rvtype), nil | ||||
| 		} else { | ||||
|   | ||||
							
								
								
									
										34
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								go.mod
									
									
									
									
									
								
							| @@ -3,22 +3,30 @@ module gogs.mikescher.com/BlackForestBytes/goext | ||||
| go 1.19 | ||||
|  | ||||
| require ( | ||||
| 	github.com/golang/snappy v0.0.4 | ||||
| 	github.com/google/go-cmp v0.5.9 | ||||
| 	github.com/jmoiron/sqlx v1.3.5 | ||||
| 	go.mongodb.org/mongo-driver v1.11.1 | ||||
| 	golang.org/x/crypto v0.4.0 | ||||
| 	golang.org/x/sys v0.3.0 | ||||
| 	golang.org/x/term v0.3.0 | ||||
| 	github.com/klauspost/compress v1.16.6 | ||||
| 	github.com/kr/pretty v0.1.0 | ||||
| 	github.com/montanaflynn/stats v0.7.1 | ||||
| 	github.com/pkg/errors v0.9.1 | ||||
| 	github.com/stretchr/testify v1.8.4 | ||||
| 	github.com/tidwall/pretty v1.0.0 | ||||
| 	github.com/xdg-go/scram v1.1.2 | ||||
| 	github.com/xdg-go/stringprep v1.0.4 | ||||
| 	github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a | ||||
| 	go.mongodb.org/mongo-driver v1.11.7 | ||||
| 	golang.org/x/crypto v0.10.0 | ||||
| 	golang.org/x/sync v0.3.0 | ||||
| 	golang.org/x/sys v0.9.0 | ||||
| 	golang.org/x/term v0.9.0 | ||||
| ) | ||||
|  | ||||
| require ( | ||||
| 	github.com/golang/snappy v0.0.1 // indirect | ||||
| 	github.com/klauspost/compress v1.13.6 // indirect | ||||
| 	github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect | ||||
| 	github.com/pkg/errors v0.9.1 // indirect | ||||
| 	github.com/davecgh/go-spew v1.1.1 // indirect | ||||
| 	github.com/kr/text v0.1.0 // indirect | ||||
| 	github.com/pmezard/go-difflib v1.0.0 // indirect | ||||
| 	github.com/xdg-go/pbkdf2 v1.0.0 // indirect | ||||
| 	github.com/xdg-go/scram v1.1.1 // indirect | ||||
| 	github.com/xdg-go/stringprep v1.0.3 // indirect | ||||
| 	github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect | ||||
| 	golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect | ||||
| 	golang.org/x/text v0.5.0 // indirect | ||||
| 	golang.org/x/text v0.10.0 // indirect | ||||
| 	gopkg.in/yaml.v3 v3.0.1 // indirect | ||||
| ) | ||||
|   | ||||
							
								
								
									
										71
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										71
									
								
								go.sum
									
									
									
									
									
								
							| @@ -3,14 +3,17 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c | ||||
| github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||||
| github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= | ||||
| github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= | ||||
| github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= | ||||
| github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= | ||||
| github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= | ||||
| github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= | ||||
| github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | ||||
| github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= | ||||
| github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= | ||||
| github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= | ||||
| github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= | ||||
| github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= | ||||
| github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= | ||||
| github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk= | ||||
| github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= | ||||
| github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= | ||||
| github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= | ||||
| github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= | ||||
| @@ -20,49 +23,77 @@ github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= | ||||
| github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= | ||||
| github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= | ||||
| github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= | ||||
| github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= | ||||
| github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= | ||||
| github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= | ||||
| github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= | ||||
| github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | ||||
| github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||
| github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||||
| github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||||
| github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= | ||||
| github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= | ||||
| github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= | ||||
| github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= | ||||
| github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= | ||||
| github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= | ||||
| github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= | ||||
| github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= | ||||
| github.com/xdg-go/scram v1.1.1 h1:VOMT+81stJgXW3CpHyqHN3AXDYIMsx56mEFrB37Mb/E= | ||||
| github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= | ||||
| github.com/xdg-go/stringprep v1.0.3 h1:kdwGpVNwPFtjs98xCGkHjQtGKh86rDcRZN17QEMCOIs= | ||||
| github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= | ||||
| github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= | ||||
| github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= | ||||
| github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= | ||||
| github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= | ||||
| github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= | ||||
| github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= | ||||
| go.mongodb.org/mongo-driver v1.11.1 h1:QP0znIRTuL0jf1oBQoAoM0C6ZJfBK4kx0Uumtv1A7w8= | ||||
| go.mongodb.org/mongo-driver v1.11.1/go.mod h1:s7p5vEtfbeR1gYi6pnj3c3/urpbLv2T5Sfd6Rp2HBB8= | ||||
| github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= | ||||
| github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= | ||||
| github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= | ||||
| go.mongodb.org/mongo-driver v1.11.7 h1:LIwYxASDLGUg/8wOhgOOZhX8tQa/9tgZPgzZoVqJvcs= | ||||
| go.mongodb.org/mongo-driver v1.11.7/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY= | ||||
| golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | ||||
| golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | ||||
| golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= | ||||
| golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= | ||||
| golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= | ||||
| golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= | ||||
| golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= | ||||
| golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= | ||||
| golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= | ||||
| golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | ||||
| golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | ||||
| golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= | ||||
| golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= | ||||
| golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= | ||||
| golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= | ||||
| golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | ||||
| golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= | ||||
| golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= | ||||
| golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | ||||
| golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||||
| golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= | ||||
| golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= | ||||
| golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||||
| golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | ||||
| golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= | ||||
| golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= | ||||
| golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= | ||||
| golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= | ||||
| golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= | ||||
| golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | ||||
| golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | ||||
| golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= | ||||
| golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= | ||||
| golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= | ||||
| golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= | ||||
| golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= | ||||
| golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= | ||||
| golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= | ||||
| golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= | ||||
| golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= | ||||
| golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | ||||
| gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= | ||||
| gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||||
| gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= | ||||
| gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= | ||||
|   | ||||
							
								
								
									
										5
									
								
								goextVersion.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								goextVersion.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package goext | ||||
|  | ||||
| const GoextVersion = "0.0.166" | ||||
|  | ||||
| const GoextVersionTimestamp = "2023-06-19T10:25:41+0200" | ||||
							
								
								
									
										16
									
								
								mongo/.errcheck-excludes
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								mongo/.errcheck-excludes
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,16 @@ | ||||
| (go.mongodb.org/mongo-driver/x/mongo/driver.Connection).Close | ||||
| (*go.mongodb.org/mongo-driver/x/network/connection.connection).Close | ||||
| (go.mongodb.org/mongo-driver/x/network/connection.Connection).Close | ||||
| (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.connection).close | ||||
| (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Topology).Unsubscribe | ||||
| (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Server).Close | ||||
| (*go.mongodb.org/mongo-driver/x/network/connection.pool).closeConnection | ||||
| (*go.mongodb.org/mongo-driver/x/mongo/driver/topology.pool).close | ||||
| (go.mongodb.org/mongo-driver/x/network/wiremessage.ReadWriteCloser).Close | ||||
| (*go.mongodb.org/mongo-driver/mongo.Cursor).Close | ||||
| (*go.mongodb.org/mongo-driver/mongo.ChangeStream).Close | ||||
| (*go.mongodb.org/mongo-driver/mongo.Client).Disconnect | ||||
| (net.Conn).Close | ||||
| encoding/pem.Encode | ||||
| fmt.Fprintf | ||||
| fmt.Fprint | ||||
							
								
								
									
										13
									
								
								mongo/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								mongo/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| .vscode | ||||
| debug | ||||
| .idea | ||||
| *.iml | ||||
| *.ipr | ||||
| *.iws | ||||
| .idea | ||||
| *.sublime-project | ||||
| *.sublime-workspace | ||||
| driver-test-data.tar.gz | ||||
| perf | ||||
| **mongocryptd.pid | ||||
| *.test | ||||
							
								
								
									
										3
									
								
								mongo/.gitmodules
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								mongo/.gitmodules
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| [submodule "specifications"] | ||||
| 	path = specifications | ||||
| 	url = git@github.com:mongodb/specifications.git | ||||
							
								
								
									
										123
									
								
								mongo/.golangci.yml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								mongo/.golangci.yml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | ||||
| run: | ||||
|   timeout: 5m | ||||
|  | ||||
| linters: | ||||
|   disable-all: true | ||||
|   # TODO(GODRIVER-2156): Enable all commented-out linters. | ||||
|   enable: | ||||
|     - errcheck | ||||
|     # - errorlint | ||||
|     - gocritic | ||||
|     - goimports | ||||
|     - gosimple | ||||
|     - gosec | ||||
|     - govet | ||||
|     - ineffassign | ||||
|     - makezero | ||||
|     - misspell | ||||
|     - nakedret | ||||
|     - paralleltest | ||||
|     - prealloc | ||||
|     - revive | ||||
|     - staticcheck | ||||
|     - typecheck | ||||
|     - unused | ||||
|     - unconvert | ||||
|     - unparam | ||||
|  | ||||
| linters-settings: | ||||
|   errcheck: | ||||
|     exclude: .errcheck-excludes | ||||
|   gocritic: | ||||
|     enabled-checks: | ||||
|       # Detects suspicious append result assignments. E.g. "b := append(a, 1, 2, 3)" | ||||
|       - appendAssign | ||||
|   govet: | ||||
|     disable: | ||||
|       - cgocall | ||||
|       - composites | ||||
|   paralleltest: | ||||
|     # Ignore missing calls to `t.Parallel()` and only report incorrect uses of `t.Parallel()`. | ||||
|     ignore-missing: true | ||||
|   staticcheck: | ||||
|     checks: [ | ||||
|       "all", | ||||
|       "-SA1019", # Disable deprecation warnings for now. | ||||
|       "-SA1012", # Disable "do not pass a nil Context" to allow testing nil contexts in tests. | ||||
|     ] | ||||
|  | ||||
| issues: | ||||
|   exclude-use-default: false | ||||
|   exclude: | ||||
|     # Add all default excluded issues except issues related to exported types/functions not having | ||||
|     # comments; we want those warnings. The defaults are copied from the "--exclude-use-default" | ||||
|     # documentation on https://golangci-lint.run/usage/configuration/#command-line-options | ||||
|     ## Defaults ## | ||||
|     # EXC0001 errcheck: Almost all programs ignore errors on these functions and in most cases it's ok | ||||
|     - Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked | ||||
|     # EXC0003 golint: False positive when tests are defined in package 'test' | ||||
|     - func name will be used as test\.Test.* by other packages, and that stutters; consider calling this | ||||
|     # EXC0004 govet: Common false positives | ||||
|     - (possible misuse of unsafe.Pointer|should have signature) | ||||
|     # EXC0005 staticcheck: Developers tend to write in C-style with an explicit 'break' in a 'switch', so it's ok to ignore | ||||
|     - ineffective break statement. Did you mean to break out of the outer loop | ||||
|     # EXC0006 gosec: Too many false-positives on 'unsafe' usage | ||||
|     - Use of unsafe calls should be audited | ||||
|     # EXC0007 gosec: Too many false-positives for parametrized shell calls | ||||
|     - Subprocess launch(ed with variable|ing should be audited) | ||||
|     # EXC0008 gosec: Duplicated errcheck checks | ||||
|     - (G104|G307) | ||||
|     # EXC0009 gosec: Too many issues in popular repos | ||||
|     - (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less) | ||||
|     # EXC0010 gosec: False positive is triggered by 'src, err := ioutil.ReadFile(filename)' | ||||
|     - Potential file inclusion via variable | ||||
|     ## End Defaults ## | ||||
|  | ||||
|     # Ignore capitalization warning for this weird field name. | ||||
|     - "var-naming: struct field CqCssWxW should be CqCSSWxW" | ||||
|     # Ignore warnings for common "wiremessage.Read..." usage because the safest way to use that API | ||||
|     # is by assigning possibly unused returned byte buffers. | ||||
|     - "SA4006: this value of `wm` is never used" | ||||
|     - "SA4006: this value of `rem` is never used" | ||||
|     - "ineffectual assignment to wm" | ||||
|     - "ineffectual assignment to rem" | ||||
|  | ||||
|   skip-dirs-use-default: false | ||||
|   skip-dirs: | ||||
|     - (^|/)vendor($|/) | ||||
|     - (^|/)testdata($|/) | ||||
|     - (^|/)etc($|/) | ||||
|   exclude-rules: | ||||
|     # Ignore some linters for example code that is intentionally simplified. | ||||
|     - path: examples/ | ||||
|       linters: | ||||
|         - revive | ||||
|         - errcheck | ||||
|     # Disable unused code linters for the copy/pasted "awsv4" package. | ||||
|     - path: x/mongo/driver/auth/internal/awsv4 | ||||
|       linters: | ||||
|         - unused | ||||
|     # Disable "unused" linter for code files that depend on the "mongocrypt.MongoCrypt" type because | ||||
|     # the linter build doesn't work correctly with CGO enabled. As a result, all calls to a | ||||
|     # "mongocrypt.MongoCrypt" API appear to always panic (see mongocrypt_not_enabled.go), leading | ||||
|     # to confusing messages about unused code. | ||||
|     - path: x/mongo/driver/crypt.go|mongo/(crypt_retrievers|mongocryptd).go | ||||
|       linters: | ||||
|         - unused | ||||
|     # Ignore "TLS MinVersion too low", "TLS InsecureSkipVerify set true", and "Use of weak random | ||||
|     # number generator (math/rand instead of crypto/rand)" in tests. | ||||
|     - path: _test\.go | ||||
|       text: G401|G402|G404 | ||||
|       linters: | ||||
|         - gosec | ||||
|     # Ignore missing comments for exported variable/function/type for code in the "internal" and | ||||
|     # "benchmark" directories. | ||||
|     - path: (internal\/|benchmark\/) | ||||
|       text: exported (.+) should have comment( \(or a comment on this block\))? or be unexported | ||||
|     # Ignore missing package comments for directories that aren't frequently used by external users. | ||||
|     - path: (internal\/|benchmark\/|x\/|cmd\/|mongo\/integration\/) | ||||
|       text: should have a package comment | ||||
|     # Disable unused linter for "golang.org/x/exp/rand" package in internal/randutil/rand. | ||||
|     - path: internal/randutil/rand | ||||
|       linters: | ||||
|         - unused | ||||
							
								
								
									
										201
									
								
								mongo/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								mongo/LICENSE
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
|                                  Apache License | ||||
|                            Version 2.0, January 2004 | ||||
|                         http://www.apache.org/licenses/ | ||||
|  | ||||
|    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||
|  | ||||
|    1. Definitions. | ||||
|  | ||||
|       "License" shall mean the terms and conditions for use, reproduction, | ||||
|       and distribution as defined by Sections 1 through 9 of this document. | ||||
|  | ||||
|       "Licensor" shall mean the copyright owner or entity authorized by | ||||
|       the copyright owner that is granting the License. | ||||
|  | ||||
|       "Legal Entity" shall mean the union of the acting entity and all | ||||
|       other entities that control, are controlled by, or are under common | ||||
|       control with that entity. For the purposes of this definition, | ||||
|       "control" means (i) the power, direct or indirect, to cause the | ||||
|       direction or management of such entity, whether by contract or | ||||
|       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||
|       outstanding shares, or (iii) beneficial ownership of such entity. | ||||
|  | ||||
|       "You" (or "Your") shall mean an individual or Legal Entity | ||||
|       exercising permissions granted by this License. | ||||
|  | ||||
|       "Source" form shall mean the preferred form for making modifications, | ||||
|       including but not limited to software source code, documentation | ||||
|       source, and configuration files. | ||||
|  | ||||
|       "Object" form shall mean any form resulting from mechanical | ||||
|       transformation or translation of a Source form, including but | ||||
|       not limited to compiled object code, generated documentation, | ||||
|       and conversions to other media types. | ||||
|  | ||||
|       "Work" shall mean the work of authorship, whether in Source or | ||||
|       Object form, made available under the License, as indicated by a | ||||
|       copyright notice that is included in or attached to the work | ||||
|       (an example is provided in the Appendix below). | ||||
|  | ||||
|       "Derivative Works" shall mean any work, whether in Source or Object | ||||
|       form, that is based on (or derived from) the Work and for which the | ||||
|       editorial revisions, annotations, elaborations, or other modifications | ||||
|       represent, as a whole, an original work of authorship. For the purposes | ||||
|       of this License, Derivative Works shall not include works that remain | ||||
|       separable from, or merely link (or bind by name) to the interfaces of, | ||||
|       the Work and Derivative Works thereof. | ||||
|  | ||||
|       "Contribution" shall mean any work of authorship, including | ||||
|       the original version of the Work and any modifications or additions | ||||
|       to that Work or Derivative Works thereof, that is intentionally | ||||
|       submitted to Licensor for inclusion in the Work by the copyright owner | ||||
|       or by an individual or Legal Entity authorized to submit on behalf of | ||||
|       the copyright owner. For the purposes of this definition, "submitted" | ||||
|       means any form of electronic, verbal, or written communication sent | ||||
|       to the Licensor or its representatives, including but not limited to | ||||
|       communication on electronic mailing lists, source code control systems, | ||||
|       and issue tracking systems that are managed by, or on behalf of, the | ||||
|       Licensor for the purpose of discussing and improving the Work, but | ||||
|       excluding communication that is conspicuously marked or otherwise | ||||
|       designated in writing by the copyright owner as "Not a Contribution." | ||||
|  | ||||
|       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||
|       on behalf of whom a Contribution has been received by Licensor and | ||||
|       subsequently incorporated within the Work. | ||||
|  | ||||
|    2. Grant of Copyright License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       copyright license to reproduce, prepare Derivative Works of, | ||||
|       publicly display, publicly perform, sublicense, and distribute the | ||||
|       Work and such Derivative Works in Source or Object form. | ||||
|  | ||||
|    3. Grant of Patent License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       (except as stated in this section) patent license to make, have made, | ||||
|       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||
|       where such license applies only to those patent claims licensable | ||||
|       by such Contributor that are necessarily infringed by their | ||||
|       Contribution(s) alone or by combination of their Contribution(s) | ||||
|       with the Work to which such Contribution(s) was submitted. If You | ||||
|       institute patent litigation against any entity (including a | ||||
|       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||
|       or a Contribution incorporated within the Work constitutes direct | ||||
|       or contributory patent infringement, then any patent licenses | ||||
|       granted to You under this License for that Work shall terminate | ||||
|       as of the date such litigation is filed. | ||||
|  | ||||
|    4. Redistribution. You may reproduce and distribute copies of the | ||||
|       Work or Derivative Works thereof in any medium, with or without | ||||
|       modifications, and in Source or Object form, provided that You | ||||
|       meet the following conditions: | ||||
|  | ||||
|       (a) You must give any other recipients of the Work or | ||||
|           Derivative Works a copy of this License; and | ||||
|  | ||||
|       (b) You must cause any modified files to carry prominent notices | ||||
|           stating that You changed the files; and | ||||
|  | ||||
|       (c) You must retain, in the Source form of any Derivative Works | ||||
|           that You distribute, all copyright, patent, trademark, and | ||||
|           attribution notices from the Source form of the Work, | ||||
|           excluding those notices that do not pertain to any part of | ||||
|           the Derivative Works; and | ||||
|  | ||||
|       (d) If the Work includes a "NOTICE" text file as part of its | ||||
|           distribution, then any Derivative Works that You distribute must | ||||
|           include a readable copy of the attribution notices contained | ||||
|           within such NOTICE file, excluding those notices that do not | ||||
|           pertain to any part of the Derivative Works, in at least one | ||||
|           of the following places: within a NOTICE text file distributed | ||||
|           as part of the Derivative Works; within the Source form or | ||||
|           documentation, if provided along with the Derivative Works; or, | ||||
|           within a display generated by the Derivative Works, if and | ||||
|           wherever such third-party notices normally appear. The contents | ||||
|           of the NOTICE file are for informational purposes only and | ||||
|           do not modify the License. You may add Your own attribution | ||||
|           notices within Derivative Works that You distribute, alongside | ||||
|           or as an addendum to the NOTICE text from the Work, provided | ||||
|           that such additional attribution notices cannot be construed | ||||
|           as modifying the License. | ||||
|  | ||||
|       You may add Your own copyright statement to Your modifications and | ||||
|       may provide additional or different license terms and conditions | ||||
|       for use, reproduction, or distribution of Your modifications, or | ||||
|       for any such Derivative Works as a whole, provided Your use, | ||||
|       reproduction, and distribution of the Work otherwise complies with | ||||
|       the conditions stated in this License. | ||||
|  | ||||
|    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||
|       any Contribution intentionally submitted for inclusion in the Work | ||||
|       by You to the Licensor shall be under the terms and conditions of | ||||
|       this License, without any additional terms or conditions. | ||||
|       Notwithstanding the above, nothing herein shall supersede or modify | ||||
|       the terms of any separate license agreement you may have executed | ||||
|       with Licensor regarding such Contributions. | ||||
|  | ||||
|    6. Trademarks. This License does not grant permission to use the trade | ||||
|       names, trademarks, service marks, or product names of the Licensor, | ||||
|       except as required for reasonable and customary use in describing the | ||||
|       origin of the Work and reproducing the content of the NOTICE file. | ||||
|  | ||||
|    7. Disclaimer of Warranty. Unless required by applicable law or | ||||
|       agreed to in writing, Licensor provides the Work (and each | ||||
|       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||
|       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
|       implied, including, without limitation, any warranties or conditions | ||||
|       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||
|       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||
|       appropriateness of using or redistributing the Work and assume any | ||||
|       risks associated with Your exercise of permissions under this License. | ||||
|  | ||||
|    8. Limitation of Liability. In no event and under no legal theory, | ||||
|       whether in tort (including negligence), contract, or otherwise, | ||||
|       unless required by applicable law (such as deliberate and grossly | ||||
|       negligent acts) or agreed to in writing, shall any Contributor be | ||||
|       liable to You for damages, including any direct, indirect, special, | ||||
|       incidental, or consequential damages of any character arising as a | ||||
|       result of this License or out of the use or inability to use the | ||||
|       Work (including but not limited to damages for loss of goodwill, | ||||
|       work stoppage, computer failure or malfunction, or any and all | ||||
|       other commercial damages or losses), even if such Contributor | ||||
|       has been advised of the possibility of such damages. | ||||
|  | ||||
|    9. Accepting Warranty or Additional Liability. While redistributing | ||||
|       the Work or Derivative Works thereof, You may choose to offer, | ||||
|       and charge a fee for, acceptance of support, warranty, indemnity, | ||||
|       or other liability obligations and/or rights consistent with this | ||||
|       License. However, in accepting such obligations, You may act only | ||||
|       on Your own behalf and on Your sole responsibility, not on behalf | ||||
|       of any other Contributor, and only if You agree to indemnify, | ||||
|       defend, and hold each Contributor harmless for any liability | ||||
|       incurred by, or claims asserted against, such Contributor by reason | ||||
|       of your accepting any such warranty or additional liability. | ||||
|  | ||||
|    END OF TERMS AND CONDITIONS | ||||
|  | ||||
|    APPENDIX: How to apply the Apache License to your work. | ||||
|  | ||||
|       To apply the Apache License to your work, attach the following | ||||
|       boilerplate notice, with the fields enclosed by brackets "[]" | ||||
|       replaced with your own identifying information. (Don't include | ||||
|       the brackets!)  The text should be enclosed in the appropriate | ||||
|       comment syntax for the file format. We also recommend that a | ||||
|       file or class name and description of purpose be included on the | ||||
|       same "printed page" as the copyright notice for easier | ||||
|       identification within third-party archives. | ||||
|  | ||||
|    Copyright [yyyy] [name of copyright owner] | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
							
								
								
									
										210
									
								
								mongo/Makefile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										210
									
								
								mongo/Makefile
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,210 @@ | ||||
| ATLAS_URIS = "$(ATLAS_FREE)" "$(ATLAS_REPLSET)" "$(ATLAS_SHARD)" "$(ATLAS_TLS11)" "$(ATLAS_TLS12)" "$(ATLAS_FREE_SRV)" "$(ATLAS_REPLSET_SRV)" "$(ATLAS_SHARD_SRV)" "$(ATLAS_TLS11_SRV)" "$(ATLAS_TLS12_SRV)" "$(ATLAS_SERVERLESS)" "$(ATLAS_SERVERLESS_SRV)" | ||||
| TEST_TIMEOUT = 1800 | ||||
|  | ||||
| ### Utility targets. ### | ||||
| .PHONY: default | ||||
| default: build check-license check-fmt check-modules lint test-short | ||||
|  | ||||
| .PHONY: add-license | ||||
| add-license: | ||||
| 	etc/check_license.sh -a | ||||
|  | ||||
| .PHONY: check-license | ||||
| check-license: | ||||
| 	etc/check_license.sh | ||||
|  | ||||
| .PHONY: build | ||||
| build: cross-compile build-tests build-compile-check | ||||
| 	go build ./... | ||||
| 	go build $(BUILD_TAGS) ./... | ||||
|  | ||||
| # Use ^$ to match no tests so that no tests are actually run but all tests are | ||||
| # compiled. Run with -short to ensure none of the TestMain functions try to | ||||
| # connect to a server. | ||||
| .PHONY: build-tests | ||||
| build-tests: | ||||
| 	go test -short $(BUILD_TAGS) -run ^$$ ./... | ||||
|  | ||||
| .PHONY: build-compile-check | ||||
| build-compile-check: | ||||
| 	etc/compile_check.sh | ||||
|  | ||||
| # Cross-compiling on Linux for architectures 386, arm, arm64, amd64, ppc64le, and s390x. | ||||
| # Omit any build tags because we don't expect our build environment to support compiling the C | ||||
| # libraries for other architectures. | ||||
| .PHONY: cross-compile | ||||
| cross-compile: | ||||
| 	GOOS=linux GOARCH=386 go build ./... | ||||
| 	GOOS=linux GOARCH=arm go build ./... | ||||
| 	GOOS=linux GOARCH=arm64 go build ./... | ||||
| 	GOOS=linux GOARCH=amd64 go build ./... | ||||
| 	GOOS=linux GOARCH=ppc64le go build ./... | ||||
| 	GOOS=linux GOARCH=s390x go build ./... | ||||
|  | ||||
| .PHONY: install-lll | ||||
| install-lll: | ||||
| 	go install github.com/walle/lll/...@latest | ||||
|  | ||||
| .PHONY: check-fmt | ||||
| check-fmt: install-lll | ||||
| 	etc/check_fmt.sh | ||||
|  | ||||
| # check-modules runs "go mod tidy" then "go mod vendor" and exits with a non-zero exit code if there | ||||
| # are any module or vendored modules changes. The intent is to confirm two properties: | ||||
| # | ||||
| # 1. Exactly the required modules are declared as dependencies. We should always be able to run | ||||
| # "go mod tidy" and expect that no unrelated changes are made to the "go.mod" file. | ||||
| # | ||||
| # 2. All required modules are copied into the vendor/ directory and are an exact copy of the | ||||
| # original module source code (i.e. the vendored modules are not modified from their original code). | ||||
| .PHONY: check-modules | ||||
| check-modules: | ||||
| 	go mod tidy -v | ||||
| 	go mod vendor | ||||
| 	git diff --exit-code go.mod go.sum ./vendor | ||||
|  | ||||
| .PHONY: doc | ||||
| doc: | ||||
| 	godoc -http=:6060 -index | ||||
|  | ||||
| .PHONY: fmt | ||||
| fmt: | ||||
| 	go fmt ./... | ||||
|  | ||||
| .PHONY: install-golangci-lint | ||||
| install-golangci-lint: | ||||
| 	go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.51.0 | ||||
|  | ||||
| # Lint with various GOOS and GOARCH targets to catch static analysis failures that may only affect | ||||
| # specific operating systems or architectures. For example, staticcheck will only check for 64-bit | ||||
| # alignment of atomically accessed variables on 32-bit architectures (see | ||||
| # https://staticcheck.io/docs/checks#SA1027) | ||||
| .PHONY: lint | ||||
| lint: install-golangci-lint | ||||
| 	GOOS=linux GOARCH=386 golangci-lint run --config .golangci.yml ./... | ||||
| 	GOOS=linux GOARCH=arm golangci-lint run --config .golangci.yml ./... | ||||
| 	GOOS=linux GOARCH=arm64 golangci-lint run --config .golangci.yml ./... | ||||
| 	GOOS=linux GOARCH=amd64 golangci-lint run --config .golangci.yml ./... | ||||
| 	GOOS=linux GOARCH=ppc64le golangci-lint run --config .golangci.yml ./... | ||||
| 	GOOS=linux GOARCH=s390x golangci-lint run --config .golangci.yml ./... | ||||
|  | ||||
| .PHONY: update-notices | ||||
| update-notices: | ||||
| 	etc/generate_notices.pl > THIRD-PARTY-NOTICES | ||||
|  | ||||
| ### Local testing targets. ### | ||||
| .PHONY: test | ||||
| test: | ||||
| 	go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -p 1 ./... | ||||
|  | ||||
| .PHONY: test-cover | ||||
| test-cover: | ||||
| 	go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -cover $(COVER_ARGS) -p 1 ./... | ||||
|  | ||||
| .PHONY: test-race | ||||
| test-race: | ||||
| 	go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -race -p 1 ./... | ||||
|  | ||||
| .PHONY: test-short | ||||
| test-short: | ||||
| 	go test $(BUILD_TAGS) -timeout 60s -short ./... | ||||
|  | ||||
| ### Evergreen specific targets. ### | ||||
| .PHONY: build-aws-ecs-test | ||||
| build-aws-ecs-test: | ||||
| 	go build $(BUILD_TAGS) ./cmd/testaws/main.go | ||||
|  | ||||
| .PHONY: evg-test | ||||
| evg-test: | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s -p 1 ./... >> test.suite | ||||
|  | ||||
| .PHONY: evg-test-atlas | ||||
| evg-test-atlas: | ||||
| 	go run ./cmd/testatlas/main.go $(ATLAS_URIS) | ||||
|  | ||||
| .PHONY: evg-test-atlas-data-lake | ||||
| evg-test-atlas-data-lake: | ||||
| 	ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestUnifiedSpecs/atlas-data-lake-testing >> spec_test.suite | ||||
| 	ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestAtlasDataLake >> spec_test.suite | ||||
|  | ||||
| .PHONY: evg-test-enterprise-auth | ||||
| evg-test-enterprise-auth: | ||||
| 	go run -tags gssapi ./cmd/testentauth/main.go | ||||
|  | ||||
| .PHONY: evg-test-kmip | ||||
| evg-test-kmip: | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/data_key_and_double_encryption >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/corpus >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/custom_endpoint >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/kms_tls_options_test >> test.suite | ||||
|  | ||||
| .PHONY: evg-test-kms | ||||
| evg-test-kms: | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite | ||||
|  | ||||
| .PHONY: evg-test-load-balancers | ||||
| evg-test-load-balancers: | ||||
| 	# Load balancer should be tested with all unified tests as well as tests in the following | ||||
| 	# components: retryable reads, retryable writes, change streams, initial DNS seedlist discovery. | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/retryable-reads -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestChangeStreamSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestInitialDNSSeedlistDiscoverySpec/load_balanced -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
|  | ||||
| .PHONY: evg-test-ocsp | ||||
| evg-test-ocsp: | ||||
| 	go test -v ./mongo -run TestOCSP $(OCSP_TLS_SHOULD_SUCCEED) >> test.suite | ||||
|  | ||||
| .PHONY: evg-test-serverless | ||||
| evg-test-serverless: | ||||
| 	# Serverless should be tested with all unified tests as well as tests in the following components: CRUD, load balancer, | ||||
| 	# retryable reads, retryable writes, sessions, transactions and cursor behavior. | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestCrudSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestWriteErrorsWithLabels -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestWriteErrorsDetails -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestHintErrors -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestWriteConcernError -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestErrorsCodeNamePropagated -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/retryable-reads -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableReadsProse -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesProse -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/sessions -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestSessionsProse -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/transactions/legacy -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestConvenientTransactions -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration -run TestCursor -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse >> test.suite | ||||
|  | ||||
| .PHONY: evg-test-versioned-api | ||||
| evg-test-versioned-api: | ||||
| 	# Versioned API related tests are in the mongo, integration and unified packages. | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration >> test.suite | ||||
| 	go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration/unified >> test.suite | ||||
|  | ||||
| .PHONY: build-gcpkms-test | ||||
| build-gcpkms-test: | ||||
| 	go build $(BUILD_TAGS) ./cmd/testgcpkms | ||||
|  | ||||
| ### Benchmark specific targets and support. ### | ||||
| .PHONY: benchmark | ||||
| benchmark:perf | ||||
| 	go test $(BUILD_TAGS) -benchmem -bench=. ./benchmark | ||||
|  | ||||
| .PHONY: driver-benchmark | ||||
| driver-benchmark:perf | ||||
| 	@go run cmd/godriver-benchmark/main.go | tee perf.suite | ||||
|  | ||||
| perf:driver-test-data.tar.gz | ||||
| 	tar -zxf $< $(if $(eq $(UNAME_S),Darwin),-s , --transform=s)/testdata/perf/ | ||||
| 	@touch $@ | ||||
|  | ||||
| driver-test-data.tar.gz: | ||||
| 	curl --retry 5 "https://s3.amazonaws.com/boxes.10gen.com/build/driver-test-data.tar.gz" -o driver-test-data.tar.gz --silent --max-time 120 | ||||
							
								
								
									
										251
									
								
								mongo/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								mongo/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,251 @@ | ||||
| <p align="center"><img src="etc/assets/mongo-gopher.png" width="250"></p> | ||||
| <p align="center"> | ||||
|   <a href="https://goreportcard.com/report/go.mongodb.org/mongo-driver"><img src="https://goreportcard.com/badge/go.mongodb.org/mongo-driver"></a> | ||||
|   <a href="https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo"><img src="etc/assets/godev-mongo-blue.svg" alt="docs"></a> | ||||
|   <a href="https://pkg.go.dev/go.mongodb.org/mongo-driver/bson"><img src="etc/assets/godev-bson-blue.svg" alt="docs"></a> | ||||
|   <a href="https://www.mongodb.com/docs/drivers/go/current/"><img src="etc/assets/docs-mongodb-green.svg"></a> | ||||
| </p> | ||||
|  | ||||
| # MongoDB Go Driver | ||||
|  | ||||
| The MongoDB supported driver for Go. | ||||
|  | ||||
| ------------------------- | ||||
| - [Requirements](#requirements) | ||||
| - [Installation](#installation) | ||||
| - [Usage](#usage) | ||||
| - [Feedback](#feedback) | ||||
| - [Testing / Development](#testing--development) | ||||
| - [Continuous Integration](#continuous-integration) | ||||
| - [License](#license) | ||||
|  | ||||
| ------------------------- | ||||
| ## Requirements | ||||
|  | ||||
| - Go 1.13 or higher. We aim to support the latest versions of Go. | ||||
|   - Go 1.20 or higher is required to run the driver test suite. | ||||
| - MongoDB 3.6 and higher. | ||||
|  | ||||
| ------------------------- | ||||
| ## Installation | ||||
|  | ||||
| The recommended way to get started using the MongoDB Go driver is by using Go modules to install the dependency in | ||||
| your project. This can be done either by importing packages from `go.mongodb.org/mongo-driver` and having the build | ||||
| step install the dependency or by explicitly running | ||||
|  | ||||
| ```bash | ||||
| go get go.mongodb.org/mongo-driver/mongo | ||||
| ``` | ||||
|  | ||||
| When using a version of Go that does not support modules, the driver can be installed using `dep` by running | ||||
|  | ||||
| ```bash | ||||
| dep ensure -add "go.mongodb.org/mongo-driver/mongo" | ||||
| ``` | ||||
|  | ||||
| ------------------------- | ||||
| ## Usage | ||||
|  | ||||
| To get started with the driver, import the `mongo` package and create a `mongo.Client` with the `Connect` function: | ||||
|  | ||||
| ```go | ||||
| import ( | ||||
|     "context" | ||||
|     "time" | ||||
|  | ||||
|     "go.mongodb.org/mongo-driver/mongo" | ||||
|     "go.mongodb.org/mongo-driver/mongo/options" | ||||
|     "go.mongodb.org/mongo-driver/mongo/readpref" | ||||
| ) | ||||
|  | ||||
| ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||||
| defer cancel() | ||||
| client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017")) | ||||
| ``` | ||||
|  | ||||
| Make sure to defer a call to `Disconnect` after instantiating your client: | ||||
|  | ||||
| ```go | ||||
| defer func() { | ||||
|     if err = client.Disconnect(ctx); err != nil { | ||||
|         panic(err) | ||||
|     } | ||||
| }() | ||||
| ``` | ||||
|  | ||||
| For more advanced configuration and authentication, see the [documentation for mongo.Connect](https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo#Connect). | ||||
|  | ||||
| Calling `Connect` does not block for server discovery. If you wish to know if a MongoDB server has been found and connected to, | ||||
| use the `Ping` method: | ||||
|  | ||||
| ```go | ||||
| ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) | ||||
| defer cancel() | ||||
| err = client.Ping(ctx, readpref.Primary()) | ||||
| ``` | ||||
|  | ||||
| To insert a document into a collection, first retrieve a `Database` and then `Collection` instance from the `Client`: | ||||
|  | ||||
| ```go | ||||
| collection := client.Database("testing").Collection("numbers") | ||||
| ``` | ||||
|  | ||||
| The `Collection` instance can then be used to insert documents: | ||||
|  | ||||
| ```go | ||||
| ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) | ||||
| defer cancel() | ||||
| res, err := collection.InsertOne(ctx, bson.D{{"name", "pi"}, {"value", 3.14159}}) | ||||
| id := res.InsertedID | ||||
| ``` | ||||
|  | ||||
| To use `bson.D`, you will need to add `"go.mongodb.org/mongo-driver/bson"` to your imports. | ||||
|  | ||||
| Your import statement should now look like this: | ||||
|  | ||||
| ```go | ||||
| import ( | ||||
|     "context" | ||||
|     "log" | ||||
|     "time" | ||||
|  | ||||
|     "go.mongodb.org/mongo-driver/bson" | ||||
|     "go.mongodb.org/mongo-driver/mongo" | ||||
|     "go.mongodb.org/mongo-driver/mongo/options" | ||||
|     "go.mongodb.org/mongo-driver/mongo/readpref" | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| Several query methods return a cursor, which can be used like this: | ||||
|  | ||||
| ```go | ||||
| ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second) | ||||
| defer cancel() | ||||
| cur, err := collection.Find(ctx, bson.D{}) | ||||
| if err != nil { log.Fatal(err) } | ||||
| defer cur.Close(ctx) | ||||
| for cur.Next(ctx) { | ||||
|     var result bson.D | ||||
|     err := cur.Decode(&result) | ||||
|     if err != nil { log.Fatal(err) } | ||||
|     // do something with result.... | ||||
| } | ||||
| if err := cur.Err(); err != nil { | ||||
|     log.Fatal(err) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| For methods that return a single item, a `SingleResult` instance is returned: | ||||
|  | ||||
| ```go | ||||
| var result struct { | ||||
|     Value float64 | ||||
| } | ||||
| filter := bson.D{{"name", "pi"}} | ||||
| ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) | ||||
| defer cancel() | ||||
| err = collection.FindOne(ctx, filter).Decode(&result) | ||||
| if err == mongo.ErrNoDocuments { | ||||
|     // Do something when no record was found | ||||
|     fmt.Println("record does not exist") | ||||
| } else if err != nil { | ||||
|     log.Fatal(err) | ||||
| } | ||||
| // Do something with result... | ||||
| ``` | ||||
|  | ||||
| Additional examples and documentation can be found under the examples directory and [on the MongoDB Documentation website](https://www.mongodb.com/docs/drivers/go/current/). | ||||
|  | ||||
| ------------------------- | ||||
| ## Feedback | ||||
|  | ||||
| For help with the driver, please post in the [MongoDB Community Forums](https://developer.mongodb.com/community/forums/tag/golang/). | ||||
|  | ||||
| New features and bugs can be reported on jira: https://jira.mongodb.org/browse/GODRIVER | ||||
|  | ||||
| ------------------------- | ||||
| ## Testing / Development | ||||
|  | ||||
| The driver tests can be run against several database configurations. The most simple configuration is a standalone mongod with no auth, no ssl, and no compression. To run these basic driver tests, make sure a standalone MongoDB server instance is running at localhost:27017. To run the tests, you can run `make` (on Windows, run `nmake`). This will run coverage, run go-lint, run go-vet, and build the examples. | ||||
|  | ||||
| ### Testing Different Topologies | ||||
|  | ||||
| To test a **replica set** or **sharded cluster**, set `MONGODB_URI="<connection-string>"` for the `make` command. | ||||
| For example, for a local replica set named `rs1` comprised of three nodes on ports 27017, 27018, and 27019: | ||||
|  | ||||
| ``` | ||||
| MONGODB_URI="mongodb://localhost:27017,localhost:27018,localhost:27019/?replicaSet=rs1" make | ||||
| ``` | ||||
|  | ||||
| ### Testing Auth and TLS | ||||
|  | ||||
| To test authentication and TLS, first set up a MongoDB cluster with auth and TLS configured. Testing authentication requires a user with the `root` role on the `admin` database. Here is an example command that would run a mongod with TLS correctly configured for tests. Either set or replace PATH_TO_SERVER_KEY_FILE and PATH_TO_CA_FILE with paths to their respective files: | ||||
|  | ||||
| ``` | ||||
| mongod \ | ||||
| --auth \ | ||||
| --tlsMode requireTLS \ | ||||
| --tlsCertificateKeyFile $PATH_TO_SERVER_KEY_FILE \ | ||||
| --tlsCAFile $PATH_TO_CA_FILE \ | ||||
| --tlsAllowInvalidCertificates | ||||
| ``` | ||||
|  | ||||
| To run the tests with `make`, set: | ||||
| - `MONGO_GO_DRIVER_CA_FILE` to the location of the CA file used by the database | ||||
| - `MONGO_GO_DRIVER_KEY_FILE` to the location of the client key file | ||||
| - `MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE` to the location of the pkcs8 client key file encrypted with the password string: `password` | ||||
| - `MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE` to the location of the unencrypted pkcs8 key file | ||||
| - `MONGODB_URI` to the connection string of the server | ||||
| - `AUTH=auth` | ||||
| - `SSL=ssl` | ||||
|  | ||||
| For example: | ||||
|  | ||||
| ``` | ||||
| AUTH=auth SSL=ssl \ | ||||
| MONGO_GO_DRIVER_CA_FILE=$PATH_TO_CA_FILE \ | ||||
| MONGO_GO_DRIVER_KEY_FILE=$PATH_TO_CLIENT_KEY_FILE \ | ||||
| MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE=$PATH_TO_ENCRYPTED_KEY_FILE \ | ||||
| MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE=$PATH_TO_UNENCRYPTED_KEY_FILE \ | ||||
| MONGODB_URI="mongodb://user:password@localhost:27017/?authSource=admin" \ | ||||
| make | ||||
| ``` | ||||
|  | ||||
| Notes: | ||||
| - The `--tlsAllowInvalidCertificates` flag is required on the server for the test suite to work correctly. | ||||
| - The test suite requires the auth database to be set with `?authSource=admin`, not `/admin`. | ||||
|  | ||||
| ### Testing Compression | ||||
|  | ||||
| The MongoDB Go Driver supports wire protocol compression using Snappy, zLib, or zstd. To run tests with wire protocol compression, set `MONGO_GO_DRIVER_COMPRESSOR` to `snappy`, `zlib`, or `zstd`.  For example: | ||||
|  | ||||
| ``` | ||||
| MONGO_GO_DRIVER_COMPRESSOR=snappy make | ||||
| ``` | ||||
|  | ||||
| Ensure the [`--networkMessageCompressors` flag](https://www.mongodb.com/docs/manual/reference/program/mongod/#cmdoption-mongod-networkmessagecompressors) on mongod or mongos includes `zlib` if testing zLib compression. | ||||
|  | ||||
| ------------------------- | ||||
| ## Contribution | ||||
|  | ||||
| Check out the [project page](https://jira.mongodb.org/browse/GODRIVER) for tickets that need completing. See our [contribution guidelines](docs/CONTRIBUTING.md) for details. | ||||
|  | ||||
| ------------------------- | ||||
| ## Continuous Integration | ||||
|  | ||||
| Commits to master are run automatically on [evergreen](https://evergreen.mongodb.com/waterfall/mongo-go-driver). | ||||
|  | ||||
| ------------------------- | ||||
| ## Frequently Encountered Issues | ||||
|  | ||||
| See our [common issues](docs/common-issues.md) documentation for troubleshooting frequently encountered issues. | ||||
|  | ||||
| ------------------------- | ||||
| ## Thanks and Acknowledgement | ||||
|  | ||||
| <a href="https://github.com/ashleymcnamara">@ashleymcnamara</a> - Mongo Gopher Artwork | ||||
|  | ||||
| ------------------------- | ||||
| ## License | ||||
|  | ||||
| The MongoDB Go Driver is licensed under the [Apache License](LICENSE). | ||||
							
								
								
									
										1554
									
								
								mongo/THIRD-PARTY-NOTICES
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1554
									
								
								mongo/THIRD-PARTY-NOTICES
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										307
									
								
								mongo/bson/benchmark_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										307
									
								
								mongo/bson/benchmark_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,307 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"compress/gzip" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| type encodetest struct { | ||||
| 	Field1String  string | ||||
| 	Field1Int64   int64 | ||||
| 	Field1Float64 float64 | ||||
| 	Field2String  string | ||||
| 	Field2Int64   int64 | ||||
| 	Field2Float64 float64 | ||||
| 	Field3String  string | ||||
| 	Field3Int64   int64 | ||||
| 	Field3Float64 float64 | ||||
| 	Field4String  string | ||||
| 	Field4Int64   int64 | ||||
| 	Field4Float64 float64 | ||||
| } | ||||
|  | ||||
| type nestedtest1 struct { | ||||
| 	Nested nestedtest2 | ||||
| } | ||||
|  | ||||
| type nestedtest2 struct { | ||||
| 	Nested nestedtest3 | ||||
| } | ||||
|  | ||||
| type nestedtest3 struct { | ||||
| 	Nested nestedtest4 | ||||
| } | ||||
|  | ||||
| type nestedtest4 struct { | ||||
| 	Nested nestedtest5 | ||||
| } | ||||
|  | ||||
| type nestedtest5 struct { | ||||
| 	Nested nestedtest6 | ||||
| } | ||||
|  | ||||
| type nestedtest6 struct { | ||||
| 	Nested nestedtest7 | ||||
| } | ||||
|  | ||||
| type nestedtest7 struct { | ||||
| 	Nested nestedtest8 | ||||
| } | ||||
|  | ||||
| type nestedtest8 struct { | ||||
| 	Nested nestedtest9 | ||||
| } | ||||
|  | ||||
| type nestedtest9 struct { | ||||
| 	Nested nestedtest10 | ||||
| } | ||||
|  | ||||
| type nestedtest10 struct { | ||||
| 	Nested nestedtest11 | ||||
| } | ||||
|  | ||||
| type nestedtest11 struct { | ||||
| 	Nested encodetest | ||||
| } | ||||
|  | ||||
| var encodetestInstance = encodetest{ | ||||
| 	Field1String:  "foo", | ||||
| 	Field1Int64:   1, | ||||
| 	Field1Float64: 3.0, | ||||
| 	Field2String:  "bar", | ||||
| 	Field2Int64:   2, | ||||
| 	Field2Float64: 3.1, | ||||
| 	Field3String:  "baz", | ||||
| 	Field3Int64:   3, | ||||
| 	Field3Float64: 3.14, | ||||
| 	Field4String:  "qux", | ||||
| 	Field4Int64:   4, | ||||
| 	Field4Float64: 3.141, | ||||
| } | ||||
|  | ||||
| var nestedInstance = nestedtest1{ | ||||
| 	nestedtest2{ | ||||
| 		nestedtest3{ | ||||
| 			nestedtest4{ | ||||
| 				nestedtest5{ | ||||
| 					nestedtest6{ | ||||
| 						nestedtest7{ | ||||
| 							nestedtest8{ | ||||
| 								nestedtest9{ | ||||
| 									nestedtest10{ | ||||
| 										nestedtest11{ | ||||
| 											encodetest{ | ||||
| 												Field1String:  "foo", | ||||
| 												Field1Int64:   1, | ||||
| 												Field1Float64: 3.0, | ||||
| 												Field2String:  "bar", | ||||
| 												Field2Int64:   2, | ||||
| 												Field2Float64: 3.1, | ||||
| 												Field3String:  "baz", | ||||
| 												Field3Int64:   3, | ||||
| 												Field3Float64: 3.14, | ||||
| 												Field4String:  "qux", | ||||
| 												Field4Int64:   4, | ||||
| 												Field4Float64: 3.141, | ||||
| 											}, | ||||
| 										}, | ||||
| 									}, | ||||
| 								}, | ||||
| 							}, | ||||
| 						}, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| const extendedBSONDir = "../testdata/extended_bson" | ||||
|  | ||||
| // readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the | ||||
| // "extended BSON" test data directory (../testdata/extended_bson) and returns it as a | ||||
| // map[string]interface{}. It panics on any errors. | ||||
| func readExtJSONFile(filename string) map[string]interface{} { | ||||
| 	filePath := path.Join(extendedBSONDir, filename) | ||||
| 	file, err := os.Open(filePath) | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Sprintf("error opening file %q: %s", filePath, err)) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		_ = file.Close() | ||||
| 	}() | ||||
|  | ||||
| 	gz, err := gzip.NewReader(file) | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Sprintf("error creating GZIP reader: %s", err)) | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		_ = gz.Close() | ||||
| 	}() | ||||
|  | ||||
| 	data, err := ioutil.ReadAll(gz) | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Sprintf("error reading GZIP contents of file: %s", err)) | ||||
| 	} | ||||
|  | ||||
| 	var v map[string]interface{} | ||||
| 	err = UnmarshalExtJSON(data, false, &v) | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err)) | ||||
| 	} | ||||
|  | ||||
| 	return v | ||||
| } | ||||
|  | ||||
| func BenchmarkMarshal(b *testing.B) { | ||||
| 	cases := []struct { | ||||
| 		desc  string | ||||
| 		value interface{} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			desc:  "simple struct", | ||||
| 			value: encodetestInstance, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "nested struct", | ||||
| 			value: nestedInstance, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "deep_bson.json.gz", | ||||
| 			value: readExtJSONFile("deep_bson.json.gz"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "flat_bson.json.gz", | ||||
| 			value: readExtJSONFile("flat_bson.json.gz"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "full_bson.json.gz", | ||||
| 			value: readExtJSONFile("full_bson.json.gz"), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		b.Run(tc.desc, func(b *testing.B) { | ||||
| 			b.Run("BSON", func(b *testing.B) { | ||||
| 				for i := 0; i < b.N; i++ { | ||||
| 					_, err := Marshal(tc.value) | ||||
| 					if err != nil { | ||||
| 						b.Errorf("error marshalling BSON: %s", err) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
|  | ||||
| 			b.Run("extJSON", func(b *testing.B) { | ||||
| 				for i := 0; i < b.N; i++ { | ||||
| 					_, err := MarshalExtJSON(tc.value, true, false) | ||||
| 					if err != nil { | ||||
| 						b.Errorf("error marshalling extended JSON: %s", err) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
|  | ||||
| 			b.Run("JSON", func(b *testing.B) { | ||||
| 				for i := 0; i < b.N; i++ { | ||||
| 					_, err := json.Marshal(tc.value) | ||||
| 					if err != nil { | ||||
| 						b.Errorf("error marshalling JSON: %s", err) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkUnmarshal(b *testing.B) { | ||||
| 	cases := []struct { | ||||
| 		desc  string | ||||
| 		value interface{} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			desc:  "simple struct", | ||||
| 			value: encodetestInstance, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "nested struct", | ||||
| 			value: nestedInstance, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "deep_bson.json.gz", | ||||
| 			value: readExtJSONFile("deep_bson.json.gz"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "flat_bson.json.gz", | ||||
| 			value: readExtJSONFile("flat_bson.json.gz"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "full_bson.json.gz", | ||||
| 			value: readExtJSONFile("full_bson.json.gz"), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		b.Run(tc.desc, func(b *testing.B) { | ||||
| 			b.Run("BSON", func(b *testing.B) { | ||||
| 				data, err := Marshal(tc.value) | ||||
| 				if err != nil { | ||||
| 					b.Errorf("error marshalling BSON: %s", err) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				b.ResetTimer() | ||||
| 				var v2 map[string]interface{} | ||||
| 				for i := 0; i < b.N; i++ { | ||||
| 					err := Unmarshal(data, &v2) | ||||
| 					if err != nil { | ||||
| 						b.Errorf("error unmarshalling BSON: %s", err) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
|  | ||||
| 			b.Run("extJSON", func(b *testing.B) { | ||||
| 				data, err := MarshalExtJSON(tc.value, true, false) | ||||
| 				if err != nil { | ||||
| 					b.Errorf("error marshalling extended JSON: %s", err) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				b.ResetTimer() | ||||
| 				var v2 map[string]interface{} | ||||
| 				for i := 0; i < b.N; i++ { | ||||
| 					err := UnmarshalExtJSON(data, true, &v2) | ||||
| 					if err != nil { | ||||
| 						b.Errorf("error unmarshalling extended JSON: %s", err) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
|  | ||||
| 			b.Run("JSON", func(b *testing.B) { | ||||
| 				data, err := json.Marshal(tc.value) | ||||
| 				if err != nil { | ||||
| 					b.Errorf("error marshalling JSON: %s", err) | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				b.ResetTimer() | ||||
| 				var v2 map[string]interface{} | ||||
| 				for i := 0; i < b.N; i++ { | ||||
| 					err := json.Unmarshal(data, &v2) | ||||
| 					if err != nil { | ||||
| 						b.Errorf("error unmarshalling JSON: %s", err) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										50
									
								
								mongo/bson/bson.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								mongo/bson/bson.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer | ||||
| // See THIRD-PARTY-NOTICES for original license terms. | ||||
|  | ||||
| package bson // import "go.mongodb.org/mongo-driver/bson" | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| // Zeroer allows custom struct types to implement a report of zero | ||||
| // state. All struct types that don't implement Zeroer or where IsZero | ||||
| // returns false are considered to be not zero. | ||||
| type Zeroer interface { | ||||
| 	IsZero() bool | ||||
| } | ||||
|  | ||||
| // D is an ordered representation of a BSON document. This type should be used when the order of the elements matters, | ||||
| // such as MongoDB command documents. If the order of the elements does not matter, an M should be used instead. | ||||
| // | ||||
| // A D should not be constructed with duplicate key names, as that can cause undefined server behavior. | ||||
| // | ||||
| // Example usage: | ||||
| // | ||||
| //	bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} | ||||
| type D = primitive.D | ||||
|  | ||||
| // E represents a BSON element for a D. It is usually used inside a D. | ||||
| type E = primitive.E | ||||
|  | ||||
| // M is an unordered representation of a BSON document. This type should be used when the order of the elements does not | ||||
| // matter. This type is handled as a regular map[string]interface{} when encoding and decoding. Elements will be | ||||
| // serialized in an undefined, random order. If the order of the elements matters, a D should be used instead. | ||||
| // | ||||
| // Example usage: | ||||
| // | ||||
| //	bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} | ||||
| type M = primitive.M | ||||
|  | ||||
| // An A is an ordered representation of a BSON array. | ||||
| // | ||||
| // Example usage: | ||||
| // | ||||
| //	bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}} | ||||
| type A = primitive.A | ||||
							
								
								
									
										530
									
								
								mongo/bson/bson_corpus_spec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										530
									
								
								mongo/bson/bson_corpus_spec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,530 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"encoding/hex" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"unicode" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"github.com/tidwall/pretty" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| ) | ||||
|  | ||||
| type testCase struct { | ||||
| 	Description  string                `json:"description"` | ||||
| 	BsonType     string                `json:"bson_type"` | ||||
| 	TestKey      *string               `json:"test_key"` | ||||
| 	Valid        []validityTestCase    `json:"valid"` | ||||
| 	DecodeErrors []decodeErrorTestCase `json:"decodeErrors"` | ||||
| 	ParseErrors  []parseErrorTestCase  `json:"parseErrors"` | ||||
| 	Deprecated   *bool                 `json:"deprecated"` | ||||
| } | ||||
|  | ||||
| type validityTestCase struct { | ||||
| 	Description       string  `json:"description"` | ||||
| 	CanonicalBson     string  `json:"canonical_bson"` | ||||
| 	CanonicalExtJSON  string  `json:"canonical_extjson"` | ||||
| 	RelaxedExtJSON    *string `json:"relaxed_extjson"` | ||||
| 	DegenerateBSON    *string `json:"degenerate_bson"` | ||||
| 	DegenerateExtJSON *string `json:"degenerate_extjson"` | ||||
| 	ConvertedBSON     *string `json:"converted_bson"` | ||||
| 	ConvertedExtJSON  *string `json:"converted_extjson"` | ||||
| 	Lossy             *bool   `json:"lossy"` | ||||
| } | ||||
|  | ||||
| type decodeErrorTestCase struct { | ||||
| 	Description string `json:"description"` | ||||
| 	Bson        string `json:"bson"` | ||||
| } | ||||
|  | ||||
| type parseErrorTestCase struct { | ||||
| 	Description string `json:"description"` | ||||
| 	String      string `json:"string"` | ||||
| } | ||||
|  | ||||
| const dataDir = "../testdata/bson-corpus/" | ||||
|  | ||||
| func findJSONFilesInDir(dir string) ([]string, error) { | ||||
| 	files := make([]string, 0) | ||||
|  | ||||
| 	entries, err := os.ReadDir(dir) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	for _, entry := range entries { | ||||
| 		if entry.IsDir() || path.Ext(entry.Name()) != ".json" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		files = append(files, entry.Name()) | ||||
| 	} | ||||
|  | ||||
| 	return files, nil | ||||
| } | ||||
|  | ||||
| // seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus. | ||||
| func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) { | ||||
| 	jbytes, err := jsonToBytes(extJSON, extJSONType, desc) | ||||
| 	if err != nil { | ||||
| 		f.Fatalf("failed to convert JSON to bytes: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	f.Add(jbytes) | ||||
| } | ||||
|  | ||||
| // seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's | ||||
| // corpus. | ||||
| func seedTestCase(f *testing.F, tcase *testCase) { | ||||
| 	for _, vtc := range tcase.Valid { | ||||
| 		seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description) | ||||
|  | ||||
| 		// Seed the relaxed extended JSON. | ||||
| 		if vtc.RelaxedExtJSON != nil { | ||||
| 			seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description) | ||||
| 		} | ||||
|  | ||||
| 		// Seed the degenerate extended JSON. | ||||
| 		if vtc.DegenerateExtJSON != nil { | ||||
| 			seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description) | ||||
| 		} | ||||
|  | ||||
| 		// Seed the converted extended JSON. | ||||
| 		if vtc.ConvertedExtJSON != nil { | ||||
| 			seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then | ||||
| // marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus. | ||||
| func seedBSONCorpus(f *testing.F) { | ||||
| 	fileNames, err := findJSONFilesInDir(dataDir) | ||||
| 	if err != nil { | ||||
| 		f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err) | ||||
| 	} | ||||
|  | ||||
| 	for _, fileName := range fileNames { | ||||
| 		filePath := path.Join(dataDir, fileName) | ||||
|  | ||||
| 		file, err := os.Open(filePath) | ||||
| 		if err != nil { | ||||
| 			f.Fatalf("failed to open file %q: %v", filePath, err) | ||||
| 		} | ||||
|  | ||||
| 		var tcase testCase | ||||
| 		if err := json.NewDecoder(file).Decode(&tcase); err != nil { | ||||
| 			f.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		seedTestCase(f, &tcase) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func needsEscapedUnicode(bsonType string) bool { | ||||
| 	return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F" | ||||
| } | ||||
|  | ||||
| func unescapeUnicode(s, bsonType string) string { | ||||
| 	if !needsEscapedUnicode(bsonType) { | ||||
| 		return s | ||||
| 	} | ||||
|  | ||||
| 	newS := "" | ||||
|  | ||||
| 	for i := 0; i < len(s); i++ { | ||||
| 		c := s[i] | ||||
| 		switch c { | ||||
| 		case '\\': | ||||
| 			switch s[i+1] { | ||||
| 			case 'u': | ||||
| 				us := s[i : i+6] | ||||
| 				u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1)) | ||||
| 				if err != nil { | ||||
| 					return "" | ||||
| 				} | ||||
| 				for _, r := range u { | ||||
| 					if r < ' ' { | ||||
| 						newS += fmt.Sprintf(`\u%04x`, r) | ||||
| 					} else { | ||||
| 						newS += string(r) | ||||
| 					} | ||||
| 				} | ||||
| 				i += 5 | ||||
| 			default: | ||||
| 				newS += string(c) | ||||
| 			} | ||||
| 		default: | ||||
| 			if c > unicode.MaxASCII { | ||||
| 				r, size := utf8.DecodeRune([]byte(s[i:])) | ||||
| 				newS += string(r) | ||||
| 				i += size - 1 | ||||
| 			} else { | ||||
| 				newS += string(c) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return newS | ||||
| } | ||||
|  | ||||
| func formatDouble(f float64) string { | ||||
| 	var s string | ||||
| 	if math.IsInf(f, 1) { | ||||
| 		s = "Infinity" | ||||
| 	} else if math.IsInf(f, -1) { | ||||
| 		s = "-Infinity" | ||||
| 	} else if math.IsNaN(f) { | ||||
| 		s = "NaN" | ||||
| 	} else { | ||||
| 		// Print exactly one decimalType place for integers; otherwise, print as many are necessary to | ||||
| 		// perfectly represent it. | ||||
| 		s = strconv.FormatFloat(f, 'G', -1, 64) | ||||
| 		if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') { | ||||
| 			s += ".0" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string { | ||||
| 	// Unmarshal string into map | ||||
| 	cEJMap := make(map[string]map[string]string) | ||||
| 	err := json.Unmarshal([]byte(cEJ), &cEJMap) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	// Parse the float contained by the map. | ||||
| 	expectedString := cEJMap[key]["$numberDouble"] | ||||
| 	expectedFloat, err := strconv.ParseFloat(expectedString, 64) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	// Normalize the string | ||||
| 	return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat)) | ||||
| } | ||||
|  | ||||
| func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string { | ||||
| 	// Unmarshal string into map | ||||
| 	rEJMap := make(map[string]float64) | ||||
| 	err := json.Unmarshal([]byte(rEJ), &rEJMap) | ||||
| 	if err != nil { | ||||
| 		return normalizeCanonicalDouble(t, key, rEJ) | ||||
| 	} | ||||
|  | ||||
| 	// Parse the float contained by the map. | ||||
| 	expectedFloat := rEJMap[key] | ||||
|  | ||||
| 	// Normalize the string | ||||
| 	return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat)) | ||||
| } | ||||
|  | ||||
| // bsonToNative decodes the BSON bytes (b) into a native Document | ||||
| func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D { | ||||
| 	var doc D | ||||
| 	err := Unmarshal(b, &doc) | ||||
| 	expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType)) | ||||
| 	return doc | ||||
| } | ||||
|  | ||||
| // nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected | ||||
| // canonical BSON (cB) | ||||
| func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) { | ||||
| 	actual, err := Marshal(doc) | ||||
| 	expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType)) | ||||
|  | ||||
| 	if diff := cmp.Diff(cB, actual); diff != "" { | ||||
| 		t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n", | ||||
| 			testDesc, docSrcDesc, cB, actual) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // jsonToNative decodes the extended JSON string (ej) into a native Document | ||||
| func jsonToNative(ej, ejType, testDesc string) (D, error) { | ||||
| 	var doc D | ||||
| 	if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil { | ||||
| 		return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err) | ||||
| 	} | ||||
| 	return doc, nil | ||||
| } | ||||
|  | ||||
| // jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice. | ||||
| func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) { | ||||
| 	native, err := jsonToNative(ej, ejType, testDesc) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	b, err := Marshal(native) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err) | ||||
| 	} | ||||
|  | ||||
| 	return b, nil | ||||
| } | ||||
|  | ||||
| // nativeToJSON encodes the native Document (doc) into an extended JSON string | ||||
| func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) { | ||||
| 	actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true) | ||||
| 	expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType)) | ||||
|  | ||||
| 	if diff := cmp.Diff(ej, string(actualEJ)); diff != "" { | ||||
| 		t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n", | ||||
| 			testDesc, ejType, docSrcDesc, ejShortName, diff) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func runTest(t *testing.T, file string) { | ||||
| 	filepath := path.Join(dataDir, file) | ||||
| 	content, err := os.ReadFile(filepath) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	// Remove ".json" from filename. | ||||
| 	file = file[:len(file)-5] | ||||
| 	testName := "bson_corpus--" + file | ||||
|  | ||||
| 	t.Run(testName, func(t *testing.T) { | ||||
| 		var test testCase | ||||
| 		require.NoError(t, json.Unmarshal(content, &test)) | ||||
|  | ||||
| 		t.Run("valid", func(t *testing.T) { | ||||
| 			for _, v := range test.Valid { | ||||
| 				t.Run(v.Description, func(t *testing.T) { | ||||
| 					// get canonical BSON | ||||
| 					cB, err := hex.DecodeString(v.CanonicalBson) | ||||
| 					expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description)) | ||||
|  | ||||
| 					// get canonical extended JSON | ||||
| 					cEJ := unescapeUnicode(string(pretty.Ugly([]byte(v.CanonicalExtJSON))), test.BsonType) | ||||
| 					if test.BsonType == "0x01" { | ||||
| 						cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ) | ||||
| 					} | ||||
|  | ||||
| 					/*** canonical BSON round-trip tests ***/ | ||||
| 					doc := bsonToNative(t, cB, "canonical", v.Description) | ||||
|  | ||||
| 					// native_to_bson(bson_to_native(cB)) = cB | ||||
| 					nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)") | ||||
|  | ||||
| 					// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ | ||||
| 					nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)") | ||||
|  | ||||
| 					// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists) | ||||
| 					if v.RelaxedExtJSON != nil { | ||||
| 						rEJ := unescapeUnicode(string(pretty.Ugly([]byte(*v.RelaxedExtJSON))), test.BsonType) | ||||
| 						if test.BsonType == "0x01" { | ||||
| 							rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ) | ||||
| 						} | ||||
|  | ||||
| 						nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)") | ||||
|  | ||||
| 						/*** relaxed extended JSON round-trip tests (if exists) ***/ | ||||
| 						doc, err = jsonToNative(rEJ, "relaxed", v.Description) | ||||
| 						require.NoError(t, err) | ||||
|  | ||||
| 						// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ | ||||
| 						nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)") | ||||
| 					} | ||||
|  | ||||
| 					/*** canonical extended JSON round-trip tests ***/ | ||||
| 					doc, err = jsonToNative(cEJ, "canonical", v.Description) | ||||
| 					require.NoError(t, err) | ||||
|  | ||||
| 					// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ | ||||
| 					nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)") | ||||
|  | ||||
| 					// native_to_bson(json_to_native(cEJ)) = cb (unless lossy) | ||||
| 					if v.Lossy == nil || !*v.Lossy { | ||||
| 						nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)") | ||||
| 					} | ||||
|  | ||||
| 					/*** degenerate BSON round-trip tests (if exists) ***/ | ||||
| 					if v.DegenerateBSON != nil { | ||||
| 						dB, err := hex.DecodeString(*v.DegenerateBSON) | ||||
| 						expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description)) | ||||
|  | ||||
| 						doc = bsonToNative(t, dB, "degenerate", v.Description) | ||||
|  | ||||
| 						// native_to_bson(bson_to_native(dB)) = cB | ||||
| 						nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)") | ||||
| 					} | ||||
|  | ||||
| 					/*** degenerate JSON round-trip tests (if exists) ***/ | ||||
| 					if v.DegenerateExtJSON != nil { | ||||
| 						dEJ := unescapeUnicode(string(pretty.Ugly([]byte(*v.DegenerateExtJSON))), test.BsonType) | ||||
| 						if test.BsonType == "0x01" { | ||||
| 							dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ) | ||||
| 						} | ||||
|  | ||||
| 						doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description) | ||||
| 						require.NoError(t, err) | ||||
|  | ||||
| 						// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ | ||||
| 						nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)") | ||||
|  | ||||
| 						// native_to_bson(json_to_native(dEJ)) = cB (unless lossy) | ||||
| 						if v.Lossy == nil || !*v.Lossy { | ||||
| 							nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)") | ||||
| 						} | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
| 		}) | ||||
|  | ||||
| 		t.Run("decode error", func(t *testing.T) { | ||||
| 			for _, d := range test.DecodeErrors { | ||||
| 				t.Run(d.Description, func(t *testing.T) { | ||||
| 					b, err := hex.DecodeString(d.Bson) | ||||
| 					expectNoError(t, err, d.Description) | ||||
|  | ||||
| 					var doc D | ||||
| 					err = Unmarshal(b, &doc) | ||||
|  | ||||
| 					// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements | ||||
| 					// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8 | ||||
| 					// characters. | ||||
| 					for _, elem := range doc { | ||||
| 						str, ok := elem.Value.(string) | ||||
| 						invalidString := ok && !utf8.ValidString(str) | ||||
| 						dbPtr, ok := elem.Value.(primitive.DBPointer) | ||||
| 						invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB) | ||||
|  | ||||
| 						if invalidString || invalidDBPtr { | ||||
| 							expectNoError(t, err, d.Description) | ||||
| 							return | ||||
| 						} | ||||
| 					} | ||||
|  | ||||
| 					expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description)) | ||||
| 				}) | ||||
| 			} | ||||
| 		}) | ||||
|  | ||||
| 		t.Run("parse error", func(t *testing.T) { | ||||
| 			for _, p := range test.ParseErrors { | ||||
| 				t.Run(p.Description, func(t *testing.T) { | ||||
| 					s := unescapeUnicode(p.String, test.BsonType) | ||||
| 					if test.BsonType == "0x13" { | ||||
| 						s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s) | ||||
| 					} | ||||
|  | ||||
| 					switch test.BsonType { | ||||
| 					case "0x00", "0x05", "0x13": | ||||
| 						var doc D | ||||
| 						err := UnmarshalExtJSON([]byte(s), true, &doc) | ||||
| 						// Null bytes are validated when marshaling to BSON | ||||
| 						if strings.Contains(p.Description, "Null") { | ||||
| 							_, err = Marshal(doc) | ||||
| 						} | ||||
| 						expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description)) | ||||
| 					default: | ||||
| 						t.Errorf("Update test to check for parse errors for type %s", test.BsonType) | ||||
| 						t.Fail() | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func Test_BsonCorpus(t *testing.T) { | ||||
| 	jsonFiles, err := findJSONFilesInDir(dataDir) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("error finding JSON files in %s: %v", dataDir, err) | ||||
| 	} | ||||
|  | ||||
| 	for _, file := range jsonFiles { | ||||
| 		runTest(t, file) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectNoError(t *testing.T, err error, desc string) { | ||||
| 	if err != nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Unepexted error: %v", desc, err) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectError(t *testing.T, err error, desc string) { | ||||
| 	if err == nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Expected error", desc) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRelaxedUUIDValidation(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		description       string | ||||
| 		canonicalExtJSON  string | ||||
| 		degenerateExtJSON string | ||||
| 		expectedErr       string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"valid uuid", | ||||
| 			"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}", | ||||
| 			"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}", | ||||
| 			"", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"invalid uuid--no hyphens", | ||||
| 			"", | ||||
| 			"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}", | ||||
| 			"$uuid value does not follow RFC 4122 format regarding length and hyphens", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"invalid uuid--trailing hyphens", | ||||
| 			"", | ||||
| 			"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}", | ||||
| 			"$uuid value does not follow RFC 4122 format regarding length and hyphens", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"invalid uuid--malformed hex", | ||||
| 			"", | ||||
| 			"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}", | ||||
| 			"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.description, func(t *testing.T) { | ||||
| 			// get canonical extended JSON | ||||
| 			cEJ := unescapeUnicode(string(pretty.Ugly([]byte(tc.canonicalExtJSON))), "0x05") | ||||
|  | ||||
| 			// get degenerate extended JSON | ||||
| 			dEJ := unescapeUnicode(string(pretty.Ugly([]byte(tc.degenerateExtJSON))), "0x05") | ||||
|  | ||||
| 			// convert dEJ to native doc | ||||
| 			var doc D | ||||
| 			err := UnmarshalExtJSON([]byte(dEJ), true, &doc) | ||||
|  | ||||
| 			if tc.expectedErr != "" { | ||||
| 				assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err) | ||||
| 			} else { | ||||
| 				assert.Nil(t, err, "expected no error, got error: %v", err) | ||||
|  | ||||
| 				// Marshal doc into extended JSON and compare with cEJ | ||||
| 				nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)") | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| } | ||||
							
								
								
									
										279
									
								
								mongo/bson/bson_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										279
									
								
								mongo/bson/bson_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,279 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| func noerr(t *testing.T, err error) { | ||||
| 	if err != nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("Unexpected error: (%T)%v", err, err) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestTimeRoundTrip(t *testing.T) { | ||||
| 	val := struct { | ||||
| 		Value time.Time | ||||
| 		ID    string | ||||
| 	}{ | ||||
| 		ID: "time-rt-test", | ||||
| 	} | ||||
|  | ||||
| 	if !val.Value.IsZero() { | ||||
| 		t.Errorf("Did not get zero time as expected.") | ||||
| 	} | ||||
|  | ||||
| 	bsonOut, err := Marshal(val) | ||||
| 	noerr(t, err) | ||||
| 	rtval := struct { | ||||
| 		Value time.Time | ||||
| 		ID    string | ||||
| 	}{} | ||||
|  | ||||
| 	err = Unmarshal(bsonOut, &rtval) | ||||
| 	noerr(t, err) | ||||
| 	if !cmp.Equal(val, rtval) { | ||||
| 		t.Errorf("Did not round trip properly. got %v; want %v", val, rtval) | ||||
| 	} | ||||
| 	if !rtval.Value.IsZero() { | ||||
| 		t.Errorf("Did not get zero time as expected.") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestNonNullTimeRoundTrip(t *testing.T) { | ||||
| 	now := time.Now() | ||||
| 	now = time.Unix(now.Unix(), 0) | ||||
| 	val := struct { | ||||
| 		Value time.Time | ||||
| 		ID    string | ||||
| 	}{ | ||||
| 		ID:    "time-rt-test", | ||||
| 		Value: now, | ||||
| 	} | ||||
|  | ||||
| 	bsonOut, err := Marshal(val) | ||||
| 	noerr(t, err) | ||||
| 	rtval := struct { | ||||
| 		Value time.Time | ||||
| 		ID    string | ||||
| 	}{} | ||||
|  | ||||
| 	err = Unmarshal(bsonOut, &rtval) | ||||
| 	noerr(t, err) | ||||
| 	if !cmp.Equal(val, rtval) { | ||||
| 		t.Errorf("Did not round trip properly. got %v; want %v", val, rtval) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestD(t *testing.T) { | ||||
| 	t.Run("can marshal", func(t *testing.T) { | ||||
| 		d := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} | ||||
| 		idx, want := bsoncore.AppendDocumentStart(nil) | ||||
| 		want = bsoncore.AppendStringElement(want, "foo", "bar") | ||||
| 		want = bsoncore.AppendStringElement(want, "hello", "world") | ||||
| 		want = bsoncore.AppendDoubleElement(want, "pi", 3.14159) | ||||
| 		want, err := bsoncore.AppendDocumentEnd(want, idx) | ||||
| 		noerr(t, err) | ||||
| 		got, err := Marshal(d) | ||||
| 		noerr(t, err) | ||||
| 		if !bytes.Equal(got, want) { | ||||
| 			t.Errorf("Marshaled documents do not match. got %v; want %v", Raw(got), Raw(want)) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("can unmarshal", func(t *testing.T) { | ||||
| 		want := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} | ||||
| 		idx, doc := bsoncore.AppendDocumentStart(nil) | ||||
| 		doc = bsoncore.AppendStringElement(doc, "foo", "bar") | ||||
| 		doc = bsoncore.AppendStringElement(doc, "hello", "world") | ||||
| 		doc = bsoncore.AppendDoubleElement(doc, "pi", 3.14159) | ||||
| 		doc, err := bsoncore.AppendDocumentEnd(doc, idx) | ||||
| 		noerr(t, err) | ||||
| 		var got D | ||||
| 		err = Unmarshal(doc, &got) | ||||
| 		noerr(t, err) | ||||
| 		if !cmp.Equal(got, want) { | ||||
| 			t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type stringerString string | ||||
|  | ||||
| func (ss stringerString) String() string { | ||||
| 	return "bar" | ||||
| } | ||||
|  | ||||
| type keyBool bool | ||||
|  | ||||
| func (kb keyBool) MarshalKey() (string, error) { | ||||
| 	return fmt.Sprintf("%v", kb), nil | ||||
| } | ||||
|  | ||||
| func (kb *keyBool) UnmarshalKey(key string) error { | ||||
| 	switch key { | ||||
| 	case "true": | ||||
| 		*kb = true | ||||
| 	case "false": | ||||
| 		*kb = false | ||||
| 	default: | ||||
| 		return fmt.Errorf("invalid bool value %v", key) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type keyStruct struct { | ||||
| 	val int64 | ||||
| } | ||||
|  | ||||
| func (k keyStruct) MarshalText() (text []byte, err error) { | ||||
| 	str := strconv.FormatInt(k.val, 10) | ||||
|  | ||||
| 	return []byte(str), nil | ||||
| } | ||||
|  | ||||
| func (k *keyStruct) UnmarshalText(text []byte) error { | ||||
| 	val, err := strconv.ParseInt(string(text), 10, 64) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	*k = keyStruct{ | ||||
| 		val: val, | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func TestMapCodec(t *testing.T) { | ||||
| 	t.Run("EncodeKeysWithStringer", func(t *testing.T) { | ||||
| 		strstr := stringerString("foo") | ||||
| 		mapObj := map[stringerString]int{strstr: 1} | ||||
| 		testCases := []struct { | ||||
| 			name string | ||||
| 			opts *bsonoptions.MapCodecOptions | ||||
| 			key  string | ||||
| 		}{ | ||||
| 			{"default", bsonoptions.MapCodec(), "foo"}, | ||||
| 			{"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"}, | ||||
| 			{"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				mapCodec := bsoncodec.NewMapCodec(tc.opts) | ||||
| 				mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build() | ||||
| 				val, err := MarshalWithRegistry(mapRegistry, mapObj) | ||||
| 				assert.Nil(t, err, "Marshal error: %v", err) | ||||
| 				assert.True(t, strings.Contains(string(val), tc.key), "expected result to contain %v, got: %v", tc.key, string(val)) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) { | ||||
| 		mapObj := map[keyBool]int{keyBool(true): 1} | ||||
|  | ||||
| 		doc, err := Marshal(mapObj) | ||||
| 		assert.Nil(t, err, "Marshal error: %v", err) | ||||
| 		idx, want := bsoncore.AppendDocumentStart(nil) | ||||
| 		want = bsoncore.AppendInt32Element(want, "true", 1) | ||||
| 		want, _ = bsoncore.AppendDocumentEnd(want, idx) | ||||
| 		assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc)) | ||||
|  | ||||
| 		var got map[keyBool]int | ||||
| 		err = Unmarshal(doc, &got) | ||||
| 		assert.Nil(t, err, "Unmarshal error: %v", err) | ||||
| 		assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got) | ||||
|  | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) { | ||||
| 		mapObj := map[keyStruct]int{ | ||||
| 			{val: 10}: 100, | ||||
| 		} | ||||
|  | ||||
| 		doc, err := Marshal(mapObj) | ||||
| 		assert.Nil(t, err, "Marshal error: %v", err) | ||||
| 		idx, want := bsoncore.AppendDocumentStart(nil) | ||||
| 		want = bsoncore.AppendInt32Element(want, "10", 100) | ||||
| 		want, _ = bsoncore.AppendDocumentEnd(want, idx) | ||||
| 		assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc)) | ||||
|  | ||||
| 		var got map[keyStruct]int | ||||
| 		err = Unmarshal(doc, &got) | ||||
| 		assert.Nil(t, err, "Unmarshal error: %v", err) | ||||
| 		assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got) | ||||
|  | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestExtJSONEscapeKey(t *testing.T) { | ||||
| 	doc := D{{Key: "\\usb#", Value: int32(1)}} | ||||
| 	b, err := MarshalExtJSON(&doc, false, false) | ||||
| 	noerr(t, err) | ||||
|  | ||||
| 	want := "{\"\\\\usb#\":1}" | ||||
| 	if diff := cmp.Diff(want, string(b)); diff != "" { | ||||
| 		t.Errorf("Marshaled documents do not match. got %v, want %v", string(b), want) | ||||
| 	} | ||||
|  | ||||
| 	var got D | ||||
| 	err = UnmarshalExtJSON(b, false, &got) | ||||
| 	noerr(t, err) | ||||
| 	if !cmp.Equal(got, doc) { | ||||
| 		t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, doc) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestBsoncoreArray(t *testing.T) { | ||||
| 	type BSONDocumentArray struct { | ||||
| 		Array []D `bson:"array"` | ||||
| 	} | ||||
|  | ||||
| 	type BSONArray struct { | ||||
| 		Array bsoncore.Array `bson:"array"` | ||||
| 	} | ||||
|  | ||||
| 	bda := BSONDocumentArray{ | ||||
| 		Array: []D{ | ||||
| 			{{"x", 1}}, | ||||
| 			{{"x", 2}}, | ||||
| 			{{"x", 3}}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	expectedBSON, err := Marshal(bda) | ||||
| 	assert.Nil(t, err, "Marshal bsoncore.Document array error: %v", err) | ||||
|  | ||||
| 	var ba BSONArray | ||||
| 	err = Unmarshal(expectedBSON, &ba) | ||||
| 	assert.Nil(t, err, "Unmarshal error: %v", err) | ||||
|  | ||||
| 	actualBSON, err := Marshal(ba) | ||||
| 	assert.Nil(t, err, "Marshal bsoncore.Array error: %v", err) | ||||
|  | ||||
| 	assert.Equal(t, expectedBSON, actualBSON, | ||||
| 		"expected BSON to be %v after Marshalling again; got %v", expectedBSON, actualBSON) | ||||
|  | ||||
| 	doc := bsoncore.Document(actualBSON) | ||||
| 	v := doc.Lookup("array") | ||||
| 	assert.Equal(t, bsontype.Array, v.Type, "expected type array, got %v", v.Type) | ||||
| } | ||||
							
								
								
									
										50
									
								
								mongo/bson/bsoncodec/array_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								mongo/bson/bsoncodec/array_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| // ArrayCodec is the Codec used for bsoncore.Array values. | ||||
| type ArrayCodec struct{} | ||||
|  | ||||
| var defaultArrayCodec = NewArrayCodec() | ||||
|  | ||||
| // NewArrayCodec returns an ArrayCodec. | ||||
| func NewArrayCodec() *ArrayCodec { | ||||
| 	return &ArrayCodec{} | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoder for bsoncore.Array values. | ||||
| func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tCoreArray { | ||||
| 		return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	arr := val.Interface().(bsoncore.Array) | ||||
| 	return bsonrw.Copier{}.CopyArrayFromBytes(vw, arr) | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoder for bsoncore.Array values. | ||||
| func (ac *ArrayCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Type() != tCoreArray { | ||||
| 		return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		val.Set(reflect.MakeSlice(val.Type(), 0, 0)) | ||||
| 	} | ||||
|  | ||||
| 	val.SetLen(0) | ||||
| 	arr, err := bsonrw.Copier{}.AppendArrayBytes(val.Interface().(bsoncore.Array), vr) | ||||
| 	val.Set(reflect.ValueOf(arr)) | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										238
									
								
								mongo/bson/bsoncodec/bsoncodec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								mongo/bson/bsoncodec/bsoncodec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,238 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec // import "go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	emptyValue = reflect.Value{} | ||||
| ) | ||||
|  | ||||
| // Marshaler is an interface implemented by types that can marshal themselves | ||||
| // into a BSON document represented as bytes. The bytes returned must be a valid | ||||
| // BSON document if the error is nil. | ||||
| type Marshaler interface { | ||||
| 	MarshalBSON() ([]byte, error) | ||||
| } | ||||
|  | ||||
| // ValueMarshaler is an interface implemented by types that can marshal | ||||
| // themselves into a BSON value as bytes. The type must be the valid type for | ||||
| // the bytes returned. The bytes and byte type together must be valid if the | ||||
| // error is nil. | ||||
| type ValueMarshaler interface { | ||||
| 	MarshalBSONValue() (bsontype.Type, []byte, error) | ||||
| } | ||||
|  | ||||
| // Unmarshaler is an interface implemented by types that can unmarshal a BSON | ||||
| // document representation of themselves. The BSON bytes can be assumed to be | ||||
| // valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data | ||||
| // after returning. | ||||
| type Unmarshaler interface { | ||||
| 	UnmarshalBSON([]byte) error | ||||
| } | ||||
|  | ||||
| // ValueUnmarshaler is an interface implemented by types that can unmarshal a | ||||
| // BSON value representation of themselves. The BSON bytes and type can be | ||||
| // assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it | ||||
| // wishes to retain the data after returning. | ||||
| type ValueUnmarshaler interface { | ||||
| 	UnmarshalBSONValue(bsontype.Type, []byte) error | ||||
| } | ||||
|  | ||||
| // ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be | ||||
| // encoded by the ValueEncoder. | ||||
| type ValueEncoderError struct { | ||||
| 	Name     string | ||||
| 	Types    []reflect.Type | ||||
| 	Kinds    []reflect.Kind | ||||
| 	Received reflect.Value | ||||
| } | ||||
|  | ||||
| func (vee ValueEncoderError) Error() string { | ||||
| 	typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds)) | ||||
| 	for _, t := range vee.Types { | ||||
| 		typeKinds = append(typeKinds, t.String()) | ||||
| 	} | ||||
| 	for _, k := range vee.Kinds { | ||||
| 		if k == reflect.Map { | ||||
| 			typeKinds = append(typeKinds, "map[string]*") | ||||
| 			continue | ||||
| 		} | ||||
| 		typeKinds = append(typeKinds, k.String()) | ||||
| 	} | ||||
| 	received := vee.Received.Kind().String() | ||||
| 	if vee.Received.IsValid() { | ||||
| 		received = vee.Received.Type().String() | ||||
| 	} | ||||
| 	return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received) | ||||
| } | ||||
|  | ||||
| // ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be | ||||
| // decoded by the ValueDecoder. | ||||
| type ValueDecoderError struct { | ||||
| 	Name     string | ||||
| 	Types    []reflect.Type | ||||
| 	Kinds    []reflect.Kind | ||||
| 	Received reflect.Value | ||||
| } | ||||
|  | ||||
| func (vde ValueDecoderError) Error() string { | ||||
| 	typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds)) | ||||
| 	for _, t := range vde.Types { | ||||
| 		typeKinds = append(typeKinds, t.String()) | ||||
| 	} | ||||
| 	for _, k := range vde.Kinds { | ||||
| 		if k == reflect.Map { | ||||
| 			typeKinds = append(typeKinds, "map[string]*") | ||||
| 			continue | ||||
| 		} | ||||
| 		typeKinds = append(typeKinds, k.String()) | ||||
| 	} | ||||
| 	received := vde.Received.Kind().String() | ||||
| 	if vde.Received.IsValid() { | ||||
| 		received = vde.Received.Type().String() | ||||
| 	} | ||||
| 	return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received) | ||||
| } | ||||
|  | ||||
| // EncodeContext is the contextual information required for a Codec to encode a | ||||
| // value. | ||||
| type EncodeContext struct { | ||||
| 	*Registry | ||||
| 	MinSize bool | ||||
| } | ||||
|  | ||||
| // DecodeContext is the contextual information required for a Codec to decode a | ||||
| // value. | ||||
| type DecodeContext struct { | ||||
| 	*Registry | ||||
| 	Truncate bool | ||||
|  | ||||
| 	// Ancestor is the type of a containing document. This is mainly used to determine what type | ||||
| 	// should be used when decoding an embedded document into an empty interface. For example, if | ||||
| 	// Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface | ||||
| 	// will be decoded into a bson.M. | ||||
| 	// | ||||
| 	// Deprecated: Use DefaultDocumentM or DefaultDocumentD instead. | ||||
| 	Ancestor reflect.Type | ||||
|  | ||||
| 	// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the | ||||
| 	// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is | ||||
| 	// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an | ||||
| 	// error. DocumentType overrides the Ancestor field. | ||||
| 	defaultDocumentType reflect.Type | ||||
| } | ||||
|  | ||||
| // DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as | ||||
| // "interface{}" or "map[string]interface{}". | ||||
| func (dc *DecodeContext) DefaultDocumentM() { | ||||
| 	dc.defaultDocumentType = reflect.TypeOf(primitive.M{}) | ||||
| } | ||||
|  | ||||
| // DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as | ||||
| // "interface{}" or "map[string]interface{}". | ||||
| func (dc *DecodeContext) DefaultDocumentD() { | ||||
| 	dc.defaultDocumentType = reflect.TypeOf(primitive.D{}) | ||||
| } | ||||
|  | ||||
| // ValueCodec is the interface that groups the methods to encode and decode | ||||
| // values. | ||||
| type ValueCodec interface { | ||||
| 	ValueEncoder | ||||
| 	ValueDecoder | ||||
| } | ||||
|  | ||||
| // ValueEncoder is the interface implemented by types that can handle the encoding of a value. | ||||
| type ValueEncoder interface { | ||||
| 	EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error | ||||
| } | ||||
|  | ||||
| // ValueEncoderFunc is an adapter function that allows a function with the correct signature to be | ||||
| // used as a ValueEncoder. | ||||
| type ValueEncoderFunc func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error | ||||
|  | ||||
| // EncodeValue implements the ValueEncoder interface. | ||||
| func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	return fn(ec, vw, val) | ||||
| } | ||||
|  | ||||
| // ValueDecoder is the interface implemented by types that can handle the decoding of a value. | ||||
| type ValueDecoder interface { | ||||
| 	DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error | ||||
| } | ||||
|  | ||||
| // ValueDecoderFunc is an adapter function that allows a function with the correct signature to be | ||||
| // used as a ValueDecoder. | ||||
| type ValueDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) error | ||||
|  | ||||
| // DecodeValue implements the ValueDecoder interface. | ||||
| func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	return fn(dc, vr, val) | ||||
| } | ||||
|  | ||||
| // typeDecoder is the interface implemented by types that can handle the decoding of a value given its type. | ||||
| type typeDecoder interface { | ||||
| 	decodeType(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error) | ||||
| } | ||||
|  | ||||
| // typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder. | ||||
| type typeDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error) | ||||
|  | ||||
| func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	return fn(dc, vr, t) | ||||
| } | ||||
|  | ||||
| // decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder. | ||||
| type decodeAdapter struct { | ||||
| 	ValueDecoderFunc | ||||
| 	typeDecoderFunc | ||||
| } | ||||
|  | ||||
| var _ ValueDecoder = decodeAdapter{} | ||||
| var _ typeDecoder = decodeAdapter{} | ||||
|  | ||||
| // decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type | ||||
| // t and calls decoder.DecodeValue on it. | ||||
| func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	td, _ := decoder.(typeDecoder) | ||||
| 	return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) | ||||
| } | ||||
|  | ||||
| func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { | ||||
| 	if td != nil { | ||||
| 		val, err := td.decodeType(dc, vr, t) | ||||
| 		if err == nil && convert && val.Type() != t { | ||||
| 			// This conversion step is necessary for slices and maps. If a user declares variables like: | ||||
| 			// | ||||
| 			// type myBool bool | ||||
| 			// var m map[string]myBool | ||||
| 			// | ||||
| 			// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present | ||||
| 			// because we'll try to assign a value of type bool to one of type myBool. | ||||
| 			val = val.Convert(t) | ||||
| 		} | ||||
| 		return val, err | ||||
| 	} | ||||
|  | ||||
| 	val := reflect.New(t).Elem() | ||||
| 	err := vd.DecodeValue(dc, vr, val) | ||||
| 	return val, err | ||||
| } | ||||
|  | ||||
| // CodecZeroer is the interface implemented by Codecs that can also determine if | ||||
| // a value of the type that would be encoded is zero. | ||||
| type CodecZeroer interface { | ||||
| 	IsTypeZero(interface{}) bool | ||||
| } | ||||
							
								
								
									
										143
									
								
								mongo/bson/bsoncodec/bsoncodec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								mongo/bson/bsoncodec/bsoncodec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| func ExampleValueEncoder() { | ||||
| 	var _ ValueEncoderFunc = func(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 		if val.Kind() != reflect.String { | ||||
| 			return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} | ||||
| 		} | ||||
|  | ||||
| 		return vw.WriteString(val.String()) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ExampleValueDecoder() { | ||||
| 	var _ ValueDecoderFunc = func(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 		if !val.CanSet() || val.Kind() != reflect.String { | ||||
| 			return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} | ||||
| 		} | ||||
|  | ||||
| 		if vr.Type() != bsontype.String { | ||||
| 			return fmt.Errorf("cannot decode %v into a string type", vr.Type()) | ||||
| 		} | ||||
|  | ||||
| 		str, err := vr.ReadString() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		val.SetString(str) | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func noerr(t *testing.T, err error) { | ||||
| 	if err != nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("Unexpected error: (%T)%v", err, err) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func compareTime(t1, t2 time.Time) bool { | ||||
| 	if t1.Location() != t2.Location() { | ||||
| 		return false | ||||
| 	} | ||||
| 	return t1.Equal(t2) | ||||
| } | ||||
|  | ||||
| func compareErrors(err1, err2 error) bool { | ||||
| 	if err1 == nil && err2 == nil { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	if err1 == nil || err2 == nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if err1.Error() != err2.Error() { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func compareDecimal128(d1, d2 primitive.Decimal128) bool { | ||||
| 	d1H, d1L := d1.GetBytes() | ||||
| 	d2H, d2L := d2.GetBytes() | ||||
|  | ||||
| 	if d1H != d2H { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if d1L != d2L { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| type noPrivateFields struct { | ||||
| 	a string | ||||
| } | ||||
|  | ||||
| func compareNoPrivateFields(npf1, npf2 noPrivateFields) bool { | ||||
| 	return npf1.a != npf2.a // We don't want these to be equal | ||||
| } | ||||
|  | ||||
| type zeroTest struct { | ||||
| 	reportZero bool | ||||
| } | ||||
|  | ||||
| func (z zeroTest) IsZero() bool { return z.reportZero } | ||||
|  | ||||
| func compareZeroTest(_, _ zeroTest) bool { return true } | ||||
|  | ||||
| type nonZeroer struct { | ||||
| 	value bool | ||||
| } | ||||
|  | ||||
| type llCodec struct { | ||||
| 	t         *testing.T | ||||
| 	decodeval interface{} | ||||
| 	encodeval interface{} | ||||
| 	err       error | ||||
| } | ||||
|  | ||||
| func (llc *llCodec) EncodeValue(_ EncodeContext, _ bsonrw.ValueWriter, i interface{}) error { | ||||
| 	if llc.err != nil { | ||||
| 		return llc.err | ||||
| 	} | ||||
|  | ||||
| 	llc.encodeval = i | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llc *llCodec) DecodeValue(_ DecodeContext, _ bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if llc.err != nil { | ||||
| 		return llc.err | ||||
| 	} | ||||
|  | ||||
| 	if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type()) { | ||||
| 		llc.t.Errorf("decodeval must be assignable to val provided to DecodeValue, but is not. decodeval %T; val %T", llc.decodeval, val) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	val.Set(reflect.ValueOf(llc.decodeval)) | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										111
									
								
								mongo/bson/bsoncodec/byte_slice_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								mongo/bson/bsoncodec/byte_slice_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,111 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| // ByteSliceCodec is the Codec used for []byte values. | ||||
| type ByteSliceCodec struct { | ||||
| 	EncodeNilAsEmpty bool | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	defaultByteSliceCodec = NewByteSliceCodec() | ||||
|  | ||||
| 	_ ValueCodec  = defaultByteSliceCodec | ||||
| 	_ typeDecoder = defaultByteSliceCodec | ||||
| ) | ||||
|  | ||||
| // NewByteSliceCodec returns a StringCodec with options opts. | ||||
| func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec { | ||||
| 	byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...) | ||||
| 	codec := ByteSliceCodec{} | ||||
| 	if byteSliceOpt.EncodeNilAsEmpty != nil { | ||||
| 		codec.EncodeNilAsEmpty = *byteSliceOpt.EncodeNilAsEmpty | ||||
| 	} | ||||
| 	return &codec | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoder for []byte. | ||||
| func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tByteSlice { | ||||
| 		return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} | ||||
| 	} | ||||
| 	if val.IsNil() && !bsc.EncodeNilAsEmpty { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
| 	return vw.WriteBinary(val.Interface().([]byte)) | ||||
| } | ||||
|  | ||||
| func (bsc *ByteSliceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	if t != tByteSlice { | ||||
| 		return emptyValue, ValueDecoderError{ | ||||
| 			Name:     "ByteSliceDecodeValue", | ||||
| 			Types:    []reflect.Type{tByteSlice}, | ||||
| 			Received: reflect.Zero(t), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var data []byte | ||||
| 	var err error | ||||
| 	switch vrType := vr.Type(); vrType { | ||||
| 	case bsontype.String: | ||||
| 		str, err := vr.ReadString() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		data = []byte(str) | ||||
| 	case bsontype.Symbol: | ||||
| 		sym, err := vr.ReadSymbol() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		data = []byte(sym) | ||||
| 	case bsontype.Binary: | ||||
| 		var subtype byte | ||||
| 		data, subtype, err = vr.ReadBinary() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { | ||||
| 			return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"} | ||||
| 		} | ||||
| 	case bsontype.Null: | ||||
| 		err = vr.ReadNull() | ||||
| 	case bsontype.Undefined: | ||||
| 		err = vr.ReadUndefined() | ||||
| 	default: | ||||
| 		return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return emptyValue, err | ||||
| 	} | ||||
|  | ||||
| 	return reflect.ValueOf(data), nil | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoder for []byte. | ||||
| func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Type() != tByteSlice { | ||||
| 		return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	elem, err := bsc.decodeType(dc, vr, tByteSlice) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	val.Set(elem) | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										63
									
								
								mongo/bson/bsoncodec/cond_addr_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								mongo/bson/bsoncodec/cond_addr_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| ) | ||||
|  | ||||
| // condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder. | ||||
| type condAddrEncoder struct { | ||||
| 	canAddrEnc ValueEncoder | ||||
| 	elseEnc    ValueEncoder | ||||
| } | ||||
|  | ||||
| var _ ValueEncoder = (*condAddrEncoder)(nil) | ||||
|  | ||||
| // newCondAddrEncoder returns an condAddrEncoder. | ||||
| func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder { | ||||
| 	encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} | ||||
| 	return &encoder | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoderFunc for a value that may be addressable. | ||||
| func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if val.CanAddr() { | ||||
| 		return cae.canAddrEnc.EncodeValue(ec, vw, val) | ||||
| 	} | ||||
| 	if cae.elseEnc != nil { | ||||
| 		return cae.elseEnc.EncodeValue(ec, vw, val) | ||||
| 	} | ||||
| 	return ErrNoEncoder{Type: val.Type()} | ||||
| } | ||||
|  | ||||
| // condAddrDecoder is the decoder used when a pointer to the value has a decoder. | ||||
| type condAddrDecoder struct { | ||||
| 	canAddrDec ValueDecoder | ||||
| 	elseDec    ValueDecoder | ||||
| } | ||||
|  | ||||
| var _ ValueDecoder = (*condAddrDecoder)(nil) | ||||
|  | ||||
| // newCondAddrDecoder returns an CondAddrDecoder. | ||||
| func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder { | ||||
| 	decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec} | ||||
| 	return &decoder | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoderFunc for a value that may be addressable. | ||||
| func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if val.CanAddr() { | ||||
| 		return cad.canAddrDec.DecodeValue(dc, vr, val) | ||||
| 	} | ||||
| 	if cad.elseDec != nil { | ||||
| 		return cad.elseDec.DecodeValue(dc, vr, val) | ||||
| 	} | ||||
| 	return ErrNoDecoder{Type: val.Type()} | ||||
| } | ||||
							
								
								
									
										97
									
								
								mongo/bson/bsoncodec/cond_addr_codec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								mongo/bson/bsoncodec/cond_addr_codec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| ) | ||||
|  | ||||
| func TestCondAddrCodec(t *testing.T) { | ||||
| 	var inner int | ||||
| 	canAddrVal := reflect.ValueOf(&inner) | ||||
| 	addressable := canAddrVal.Elem() | ||||
| 	unaddressable := reflect.ValueOf(inner) | ||||
| 	rw := &bsonrwtest.ValueReaderWriter{} | ||||
|  | ||||
| 	t.Run("addressEncode", func(t *testing.T) { | ||||
| 		invoked := 0 | ||||
| 		encode1 := ValueEncoderFunc(func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error { | ||||
| 			invoked = 1 | ||||
| 			return nil | ||||
| 		}) | ||||
| 		encode2 := ValueEncoderFunc(func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error { | ||||
| 			invoked = 2 | ||||
| 			return nil | ||||
| 		}) | ||||
| 		condEncoder := newCondAddrEncoder(encode1, encode2) | ||||
|  | ||||
| 		testCases := []struct { | ||||
| 			name    string | ||||
| 			val     reflect.Value | ||||
| 			invoked int | ||||
| 		}{ | ||||
| 			{"canAddr", addressable, 1}, | ||||
| 			{"else", unaddressable, 2}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val) | ||||
| 				assert.Nil(t, err, "CondAddrEncoder error: %v", err) | ||||
|  | ||||
| 				assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) | ||||
| 			}) | ||||
| 		} | ||||
|  | ||||
| 		t.Run("error", func(t *testing.T) { | ||||
| 			errEncoder := newCondAddrEncoder(encode1, nil) | ||||
| 			err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable) | ||||
| 			want := ErrNoEncoder{Type: unaddressable.Type()} | ||||
| 			assert.Equal(t, err, want, "expected error %v, got %v", want, err) | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("addressDecode", func(t *testing.T) { | ||||
| 		invoked := 0 | ||||
| 		decode1 := ValueDecoderFunc(func(DecodeContext, bsonrw.ValueReader, reflect.Value) error { | ||||
| 			invoked = 1 | ||||
| 			return nil | ||||
| 		}) | ||||
| 		decode2 := ValueDecoderFunc(func(DecodeContext, bsonrw.ValueReader, reflect.Value) error { | ||||
| 			invoked = 2 | ||||
| 			return nil | ||||
| 		}) | ||||
| 		condDecoder := newCondAddrDecoder(decode1, decode2) | ||||
|  | ||||
| 		testCases := []struct { | ||||
| 			name    string | ||||
| 			val     reflect.Value | ||||
| 			invoked int | ||||
| 		}{ | ||||
| 			{"canAddr", addressable, 1}, | ||||
| 			{"else", unaddressable, 2}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val) | ||||
| 				assert.Nil(t, err, "CondAddrDecoder error: %v", err) | ||||
|  | ||||
| 				assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked) | ||||
| 			}) | ||||
| 		} | ||||
|  | ||||
| 		t.Run("error", func(t *testing.T) { | ||||
| 			errDecoder := newCondAddrDecoder(decode1, nil) | ||||
| 			err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable) | ||||
| 			want := ErrNoDecoder{Type: unaddressable.Type()} | ||||
| 			assert.Equal(t, err, want, "expected error %v, got %v", want, err) | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										1729
									
								
								mongo/bson/bsoncodec/default_value_decoders.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1729
									
								
								mongo/bson/bsoncodec/default_value_decoders.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										3794
									
								
								mongo/bson/bsoncodec/default_value_decoders_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3794
									
								
								mongo/bson/bsoncodec/default_value_decoders_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										766
									
								
								mongo/bson/bsoncodec/default_value_encoders.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										766
									
								
								mongo/bson/bsoncodec/default_value_encoders.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,766 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"net/url" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| var defaultValueEncoders DefaultValueEncoders | ||||
|  | ||||
| var bvwPool = bsonrw.NewBSONValueWriterPool() | ||||
|  | ||||
| var errInvalidValue = errors.New("cannot encode invalid element") | ||||
|  | ||||
| var sliceWriterPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		sw := make(bsonrw.SliceWriter, 0) | ||||
| 		return &sw | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| func encodeElement(ec EncodeContext, dw bsonrw.DocumentWriter, e primitive.E) error { | ||||
| 	vw, err := dw.WriteDocumentElement(e.Key) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if e.Value == nil { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
| 	encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // DefaultValueEncoders is a namespace type for the default ValueEncoders used | ||||
| // when creating a registry. | ||||
| type DefaultValueEncoders struct{} | ||||
|  | ||||
| // RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with | ||||
| // the provided RegistryBuilder. | ||||
| func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) { | ||||
| 	if rb == nil { | ||||
| 		panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) | ||||
| 	} | ||||
| 	rb. | ||||
| 		RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec). | ||||
| 		RegisterTypeEncoder(tTime, defaultTimeCodec). | ||||
| 		RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec). | ||||
| 		RegisterTypeEncoder(tCoreArray, defaultArrayCodec). | ||||
| 		RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)). | ||||
| 		RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)). | ||||
| 		RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)). | ||||
| 		RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)). | ||||
| 		RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)). | ||||
| 		RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)). | ||||
| 		RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)). | ||||
| 		RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)). | ||||
| 		RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)). | ||||
| 		RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)). | ||||
| 		RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)). | ||||
| 		RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)). | ||||
| 		RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)). | ||||
| 		RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)). | ||||
| 		RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)). | ||||
| 		RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)). | ||||
| 		RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)). | ||||
| 		RegisterDefaultEncoder(reflect.Map, defaultMapCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec). | ||||
| 		RegisterDefaultEncoder(reflect.String, defaultStringCodec). | ||||
| 		RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()). | ||||
| 		RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()). | ||||
| 		RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)). | ||||
| 		RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)). | ||||
| 		RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue)) | ||||
| } | ||||
|  | ||||
| // BooleanEncodeValue is the ValueEncoderFunc for bool types. | ||||
| func (dve DefaultValueEncoders) BooleanEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Bool { | ||||
| 		return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} | ||||
| 	} | ||||
| 	return vw.WriteBoolean(val.Bool()) | ||||
| } | ||||
|  | ||||
| func fitsIn32Bits(i int64) bool { | ||||
| 	return math.MinInt32 <= i && i <= math.MaxInt32 | ||||
| } | ||||
|  | ||||
| // IntEncodeValue is the ValueEncoderFunc for int types. | ||||
| func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	switch val.Kind() { | ||||
| 	case reflect.Int8, reflect.Int16, reflect.Int32: | ||||
| 		return vw.WriteInt32(int32(val.Int())) | ||||
| 	case reflect.Int: | ||||
| 		i64 := val.Int() | ||||
| 		if fitsIn32Bits(i64) { | ||||
| 			return vw.WriteInt32(int32(i64)) | ||||
| 		} | ||||
| 		return vw.WriteInt64(i64) | ||||
| 	case reflect.Int64: | ||||
| 		i64 := val.Int() | ||||
| 		if ec.MinSize && fitsIn32Bits(i64) { | ||||
| 			return vw.WriteInt32(int32(i64)) | ||||
| 		} | ||||
| 		return vw.WriteInt64(i64) | ||||
| 	} | ||||
|  | ||||
| 	return ValueEncoderError{ | ||||
| 		Name:     "IntEncodeValue", | ||||
| 		Kinds:    []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, | ||||
| 		Received: val, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // UintEncodeValue is the ValueEncoderFunc for uint types. | ||||
| // | ||||
| // Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	switch val.Kind() { | ||||
| 	case reflect.Uint8, reflect.Uint16: | ||||
| 		return vw.WriteInt32(int32(val.Uint())) | ||||
| 	case reflect.Uint, reflect.Uint32, reflect.Uint64: | ||||
| 		u64 := val.Uint() | ||||
| 		if ec.MinSize && u64 <= math.MaxInt32 { | ||||
| 			return vw.WriteInt32(int32(u64)) | ||||
| 		} | ||||
| 		if u64 > math.MaxInt64 { | ||||
| 			return fmt.Errorf("%d overflows int64", u64) | ||||
| 		} | ||||
| 		return vw.WriteInt64(int64(u64)) | ||||
| 	} | ||||
|  | ||||
| 	return ValueEncoderError{ | ||||
| 		Name:     "UintEncodeValue", | ||||
| 		Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, | ||||
| 		Received: val, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // FloatEncodeValue is the ValueEncoderFunc for float types. | ||||
| func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	switch val.Kind() { | ||||
| 	case reflect.Float32, reflect.Float64: | ||||
| 		return vw.WriteDouble(val.Float()) | ||||
| 	} | ||||
|  | ||||
| 	return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val} | ||||
| } | ||||
|  | ||||
| // StringEncodeValue is the ValueEncoderFunc for string types. | ||||
| // | ||||
| // Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if val.Kind() != reflect.String { | ||||
| 		return ValueEncoderError{ | ||||
| 			Name:     "StringEncodeValue", | ||||
| 			Kinds:    []reflect.Kind{reflect.String}, | ||||
| 			Received: val, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteString(val.String()) | ||||
| } | ||||
|  | ||||
| // ObjectIDEncodeValue is the ValueEncoderFunc for primitive.ObjectID. | ||||
| func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tOID { | ||||
| 		return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} | ||||
| 	} | ||||
| 	return vw.WriteObjectID(val.Interface().(primitive.ObjectID)) | ||||
| } | ||||
|  | ||||
| // Decimal128EncodeValue is the ValueEncoderFunc for primitive.Decimal128. | ||||
| func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tDecimal { | ||||
| 		return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} | ||||
| 	} | ||||
| 	return vw.WriteDecimal128(val.Interface().(primitive.Decimal128)) | ||||
| } | ||||
|  | ||||
| // JSONNumberEncodeValue is the ValueEncoderFunc for json.Number. | ||||
| func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tJSONNumber { | ||||
| 		return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} | ||||
| 	} | ||||
| 	jsnum := val.Interface().(json.Number) | ||||
|  | ||||
| 	// Attempt int first, then float64 | ||||
| 	if i64, err := jsnum.Int64(); err == nil { | ||||
| 		return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64)) | ||||
| 	} | ||||
|  | ||||
| 	f64, err := jsnum.Float64() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64)) | ||||
| } | ||||
|  | ||||
| // URLEncodeValue is the ValueEncoderFunc for url.URL. | ||||
| func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tURL { | ||||
| 		return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} | ||||
| 	} | ||||
| 	u := val.Interface().(url.URL) | ||||
| 	return vw.WriteString(u.String()) | ||||
| } | ||||
|  | ||||
| // TimeEncodeValue is the ValueEncoderFunc for time.TIme. | ||||
| // | ||||
| // Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tTime { | ||||
| 		return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} | ||||
| 	} | ||||
| 	tt := val.Interface().(time.Time) | ||||
| 	dt := primitive.NewDateTimeFromTime(tt) | ||||
| 	return vw.WriteDateTime(int64(dt)) | ||||
| } | ||||
|  | ||||
| // ByteSliceEncodeValue is the ValueEncoderFunc for []byte. | ||||
| // | ||||
| // Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) ByteSliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tByteSlice { | ||||
| 		return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} | ||||
| 	} | ||||
| 	if val.IsNil() { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
| 	return vw.WriteBinary(val.Interface().([]byte)) | ||||
| } | ||||
|  | ||||
| // MapEncodeValue is the ValueEncoderFunc for map[string]* types. | ||||
| // | ||||
| // Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { | ||||
| 		return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		// If we have a nill map but we can't WriteNull, that means we're probably trying to encode | ||||
| 		// to a TopLevel document. We can't currently tell if this is what actually happened, but if | ||||
| 		// there's a deeper underlying problem, the error will also be returned from WriteDocument, | ||||
| 		// so just continue. The operations on a map reflection value are valid, so we can call | ||||
| 		// MapKeys within mapEncodeValue without a problem. | ||||
| 		err := vw.WriteNull() | ||||
| 		if err == nil { | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	dw, err := vw.WriteDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return dve.mapEncodeValue(ec, dw, val, nil) | ||||
| } | ||||
|  | ||||
| // mapEncodeValue handles encoding of the values of a map. The collisionFn returns | ||||
| // true if the provided key exists, this is mainly used for inline maps in the | ||||
| // struct codec. | ||||
| func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { | ||||
|  | ||||
| 	elemType := val.Type().Elem() | ||||
| 	encoder, err := ec.LookupEncoder(elemType) | ||||
| 	if err != nil && elemType.Kind() != reflect.Interface { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	keys := val.MapKeys() | ||||
| 	for _, key := range keys { | ||||
| 		if collisionFn != nil && collisionFn(key.String()) { | ||||
| 			return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) | ||||
| 		} | ||||
|  | ||||
| 		currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key)) | ||||
| 		if lookupErr != nil && lookupErr != errInvalidValue { | ||||
| 			return lookupErr | ||||
| 		} | ||||
|  | ||||
| 		vw, err := dw.WriteDocumentElement(key.String()) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if lookupErr == errInvalidValue { | ||||
| 			err = vw.WriteNull() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		err = currEncoder.EncodeValue(ec, vw, currVal) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return dw.WriteDocumentEnd() | ||||
| } | ||||
|  | ||||
| // ArrayEncodeValue is the ValueEncoderFunc for array types. | ||||
| func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Array { | ||||
| 		return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	// If we have a []primitive.E we want to treat it as a document instead of as an array. | ||||
| 	if val.Type().Elem() == tE { | ||||
| 		dw, err := vw.WriteDocument() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		for idx := 0; idx < val.Len(); idx++ { | ||||
| 			e := val.Index(idx).Interface().(primitive.E) | ||||
| 			err = encodeElement(ec, dw, e) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return dw.WriteDocumentEnd() | ||||
| 	} | ||||
|  | ||||
| 	// If we have a []byte we want to treat it as a binary instead of as an array. | ||||
| 	if val.Type().Elem() == tByte { | ||||
| 		var byteSlice []byte | ||||
| 		for idx := 0; idx < val.Len(); idx++ { | ||||
| 			byteSlice = append(byteSlice, val.Index(idx).Interface().(byte)) | ||||
| 		} | ||||
| 		return vw.WriteBinary(byteSlice) | ||||
| 	} | ||||
|  | ||||
| 	aw, err := vw.WriteArray() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	elemType := val.Type().Elem() | ||||
| 	encoder, err := ec.LookupEncoder(elemType) | ||||
| 	if err != nil && elemType.Kind() != reflect.Interface { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for idx := 0; idx < val.Len(); idx++ { | ||||
| 		currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) | ||||
| 		if lookupErr != nil && lookupErr != errInvalidValue { | ||||
| 			return lookupErr | ||||
| 		} | ||||
|  | ||||
| 		vw, err := aw.WriteArrayElement() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if lookupErr == errInvalidValue { | ||||
| 			err = vw.WriteNull() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		err = currEncoder.EncodeValue(ec, vw, currVal) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return aw.WriteArrayEnd() | ||||
| } | ||||
|  | ||||
| // SliceEncodeValue is the ValueEncoderFunc for slice types. | ||||
| // | ||||
| // Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Slice { | ||||
| 		return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
|  | ||||
| 	// If we have a []primitive.E we want to treat it as a document instead of as an array. | ||||
| 	if val.Type().ConvertibleTo(tD) { | ||||
| 		d := val.Convert(tD).Interface().(primitive.D) | ||||
|  | ||||
| 		dw, err := vw.WriteDocument() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		for _, e := range d { | ||||
| 			err = encodeElement(ec, dw, e) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return dw.WriteDocumentEnd() | ||||
| 	} | ||||
|  | ||||
| 	aw, err := vw.WriteArray() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	elemType := val.Type().Elem() | ||||
| 	encoder, err := ec.LookupEncoder(elemType) | ||||
| 	if err != nil && elemType.Kind() != reflect.Interface { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for idx := 0; idx < val.Len(); idx++ { | ||||
| 		currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx)) | ||||
| 		if lookupErr != nil && lookupErr != errInvalidValue { | ||||
| 			return lookupErr | ||||
| 		} | ||||
|  | ||||
| 		vw, err := aw.WriteArrayElement() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if lookupErr == errInvalidValue { | ||||
| 			err = vw.WriteNull() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		err = currEncoder.EncodeValue(ec, vw, currVal) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return aw.WriteArrayEnd() | ||||
| } | ||||
|  | ||||
| func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { | ||||
| 	if origEncoder != nil || (currVal.Kind() != reflect.Interface) { | ||||
| 		return origEncoder, currVal, nil | ||||
| 	} | ||||
| 	currVal = currVal.Elem() | ||||
| 	if !currVal.IsValid() { | ||||
| 		return nil, currVal, errInvalidValue | ||||
| 	} | ||||
| 	currEncoder, err := ec.LookupEncoder(currVal.Type()) | ||||
|  | ||||
| 	return currEncoder, currVal, err | ||||
| } | ||||
|  | ||||
| // EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}. | ||||
| // | ||||
| // Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead. | ||||
| func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tEmpty { | ||||
| 		return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
| 	encoder, err := ec.LookupEncoder(val.Elem().Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return encoder.EncodeValue(ec, vw, val.Elem()) | ||||
| } | ||||
|  | ||||
| // ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. | ||||
| func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	// Either val or a pointer to val must implement ValueMarshaler | ||||
| 	switch { | ||||
| 	case !val.IsValid(): | ||||
| 		return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} | ||||
| 	case val.Type().Implements(tValueMarshaler): | ||||
| 		// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer | ||||
| 		if isImplementationNil(val, tValueMarshaler) { | ||||
| 			return vw.WriteNull() | ||||
| 		} | ||||
| 	case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr(): | ||||
| 		val = val.Addr() | ||||
| 	default: | ||||
| 		return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue") | ||||
| 	returns := fn.Call(nil) | ||||
| 	if !returns[2].IsNil() { | ||||
| 		return returns[2].Interface().(error) | ||||
| 	} | ||||
| 	t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte) | ||||
| 	return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data) | ||||
| } | ||||
|  | ||||
| // MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. | ||||
| func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	// Either val or a pointer to val must implement Marshaler | ||||
| 	switch { | ||||
| 	case !val.IsValid(): | ||||
| 		return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} | ||||
| 	case val.Type().Implements(tMarshaler): | ||||
| 		// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer | ||||
| 		if isImplementationNil(val, tMarshaler) { | ||||
| 			return vw.WriteNull() | ||||
| 		} | ||||
| 	case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr(): | ||||
| 		val = val.Addr() | ||||
| 	default: | ||||
| 		return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	fn := val.Convert(tMarshaler).MethodByName("MarshalBSON") | ||||
| 	returns := fn.Call(nil) | ||||
| 	if !returns[1].IsNil() { | ||||
| 		return returns[1].Interface().(error) | ||||
| 	} | ||||
| 	data := returns[0].Interface().([]byte) | ||||
| 	return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data) | ||||
| } | ||||
|  | ||||
| // ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations. | ||||
| func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	// Either val or a pointer to val must implement Proxy | ||||
| 	switch { | ||||
| 	case !val.IsValid(): | ||||
| 		return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} | ||||
| 	case val.Type().Implements(tProxy): | ||||
| 		// If Proxy is implemented on a concrete type, make sure that val isn't a nil pointer | ||||
| 		if isImplementationNil(val, tProxy) { | ||||
| 			return vw.WriteNull() | ||||
| 		} | ||||
| 	case reflect.PtrTo(val.Type()).Implements(tProxy) && val.CanAddr(): | ||||
| 		val = val.Addr() | ||||
| 	default: | ||||
| 		return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	fn := val.Convert(tProxy).MethodByName("ProxyBSON") | ||||
| 	returns := fn.Call(nil) | ||||
| 	if !returns[1].IsNil() { | ||||
| 		return returns[1].Interface().(error) | ||||
| 	} | ||||
| 	data := returns[0] | ||||
| 	var encoder ValueEncoder | ||||
| 	var err error | ||||
| 	if data.Elem().IsValid() { | ||||
| 		encoder, err = ec.LookupEncoder(data.Elem().Type()) | ||||
| 	} else { | ||||
| 		encoder, err = ec.LookupEncoder(nil) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return encoder.EncodeValue(ec, vw, data.Elem()) | ||||
| } | ||||
|  | ||||
| // JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type. | ||||
| func (DefaultValueEncoders) JavaScriptEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tJavaScript { | ||||
| 		return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteJavascript(val.String()) | ||||
| } | ||||
|  | ||||
| // SymbolEncodeValue is the ValueEncoderFunc for the primitive.Symbol type. | ||||
| func (DefaultValueEncoders) SymbolEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tSymbol { | ||||
| 		return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteSymbol(val.String()) | ||||
| } | ||||
|  | ||||
| // BinaryEncodeValue is the ValueEncoderFunc for Binary. | ||||
| func (DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tBinary { | ||||
| 		return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} | ||||
| 	} | ||||
| 	b := val.Interface().(primitive.Binary) | ||||
|  | ||||
| 	return vw.WriteBinaryWithSubtype(b.Data, b.Subtype) | ||||
| } | ||||
|  | ||||
| // UndefinedEncodeValue is the ValueEncoderFunc for Undefined. | ||||
| func (DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tUndefined { | ||||
| 		return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteUndefined() | ||||
| } | ||||
|  | ||||
| // DateTimeEncodeValue is the ValueEncoderFunc for DateTime. | ||||
| func (DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tDateTime { | ||||
| 		return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteDateTime(val.Int()) | ||||
| } | ||||
|  | ||||
| // NullEncodeValue is the ValueEncoderFunc for Null. | ||||
| func (DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tNull { | ||||
| 		return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteNull() | ||||
| } | ||||
|  | ||||
| // RegexEncodeValue is the ValueEncoderFunc for Regex. | ||||
| func (DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tRegex { | ||||
| 		return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	regex := val.Interface().(primitive.Regex) | ||||
|  | ||||
| 	return vw.WriteRegex(regex.Pattern, regex.Options) | ||||
| } | ||||
|  | ||||
| // DBPointerEncodeValue is the ValueEncoderFunc for DBPointer. | ||||
| func (DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tDBPointer { | ||||
| 		return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	dbp := val.Interface().(primitive.DBPointer) | ||||
|  | ||||
| 	return vw.WriteDBPointer(dbp.DB, dbp.Pointer) | ||||
| } | ||||
|  | ||||
| // TimestampEncodeValue is the ValueEncoderFunc for Timestamp. | ||||
| func (DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tTimestamp { | ||||
| 		return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	ts := val.Interface().(primitive.Timestamp) | ||||
|  | ||||
| 	return vw.WriteTimestamp(ts.T, ts.I) | ||||
| } | ||||
|  | ||||
| // MinKeyEncodeValue is the ValueEncoderFunc for MinKey. | ||||
| func (DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tMinKey { | ||||
| 		return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteMinKey() | ||||
| } | ||||
|  | ||||
| // MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey. | ||||
| func (DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tMaxKey { | ||||
| 		return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteMaxKey() | ||||
| } | ||||
|  | ||||
| // CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. | ||||
| func (DefaultValueEncoders) CoreDocumentEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tCoreDocument { | ||||
| 		return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	cdoc := val.Interface().(bsoncore.Document) | ||||
|  | ||||
| 	return bsonrw.Copier{}.CopyDocumentFromBytes(vw, cdoc) | ||||
| } | ||||
|  | ||||
| // CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. | ||||
| func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tCodeWithScope { | ||||
| 		return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	cws := val.Interface().(primitive.CodeWithScope) | ||||
|  | ||||
| 	dw, err := vw.WriteCodeWithScope(string(cws.Code)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	sw := sliceWriterPool.Get().(*bsonrw.SliceWriter) | ||||
| 	defer sliceWriterPool.Put(sw) | ||||
| 	*sw = (*sw)[:0] | ||||
|  | ||||
| 	scopeVW := bvwPool.Get(sw) | ||||
| 	defer bvwPool.Put(scopeVW) | ||||
|  | ||||
| 	encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = bsonrw.Copier{}.CopyBytesToDocumentWriter(dw, *sw) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return dw.WriteDocumentEnd() | ||||
| } | ||||
|  | ||||
| // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type | ||||
| func isImplementationNil(val reflect.Value, inter reflect.Type) bool { | ||||
| 	vt := val.Type() | ||||
| 	for vt.Kind() == reflect.Ptr { | ||||
| 		vt = vt.Elem() | ||||
| 	} | ||||
| 	return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil() | ||||
| } | ||||
							
								
								
									
										1909
									
								
								mongo/bson/bsoncodec/default_value_encoders_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1909
									
								
								mongo/bson/bsoncodec/default_value_encoders_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										90
									
								
								mongo/bson/bsoncodec/doc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								mongo/bson/bsoncodec/doc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,90 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2022-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| // Package bsoncodec provides a system for encoding values to BSON representations and decoding | ||||
| // values from BSON representations. This package considers both binary BSON and ExtendedJSON as | ||||
| // BSON representations. The types in this package enable a flexible system for handling this | ||||
| // encoding and decoding. | ||||
| // | ||||
| // The codec system is composed of two parts: | ||||
| // | ||||
| // 1) ValueEncoders and ValueDecoders that handle encoding and decoding Go values to and from BSON | ||||
| // representations. | ||||
| // | ||||
| // 2) A Registry that holds these ValueEncoders and ValueDecoders and provides methods for | ||||
| // retrieving them. | ||||
| // | ||||
| // # ValueEncoders and ValueDecoders | ||||
| // | ||||
| // The ValueEncoder interface is implemented by types that can encode a provided Go type to BSON. | ||||
| // The value to encode is provided as a reflect.Value and a bsonrw.ValueWriter is used within the | ||||
| // EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc | ||||
| // is provided to allow use of a function with the correct signature as a ValueEncoder. An | ||||
| // EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and | ||||
| // to provide configuration information. | ||||
| // | ||||
| // The ValueDecoder interface is the inverse of the ValueEncoder. Implementations should ensure that | ||||
| // the value they receive is settable. Similar to ValueEncoderFunc, ValueDecoderFunc is provided to | ||||
| // allow the use of a function with the correct signature as a ValueDecoder. A DecodeContext | ||||
| // instance is provided and serves similar functionality to the EncodeContext. | ||||
| // | ||||
| // # Registry and RegistryBuilder | ||||
| // | ||||
| // A Registry is an immutable store for ValueEncoders, ValueDecoders, and a type map. See the Registry type | ||||
| // documentation for examples of registering various custom encoders and decoders. A Registry can be constructed using a | ||||
| // RegistryBuilder, which handles three main types of codecs: | ||||
| // | ||||
| // 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and RegisterTypeDecoder methods. | ||||
| // The registered codec will be invoked when encoding/decoding a value whose type matches the registered type exactly. | ||||
| // If the registered type is an interface, the codec will be invoked when encoding or decoding values whose type is the | ||||
| // interface, but not for values with concrete types that implement the interface. | ||||
| // | ||||
| // 2. Hook encoders/decoders - These can be registered using the RegisterHookEncoder and RegisterHookDecoder methods. | ||||
| // These methods only accept interface types and the registered codecs will be invoked when encoding or decoding values | ||||
| // whose types implement the interface. An example of a hook defined by the driver is bson.Marshaler. The driver will | ||||
| // call the MarshalBSON method for any value whose type implements bson.Marshaler, regardless of the value's concrete | ||||
| // type. | ||||
| // | ||||
| // 3. Type map entries - This can be used to associate a BSON type with a Go type. These type associations are used when | ||||
| // decoding into a bson.D/bson.M or a struct field of type interface{}. For example, by default, BSON int32 and int64 | ||||
| // values decode as Go int32 and int64 instances, respectively, when decoding into a bson.D. The following code would | ||||
| // change the behavior so these values decode as Go int instances instead: | ||||
| // | ||||
| //	intType := reflect.TypeOf(int(0)) | ||||
| //	registryBuilder.RegisterTypeMapEntry(bsontype.Int32, intType).RegisterTypeMapEntry(bsontype.Int64, intType) | ||||
| // | ||||
| // 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and RegisterDefaultDecoder | ||||
| // methods. The registered codec will be invoked when encoding or decoding values whose reflect.Kind matches the | ||||
| // registered reflect.Kind as long as the value's type doesn't match a registered type or hook encoder/decoder first. | ||||
| // These methods should be used to change the behavior for all values for a specific kind. | ||||
| // | ||||
| // # Registry Lookup Procedure | ||||
| // | ||||
| // When looking up an encoder in a Registry, the precedence rules are as follows: | ||||
| // | ||||
| // 1. A type encoder registered for the exact type of the value. | ||||
| // | ||||
| // 2. A hook encoder registered for an interface that is implemented by the value or by a pointer to the value. If the | ||||
| // value matches multiple hooks (e.g. the type implements bsoncodec.Marshaler and bsoncodec.ValueMarshaler), the first | ||||
| // one registered will be selected. Note that registries constructed using bson.NewRegistryBuilder have driver-defined | ||||
| // hooks registered for the bsoncodec.Marshaler, bsoncodec.ValueMarshaler, and bsoncodec.Proxy interfaces, so those | ||||
| // will take precedence over any new hooks. | ||||
| // | ||||
| // 3. A kind encoder registered for the value's kind. | ||||
| // | ||||
| // If all of these lookups fail to find an encoder, an error of type ErrNoEncoder is returned. The same precedence | ||||
| // rules apply for decoders, with the exception that an error of type ErrNoDecoder will be returned if no decoder is | ||||
| // found. | ||||
| // | ||||
| // # DefaultValueEncoders and DefaultValueDecoders | ||||
| // | ||||
| // The DefaultValueEncoders and DefaultValueDecoders types provide a full set of ValueEncoders and | ||||
| // ValueDecoders for handling a wide range of Go types, including all of the types within the | ||||
| // primitive package. To make registering these codecs easier, a helper method on each type is | ||||
| // provided. For the DefaultValueEncoders type the method is called RegisterDefaultEncoders and for | ||||
| // the DefaultValueDecoders type the method is called RegisterDefaultDecoders, this method also | ||||
| // handles registering type map entries for each BSON type. | ||||
| package bsoncodec | ||||
							
								
								
									
										147
									
								
								mongo/bson/bsoncodec/empty_interface_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								mongo/bson/bsoncodec/empty_interface_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,147 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| // EmptyInterfaceCodec is the Codec used for interface{} values. | ||||
| type EmptyInterfaceCodec struct { | ||||
| 	DecodeBinaryAsSlice bool | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec() | ||||
|  | ||||
| 	_ ValueCodec  = defaultEmptyInterfaceCodec | ||||
| 	_ typeDecoder = defaultEmptyInterfaceCodec | ||||
| ) | ||||
|  | ||||
| // NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts. | ||||
| func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec { | ||||
| 	interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...) | ||||
|  | ||||
| 	codec := EmptyInterfaceCodec{} | ||||
| 	if interfaceOpt.DecodeBinaryAsSlice != nil { | ||||
| 		codec.DecodeBinaryAsSlice = *interfaceOpt.DecodeBinaryAsSlice | ||||
| 	} | ||||
| 	return &codec | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoderFunc for interface{}. | ||||
| func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tEmpty { | ||||
| 		return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
| 	encoder, err := ec.LookupEncoder(val.Elem().Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return encoder.EncodeValue(ec, vw, val.Elem()) | ||||
| } | ||||
|  | ||||
| func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType bsontype.Type) (reflect.Type, error) { | ||||
| 	isDocument := valueType == bsontype.Type(0) || valueType == bsontype.EmbeddedDocument | ||||
| 	if isDocument { | ||||
| 		if dc.defaultDocumentType != nil { | ||||
| 			// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return | ||||
| 			// that type. | ||||
| 			return dc.defaultDocumentType, nil | ||||
| 		} | ||||
| 		if dc.Ancestor != nil { | ||||
| 			// Using ancestor information rather than looking up the type map entry forces consistent decoding. | ||||
| 			// If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry | ||||
| 			// has been registered. | ||||
| 			return dc.Ancestor, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	rtype, err := dc.LookupTypeMapEntry(valueType) | ||||
| 	if err == nil { | ||||
| 		return rtype, nil | ||||
| 	} | ||||
|  | ||||
| 	if isDocument { | ||||
| 		// For documents, fallback to looking up a type map entry for bsontype.Type(0) or bsontype.EmbeddedDocument, | ||||
| 		// depending on the original valueType. | ||||
| 		var lookupType bsontype.Type | ||||
| 		switch valueType { | ||||
| 		case bsontype.Type(0): | ||||
| 			lookupType = bsontype.EmbeddedDocument | ||||
| 		case bsontype.EmbeddedDocument: | ||||
| 			lookupType = bsontype.Type(0) | ||||
| 		} | ||||
|  | ||||
| 		rtype, err = dc.LookupTypeMapEntry(lookupType) | ||||
| 		if err == nil { | ||||
| 			return rtype, nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil, err | ||||
| } | ||||
|  | ||||
| func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	if t != tEmpty { | ||||
| 		return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)} | ||||
| 	} | ||||
|  | ||||
| 	rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type()) | ||||
| 	if err != nil { | ||||
| 		switch vr.Type() { | ||||
| 		case bsontype.Null: | ||||
| 			return reflect.Zero(t), vr.ReadNull() | ||||
| 		default: | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	decoder, err := dc.LookupDecoder(rtype) | ||||
| 	if err != nil { | ||||
| 		return emptyValue, err | ||||
| 	} | ||||
|  | ||||
| 	elem, err := decodeTypeOrValue(decoder, dc, vr, rtype) | ||||
| 	if err != nil { | ||||
| 		return emptyValue, err | ||||
| 	} | ||||
|  | ||||
| 	if eic.DecodeBinaryAsSlice && rtype == tBinary { | ||||
| 		binElem := elem.Interface().(primitive.Binary) | ||||
| 		if binElem.Subtype == bsontype.BinaryGeneric || binElem.Subtype == bsontype.BinaryBinaryOld { | ||||
| 			elem = reflect.ValueOf(binElem.Data) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return elem, nil | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoderFunc for interface{}. | ||||
| func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Type() != tEmpty { | ||||
| 		return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	elem, err := eic.decodeType(dc, vr, val.Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	val.Set(elem) | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										309
									
								
								mongo/bson/bsoncodec/map_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										309
									
								
								mongo/bson/bsoncodec/map_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,309 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"encoding" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| var defaultMapCodec = NewMapCodec() | ||||
|  | ||||
| // MapCodec is the Codec used for map values. | ||||
| type MapCodec struct { | ||||
| 	DecodeZerosMap         bool | ||||
| 	EncodeNilAsEmpty       bool | ||||
| 	EncodeKeysWithStringer bool | ||||
| } | ||||
|  | ||||
| var _ ValueCodec = &MapCodec{} | ||||
|  | ||||
| // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key. | ||||
| // This applies to types used as map keys and is similar to encoding.TextMarshaler. | ||||
| type KeyMarshaler interface { | ||||
| 	MarshalKey() (key string, err error) | ||||
| } | ||||
|  | ||||
| // KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation | ||||
| // of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler. | ||||
| // | ||||
| // UnmarshalKey must be able to decode the form generated by MarshalKey. | ||||
| // UnmarshalKey must copy the text if it wishes to retain the text | ||||
| // after returning. | ||||
| type KeyUnmarshaler interface { | ||||
| 	UnmarshalKey(key string) error | ||||
| } | ||||
|  | ||||
| // NewMapCodec returns a MapCodec with options opts. | ||||
| func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec { | ||||
| 	mapOpt := bsonoptions.MergeMapCodecOptions(opts...) | ||||
|  | ||||
| 	codec := MapCodec{} | ||||
| 	if mapOpt.DecodeZerosMap != nil { | ||||
| 		codec.DecodeZerosMap = *mapOpt.DecodeZerosMap | ||||
| 	} | ||||
| 	if mapOpt.EncodeNilAsEmpty != nil { | ||||
| 		codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty | ||||
| 	} | ||||
| 	if mapOpt.EncodeKeysWithStringer != nil { | ||||
| 		codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer | ||||
| 	} | ||||
| 	return &codec | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoder for map[*]* types. | ||||
| func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Map { | ||||
| 		return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() && !mc.EncodeNilAsEmpty { | ||||
| 		// If we have a nil map but we can't WriteNull, that means we're probably trying to encode | ||||
| 		// to a TopLevel document. We can't currently tell if this is what actually happened, but if | ||||
| 		// there's a deeper underlying problem, the error will also be returned from WriteDocument, | ||||
| 		// so just continue. The operations on a map reflection value are valid, so we can call | ||||
| 		// MapKeys within mapEncodeValue without a problem. | ||||
| 		err := vw.WriteNull() | ||||
| 		if err == nil { | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	dw, err := vw.WriteDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return mc.mapEncodeValue(ec, dw, val, nil) | ||||
| } | ||||
|  | ||||
| // mapEncodeValue handles encoding of the values of a map. The collisionFn returns | ||||
| // true if the provided key exists, this is mainly used for inline maps in the | ||||
| // struct codec. | ||||
| func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { | ||||
|  | ||||
| 	elemType := val.Type().Elem() | ||||
| 	encoder, err := ec.LookupEncoder(elemType) | ||||
| 	if err != nil && elemType.Kind() != reflect.Interface { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	keys := val.MapKeys() | ||||
| 	for _, key := range keys { | ||||
| 		keyStr, err := mc.encodeKey(key) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if collisionFn != nil && collisionFn(keyStr) { | ||||
| 			return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key) | ||||
| 		} | ||||
|  | ||||
| 		currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key)) | ||||
| 		if lookupErr != nil && lookupErr != errInvalidValue { | ||||
| 			return lookupErr | ||||
| 		} | ||||
|  | ||||
| 		vw, err := dw.WriteDocumentElement(keyStr) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if lookupErr == errInvalidValue { | ||||
| 			err = vw.WriteNull() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		err = currEncoder.EncodeValue(ec, vw, currVal) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return dw.WriteDocumentEnd() | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoder for map[string/decimal]* types. | ||||
| func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) { | ||||
| 		return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	switch vrType := vr.Type(); vrType { | ||||
| 	case bsontype.Type(0), bsontype.EmbeddedDocument: | ||||
| 	case bsontype.Null: | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return vr.ReadNull() | ||||
| 	case bsontype.Undefined: | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return vr.ReadUndefined() | ||||
| 	default: | ||||
| 		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) | ||||
| 	} | ||||
|  | ||||
| 	dr, err := vr.ReadDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		val.Set(reflect.MakeMap(val.Type())) | ||||
| 	} | ||||
|  | ||||
| 	if val.Len() > 0 && mc.DecodeZerosMap { | ||||
| 		clearMap(val) | ||||
| 	} | ||||
|  | ||||
| 	eType := val.Type().Elem() | ||||
| 	decoder, err := dc.LookupDecoder(eType) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	eTypeDecoder, _ := decoder.(typeDecoder) | ||||
|  | ||||
| 	if eType == tEmpty { | ||||
| 		dc.Ancestor = val.Type() | ||||
| 	} | ||||
|  | ||||
| 	keyType := val.Type().Key() | ||||
|  | ||||
| 	for { | ||||
| 		key, vr, err := dr.ReadElement() | ||||
| 		if err == bsonrw.ErrEOD { | ||||
| 			break | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		k, err := mc.decodeKey(key, keyType) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) | ||||
| 		if err != nil { | ||||
| 			return newDecodeError(key, err) | ||||
| 		} | ||||
|  | ||||
| 		val.SetMapIndex(k, elem) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func clearMap(m reflect.Value) { | ||||
| 	var none reflect.Value | ||||
| 	for _, k := range m.MapKeys() { | ||||
| 		m.SetMapIndex(k, none) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) { | ||||
| 	if mc.EncodeKeysWithStringer { | ||||
| 		return fmt.Sprint(val), nil | ||||
| 	} | ||||
|  | ||||
| 	// keys of any string type are used directly | ||||
| 	if val.Kind() == reflect.String { | ||||
| 		return val.String(), nil | ||||
| 	} | ||||
| 	// KeyMarshalers are marshaled | ||||
| 	if km, ok := val.Interface().(KeyMarshaler); ok { | ||||
| 		if val.Kind() == reflect.Ptr && val.IsNil() { | ||||
| 			return "", nil | ||||
| 		} | ||||
| 		buf, err := km.MarshalKey() | ||||
| 		if err == nil { | ||||
| 			return buf, nil | ||||
| 		} | ||||
| 		return "", err | ||||
| 	} | ||||
| 	// keys implement encoding.TextMarshaler are marshaled. | ||||
| 	if km, ok := val.Interface().(encoding.TextMarshaler); ok { | ||||
| 		if val.Kind() == reflect.Ptr && val.IsNil() { | ||||
| 			return "", nil | ||||
| 		} | ||||
|  | ||||
| 		buf, err := km.MarshalText() | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
|  | ||||
| 		return string(buf), nil | ||||
| 	} | ||||
|  | ||||
| 	switch val.Kind() { | ||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 		return strconv.FormatInt(val.Int(), 10), nil | ||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | ||||
| 		return strconv.FormatUint(val.Uint(), 10), nil | ||||
| 	} | ||||
| 	return "", fmt.Errorf("unsupported key type: %v", val.Type()) | ||||
| } | ||||
|  | ||||
| var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem() | ||||
| var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() | ||||
|  | ||||
| func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) { | ||||
| 	keyVal := reflect.ValueOf(key) | ||||
| 	var err error | ||||
| 	switch { | ||||
| 	// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler | ||||
| 	case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType): | ||||
| 		keyVal = reflect.New(keyType) | ||||
| 		v := keyVal.Interface().(KeyUnmarshaler) | ||||
| 		err = v.UnmarshalKey(key) | ||||
| 		keyVal = keyVal.Elem() | ||||
| 	// Try to decode encoding.TextUnmarshalers. | ||||
| 	case reflect.PtrTo(keyType).Implements(textUnmarshalerType): | ||||
| 		keyVal = reflect.New(keyType) | ||||
| 		v := keyVal.Interface().(encoding.TextUnmarshaler) | ||||
| 		err = v.UnmarshalText([]byte(key)) | ||||
| 		keyVal = keyVal.Elem() | ||||
| 	// Otherwise, go to type specific behavior | ||||
| 	default: | ||||
| 		switch keyType.Kind() { | ||||
| 		case reflect.String: | ||||
| 			keyVal = reflect.ValueOf(key).Convert(keyType) | ||||
| 		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 			n, parseErr := strconv.ParseInt(key, 10, 64) | ||||
| 			if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) { | ||||
| 				err = fmt.Errorf("failed to unmarshal number key %v", key) | ||||
| 			} | ||||
| 			keyVal = reflect.ValueOf(n).Convert(keyType) | ||||
| 		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | ||||
| 			n, parseErr := strconv.ParseUint(key, 10, 64) | ||||
| 			if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) { | ||||
| 				err = fmt.Errorf("failed to unmarshal number key %v", key) | ||||
| 				break | ||||
| 			} | ||||
| 			keyVal = reflect.ValueOf(n).Convert(keyType) | ||||
| 		case reflect.Float32, reflect.Float64: | ||||
| 			if mc.EncodeKeysWithStringer { | ||||
| 				parsed, err := strconv.ParseFloat(key, 64) | ||||
| 				if err != nil { | ||||
| 					return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err) | ||||
| 				} | ||||
| 				keyVal = reflect.ValueOf(parsed) | ||||
| 				break | ||||
| 			} | ||||
| 			fallthrough | ||||
| 		default: | ||||
| 			return keyVal, fmt.Errorf("unsupported key type: %v", keyType) | ||||
| 		} | ||||
| 	} | ||||
| 	return keyVal, err | ||||
| } | ||||
							
								
								
									
										65
									
								
								mongo/bson/bsoncodec/mode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								mongo/bson/bsoncodec/mode.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,65 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import "fmt" | ||||
|  | ||||
| type mode int | ||||
|  | ||||
| const ( | ||||
| 	_ mode = iota | ||||
| 	mTopLevel | ||||
| 	mDocument | ||||
| 	mArray | ||||
| 	mValue | ||||
| 	mElement | ||||
| 	mCodeWithScope | ||||
| 	mSpacer | ||||
| ) | ||||
|  | ||||
| func (m mode) String() string { | ||||
| 	var str string | ||||
|  | ||||
| 	switch m { | ||||
| 	case mTopLevel: | ||||
| 		str = "TopLevel" | ||||
| 	case mDocument: | ||||
| 		str = "DocumentMode" | ||||
| 	case mArray: | ||||
| 		str = "ArrayMode" | ||||
| 	case mValue: | ||||
| 		str = "ValueMode" | ||||
| 	case mElement: | ||||
| 		str = "ElementMode" | ||||
| 	case mCodeWithScope: | ||||
| 		str = "CodeWithScopeMode" | ||||
| 	case mSpacer: | ||||
| 		str = "CodeWithScopeSpacerFrame" | ||||
| 	default: | ||||
| 		str = "UnknownMode" | ||||
| 	} | ||||
|  | ||||
| 	return str | ||||
| } | ||||
|  | ||||
| // TransitionError is an error returned when an invalid progressing a | ||||
| // ValueReader or ValueWriter state machine occurs. | ||||
| type TransitionError struct { | ||||
| 	parent      mode | ||||
| 	current     mode | ||||
| 	destination mode | ||||
| } | ||||
|  | ||||
| func (te TransitionError) Error() string { | ||||
| 	if te.destination == mode(0) { | ||||
| 		return fmt.Sprintf("invalid state transition: cannot read/write value while in %s", te.current) | ||||
| 	} | ||||
| 	if te.parent == mode(0) { | ||||
| 		return fmt.Sprintf("invalid state transition: %s -> %s", te.current, te.destination) | ||||
| 	} | ||||
| 	return fmt.Sprintf("invalid state transition: %s -> %s; parent %s", te.current, te.destination, te.parent) | ||||
| } | ||||
							
								
								
									
										109
									
								
								mongo/bson/bsoncodec/pointer_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								mongo/bson/bsoncodec/pointer_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,109 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| var _ ValueEncoder = &PointerCodec{} | ||||
| var _ ValueDecoder = &PointerCodec{} | ||||
|  | ||||
| // PointerCodec is the Codec used for pointers. | ||||
| type PointerCodec struct { | ||||
| 	ecache map[reflect.Type]ValueEncoder | ||||
| 	dcache map[reflect.Type]ValueDecoder | ||||
| 	l      sync.RWMutex | ||||
| } | ||||
|  | ||||
| // NewPointerCodec returns a PointerCodec that has been initialized. | ||||
| func NewPointerCodec() *PointerCodec { | ||||
| 	return &PointerCodec{ | ||||
| 		ecache: make(map[reflect.Type]ValueEncoder), | ||||
| 		dcache: make(map[reflect.Type]ValueDecoder), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil | ||||
| // or looking up an encoder for the type of value the pointer points to. | ||||
| func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if val.Kind() != reflect.Ptr { | ||||
| 		if !val.IsValid() { | ||||
| 			return vw.WriteNull() | ||||
| 		} | ||||
| 		return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
|  | ||||
| 	pc.l.RLock() | ||||
| 	enc, ok := pc.ecache[val.Type()] | ||||
| 	pc.l.RUnlock() | ||||
| 	if ok { | ||||
| 		if enc == nil { | ||||
| 			return ErrNoEncoder{Type: val.Type()} | ||||
| 		} | ||||
| 		return enc.EncodeValue(ec, vw, val.Elem()) | ||||
| 	} | ||||
|  | ||||
| 	enc, err := ec.LookupEncoder(val.Type().Elem()) | ||||
| 	pc.l.Lock() | ||||
| 	pc.ecache[val.Type()] = enc | ||||
| 	pc.l.Unlock() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return enc.EncodeValue(ec, vw, val.Elem()) | ||||
| } | ||||
|  | ||||
| // DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and | ||||
| // using that to decode. If the BSON value is Null, this method will set the pointer to nil. | ||||
| func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Kind() != reflect.Ptr { | ||||
| 		return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if vr.Type() == bsontype.Null { | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return vr.ReadNull() | ||||
| 	} | ||||
| 	if vr.Type() == bsontype.Undefined { | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return vr.ReadUndefined() | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		val.Set(reflect.New(val.Type().Elem())) | ||||
| 	} | ||||
|  | ||||
| 	pc.l.RLock() | ||||
| 	dec, ok := pc.dcache[val.Type()] | ||||
| 	pc.l.RUnlock() | ||||
| 	if ok { | ||||
| 		if dec == nil { | ||||
| 			return ErrNoDecoder{Type: val.Type()} | ||||
| 		} | ||||
| 		return dec.DecodeValue(dc, vr, val.Elem()) | ||||
| 	} | ||||
|  | ||||
| 	dec, err := dc.LookupDecoder(val.Type().Elem()) | ||||
| 	pc.l.Lock() | ||||
| 	pc.dcache[val.Type()] = dec | ||||
| 	pc.l.Unlock() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return dec.DecodeValue(dc, vr, val.Elem()) | ||||
| } | ||||
							
								
								
									
										14
									
								
								mongo/bson/bsoncodec/proxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								mongo/bson/bsoncodec/proxy.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| // Proxy is an interface implemented by types that cannot themselves be directly encoded. Types | ||||
| // that implement this interface with have ProxyBSON called during the encoding process and that | ||||
| // value will be encoded in place for the implementer. | ||||
| type Proxy interface { | ||||
| 	ProxyBSON() (interface{}, error) | ||||
| } | ||||
							
								
								
									
										469
									
								
								mongo/bson/bsoncodec/registry.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										469
									
								
								mongo/bson/bsoncodec/registry.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,469 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| // ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. | ||||
| var ErrNilType = errors.New("cannot perform a decoder lookup on <nil>") | ||||
|  | ||||
| // ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder. | ||||
| var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder") | ||||
|  | ||||
| // ErrNoEncoder is returned when there wasn't an encoder available for a type. | ||||
| type ErrNoEncoder struct { | ||||
| 	Type reflect.Type | ||||
| } | ||||
|  | ||||
| func (ene ErrNoEncoder) Error() string { | ||||
| 	if ene.Type == nil { | ||||
| 		return "no encoder found for <nil>" | ||||
| 	} | ||||
| 	return "no encoder found for " + ene.Type.String() | ||||
| } | ||||
|  | ||||
| // ErrNoDecoder is returned when there wasn't a decoder available for a type. | ||||
| type ErrNoDecoder struct { | ||||
| 	Type reflect.Type | ||||
| } | ||||
|  | ||||
| func (end ErrNoDecoder) Error() string { | ||||
| 	return "no decoder found for " + end.Type.String() | ||||
| } | ||||
|  | ||||
| // ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type. | ||||
| type ErrNoTypeMapEntry struct { | ||||
| 	Type bsontype.Type | ||||
| } | ||||
|  | ||||
| func (entme ErrNoTypeMapEntry) Error() string { | ||||
| 	return "no type map entry found for " + entme.Type.String() | ||||
| } | ||||
|  | ||||
| // ErrNotInterface is returned when the provided type is not an interface. | ||||
| var ErrNotInterface = errors.New("The provided type is not an interface") | ||||
|  | ||||
| // A RegistryBuilder is used to build a Registry. This type is not goroutine | ||||
| // safe. | ||||
| type RegistryBuilder struct { | ||||
| 	typeEncoders      map[reflect.Type]ValueEncoder | ||||
| 	interfaceEncoders []interfaceValueEncoder | ||||
| 	kindEncoders      map[reflect.Kind]ValueEncoder | ||||
|  | ||||
| 	typeDecoders      map[reflect.Type]ValueDecoder | ||||
| 	interfaceDecoders []interfaceValueDecoder | ||||
| 	kindDecoders      map[reflect.Kind]ValueDecoder | ||||
|  | ||||
| 	typeMap map[bsontype.Type]reflect.Type | ||||
| } | ||||
|  | ||||
| // A Registry is used to store and retrieve codecs for types and interfaces. This type is the main | ||||
| // typed passed around and Encoders and Decoders are constructed from it. | ||||
| type Registry struct { | ||||
| 	typeEncoders map[reflect.Type]ValueEncoder | ||||
| 	typeDecoders map[reflect.Type]ValueDecoder | ||||
|  | ||||
| 	interfaceEncoders []interfaceValueEncoder | ||||
| 	interfaceDecoders []interfaceValueDecoder | ||||
|  | ||||
| 	kindEncoders map[reflect.Kind]ValueEncoder | ||||
| 	kindDecoders map[reflect.Kind]ValueDecoder | ||||
|  | ||||
| 	typeMap map[bsontype.Type]reflect.Type | ||||
|  | ||||
| 	mu sync.RWMutex | ||||
| } | ||||
|  | ||||
| // NewRegistryBuilder creates a new empty RegistryBuilder. | ||||
| func NewRegistryBuilder() *RegistryBuilder { | ||||
| 	return &RegistryBuilder{ | ||||
| 		typeEncoders: make(map[reflect.Type]ValueEncoder), | ||||
| 		typeDecoders: make(map[reflect.Type]ValueDecoder), | ||||
|  | ||||
| 		interfaceEncoders: make([]interfaceValueEncoder, 0), | ||||
| 		interfaceDecoders: make([]interfaceValueDecoder, 0), | ||||
|  | ||||
| 		kindEncoders: make(map[reflect.Kind]ValueEncoder), | ||||
| 		kindDecoders: make(map[reflect.Kind]ValueDecoder), | ||||
|  | ||||
| 		typeMap: make(map[bsontype.Type]reflect.Type), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func buildDefaultRegistry() *Registry { | ||||
| 	rb := NewRegistryBuilder() | ||||
| 	defaultValueEncoders.RegisterDefaultEncoders(rb) | ||||
| 	defaultValueDecoders.RegisterDefaultDecoders(rb) | ||||
| 	return rb.Build() | ||||
| } | ||||
|  | ||||
| // RegisterCodec will register the provided ValueCodec for the provided type. | ||||
| func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder { | ||||
| 	rb.RegisterTypeEncoder(t, codec) | ||||
| 	rb.RegisterTypeDecoder(t, codec) | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterTypeEncoder will register the provided ValueEncoder for the provided type. | ||||
| // | ||||
| // The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered | ||||
| // for a pointer to that type. | ||||
| // | ||||
| // If the given type is an interface, the encoder will be called when marshalling a type that is that interface. It | ||||
| // will not be called when marshalling a non-interface type that implements the interface. | ||||
| func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { | ||||
| 	rb.typeEncoders[t] = enc | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when | ||||
| // marshalling a type if the type implements t or a pointer to the type implements t. If the provided type is not | ||||
| // an interface (i.e. t.Kind() != reflect.Interface), this method will panic. | ||||
| func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { | ||||
| 	if t.Kind() != reflect.Interface { | ||||
| 		panicStr := fmt.Sprintf("RegisterHookEncoder expects a type with kind reflect.Interface, "+ | ||||
| 			"got type %s with kind %s", t, t.Kind()) | ||||
| 		panic(panicStr) | ||||
| 	} | ||||
|  | ||||
| 	for idx, encoder := range rb.interfaceEncoders { | ||||
| 		if encoder.i == t { | ||||
| 			rb.interfaceEncoders[idx].ve = enc | ||||
| 			return rb | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc}) | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterTypeDecoder will register the provided ValueDecoder for the provided type. | ||||
| // | ||||
| // The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered | ||||
| // for a pointer to that type. | ||||
| // | ||||
| // If the given type is an interface, the decoder will be called when unmarshalling into a type that is that interface. | ||||
| // It will not be called when unmarshalling into a non-interface type that implements the interface. | ||||
| func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { | ||||
| 	rb.typeDecoders[t] = dec | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when | ||||
| // unmarshalling into a type if the type implements t or a pointer to the type implements t. If the provided type is not | ||||
| // an interface (i.e. t.Kind() != reflect.Interface), this method will panic. | ||||
| func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { | ||||
| 	if t.Kind() != reflect.Interface { | ||||
| 		panicStr := fmt.Sprintf("RegisterHookDecoder expects a type with kind reflect.Interface, "+ | ||||
| 			"got type %s with kind %s", t, t.Kind()) | ||||
| 		panic(panicStr) | ||||
| 	} | ||||
|  | ||||
| 	for idx, decoder := range rb.interfaceDecoders { | ||||
| 		if decoder.i == t { | ||||
| 			rb.interfaceDecoders[idx].vd = dec | ||||
| 			return rb | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec}) | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterEncoder registers the provided type and encoder pair. | ||||
| // | ||||
| // Deprecated: Use RegisterTypeEncoder or RegisterHookEncoder instead. | ||||
| func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder { | ||||
| 	if t == tEmpty { | ||||
| 		rb.typeEncoders[t] = enc | ||||
| 		return rb | ||||
| 	} | ||||
| 	switch t.Kind() { | ||||
| 	case reflect.Interface: | ||||
| 		for idx, ir := range rb.interfaceEncoders { | ||||
| 			if ir.i == t { | ||||
| 				rb.interfaceEncoders[idx].ve = enc | ||||
| 				return rb | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc}) | ||||
| 	default: | ||||
| 		rb.typeEncoders[t] = enc | ||||
| 	} | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterDecoder registers the provided type and decoder pair. | ||||
| // | ||||
| // Deprecated: Use RegisterTypeDecoder or RegisterHookDecoder instead. | ||||
| func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder { | ||||
| 	if t == nil { | ||||
| 		rb.typeDecoders[nil] = dec | ||||
| 		return rb | ||||
| 	} | ||||
| 	if t == tEmpty { | ||||
| 		rb.typeDecoders[t] = dec | ||||
| 		return rb | ||||
| 	} | ||||
| 	switch t.Kind() { | ||||
| 	case reflect.Interface: | ||||
| 		for idx, ir := range rb.interfaceDecoders { | ||||
| 			if ir.i == t { | ||||
| 				rb.interfaceDecoders[idx].vd = dec | ||||
| 				return rb | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec}) | ||||
| 	default: | ||||
| 		rb.typeDecoders[t] = dec | ||||
| 	} | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterDefaultEncoder will registr the provided ValueEncoder to the provided | ||||
| // kind. | ||||
| func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder { | ||||
| 	rb.kindEncoders[kind] = enc | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterDefaultDecoder will register the provided ValueDecoder to the | ||||
| // provided kind. | ||||
| func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { | ||||
| 	rb.kindDecoders[kind] = dec | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this | ||||
| // mapping is decoding situations where an empty interface is used and a default type needs to be | ||||
| // created and decoded into. | ||||
| // | ||||
| // By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON | ||||
| // documents, a type map entry for bsontype.EmbeddedDocument should be registered. For example, to force BSON documents | ||||
| // to decode to bson.Raw, use the following code: | ||||
| // | ||||
| //	rb.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{})) | ||||
| func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) *RegistryBuilder { | ||||
| 	rb.typeMap[bt] = rt | ||||
| 	return rb | ||||
| } | ||||
|  | ||||
| // Build creates a Registry from the current state of this RegistryBuilder. | ||||
| func (rb *RegistryBuilder) Build() *Registry { | ||||
| 	registry := new(Registry) | ||||
|  | ||||
| 	registry.typeEncoders = make(map[reflect.Type]ValueEncoder) | ||||
| 	for t, enc := range rb.typeEncoders { | ||||
| 		registry.typeEncoders[t] = enc | ||||
| 	} | ||||
|  | ||||
| 	registry.typeDecoders = make(map[reflect.Type]ValueDecoder) | ||||
| 	for t, dec := range rb.typeDecoders { | ||||
| 		registry.typeDecoders[t] = dec | ||||
| 	} | ||||
|  | ||||
| 	registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.interfaceEncoders)) | ||||
| 	copy(registry.interfaceEncoders, rb.interfaceEncoders) | ||||
|  | ||||
| 	registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.interfaceDecoders)) | ||||
| 	copy(registry.interfaceDecoders, rb.interfaceDecoders) | ||||
|  | ||||
| 	registry.kindEncoders = make(map[reflect.Kind]ValueEncoder) | ||||
| 	for kind, enc := range rb.kindEncoders { | ||||
| 		registry.kindEncoders[kind] = enc | ||||
| 	} | ||||
|  | ||||
| 	registry.kindDecoders = make(map[reflect.Kind]ValueDecoder) | ||||
| 	for kind, dec := range rb.kindDecoders { | ||||
| 		registry.kindDecoders[kind] = dec | ||||
| 	} | ||||
|  | ||||
| 	registry.typeMap = make(map[bsontype.Type]reflect.Type) | ||||
| 	for bt, rt := range rb.typeMap { | ||||
| 		registry.typeMap[bt] = rt | ||||
| 	} | ||||
|  | ||||
| 	return registry | ||||
| } | ||||
|  | ||||
| // LookupEncoder inspects the registry for an encoder for the given type. The lookup precedence works as follows: | ||||
| // | ||||
| // 1. An encoder registered for the exact type. If the given type represents an interface, an encoder registered using | ||||
| // RegisterTypeEncoder for the interface will be selected. | ||||
| // | ||||
| // 2. An encoder registered using RegisterHookEncoder for an interface implemented by the type or by a pointer to the | ||||
| // type. | ||||
| // | ||||
| // 3. An encoder registered for the reflect.Kind of the value. | ||||
| // | ||||
| // If no encoder is found, an error of type ErrNoEncoder is returned. | ||||
| func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) { | ||||
| 	encodererr := ErrNoEncoder{Type: t} | ||||
| 	r.mu.RLock() | ||||
| 	enc, found := r.lookupTypeEncoder(t) | ||||
| 	r.mu.RUnlock() | ||||
| 	if found { | ||||
| 		if enc == nil { | ||||
| 			return nil, ErrNoEncoder{Type: t} | ||||
| 		} | ||||
| 		return enc, nil | ||||
| 	} | ||||
|  | ||||
| 	enc, found = r.lookupInterfaceEncoder(t, true) | ||||
| 	if found { | ||||
| 		r.mu.Lock() | ||||
| 		r.typeEncoders[t] = enc | ||||
| 		r.mu.Unlock() | ||||
| 		return enc, nil | ||||
| 	} | ||||
|  | ||||
| 	if t == nil { | ||||
| 		r.mu.Lock() | ||||
| 		r.typeEncoders[t] = nil | ||||
| 		r.mu.Unlock() | ||||
| 		return nil, encodererr | ||||
| 	} | ||||
|  | ||||
| 	enc, found = r.kindEncoders[t.Kind()] | ||||
| 	if !found { | ||||
| 		r.mu.Lock() | ||||
| 		r.typeEncoders[t] = nil | ||||
| 		r.mu.Unlock() | ||||
| 		return nil, encodererr | ||||
| 	} | ||||
|  | ||||
| 	r.mu.Lock() | ||||
| 	r.typeEncoders[t] = enc | ||||
| 	r.mu.Unlock() | ||||
| 	return enc, nil | ||||
| } | ||||
|  | ||||
| func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) { | ||||
| 	enc, found := r.typeEncoders[t] | ||||
| 	return enc, found | ||||
| } | ||||
|  | ||||
| func (r *Registry) lookupInterfaceEncoder(t reflect.Type, allowAddr bool) (ValueEncoder, bool) { | ||||
| 	if t == nil { | ||||
| 		return nil, false | ||||
| 	} | ||||
| 	for _, ienc := range r.interfaceEncoders { | ||||
| 		if t.Implements(ienc.i) { | ||||
| 			return ienc.ve, true | ||||
| 		} | ||||
| 		if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(ienc.i) { | ||||
| 			// if *t implements an interface, this will catch if t implements an interface further ahead | ||||
| 			// in interfaceEncoders | ||||
| 			defaultEnc, found := r.lookupInterfaceEncoder(t, false) | ||||
| 			if !found { | ||||
| 				defaultEnc = r.kindEncoders[t.Kind()] | ||||
| 			} | ||||
| 			return newCondAddrEncoder(ienc.ve, defaultEnc), true | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
|  | ||||
| // LookupDecoder inspects the registry for an decoder for the given type. The lookup precedence works as follows: | ||||
| // | ||||
| // 1. A decoder registered for the exact type. If the given type represents an interface, a decoder registered using | ||||
| // RegisterTypeDecoder for the interface will be selected. | ||||
| // | ||||
| // 2. A decoder registered using RegisterHookDecoder for an interface implemented by the type or by a pointer to the | ||||
| // type. | ||||
| // | ||||
| // 3. A decoder registered for the reflect.Kind of the value. | ||||
| // | ||||
| // If no decoder is found, an error of type ErrNoDecoder is returned. | ||||
| func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) { | ||||
| 	if t == nil { | ||||
| 		return nil, ErrNilType | ||||
| 	} | ||||
| 	decodererr := ErrNoDecoder{Type: t} | ||||
| 	r.mu.RLock() | ||||
| 	dec, found := r.lookupTypeDecoder(t) | ||||
| 	r.mu.RUnlock() | ||||
| 	if found { | ||||
| 		if dec == nil { | ||||
| 			return nil, ErrNoDecoder{Type: t} | ||||
| 		} | ||||
| 		return dec, nil | ||||
| 	} | ||||
|  | ||||
| 	dec, found = r.lookupInterfaceDecoder(t, true) | ||||
| 	if found { | ||||
| 		r.mu.Lock() | ||||
| 		r.typeDecoders[t] = dec | ||||
| 		r.mu.Unlock() | ||||
| 		return dec, nil | ||||
| 	} | ||||
|  | ||||
| 	dec, found = r.kindDecoders[t.Kind()] | ||||
| 	if !found { | ||||
| 		r.mu.Lock() | ||||
| 		r.typeDecoders[t] = nil | ||||
| 		r.mu.Unlock() | ||||
| 		return nil, decodererr | ||||
| 	} | ||||
|  | ||||
| 	r.mu.Lock() | ||||
| 	r.typeDecoders[t] = dec | ||||
| 	r.mu.Unlock() | ||||
| 	return dec, nil | ||||
| } | ||||
|  | ||||
| func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) { | ||||
| 	dec, found := r.typeDecoders[t] | ||||
| 	return dec, found | ||||
| } | ||||
|  | ||||
| func (r *Registry) lookupInterfaceDecoder(t reflect.Type, allowAddr bool) (ValueDecoder, bool) { | ||||
| 	for _, idec := range r.interfaceDecoders { | ||||
| 		if t.Implements(idec.i) { | ||||
| 			return idec.vd, true | ||||
| 		} | ||||
| 		if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(idec.i) { | ||||
| 			// if *t implements an interface, this will catch if t implements an interface further ahead | ||||
| 			// in interfaceDecoders | ||||
| 			defaultDec, found := r.lookupInterfaceDecoder(t, false) | ||||
| 			if !found { | ||||
| 				defaultDec = r.kindDecoders[t.Kind()] | ||||
| 			} | ||||
| 			return newCondAddrDecoder(idec.vd, defaultDec), true | ||||
| 		} | ||||
| 	} | ||||
| 	return nil, false | ||||
| } | ||||
|  | ||||
| // LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON | ||||
| // type. If no type is found, ErrNoTypeMapEntry is returned. | ||||
| func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) { | ||||
| 	t, ok := r.typeMap[bt] | ||||
| 	if !ok || t == nil { | ||||
| 		return nil, ErrNoTypeMapEntry{Type: bt} | ||||
| 	} | ||||
| 	return t, nil | ||||
| } | ||||
|  | ||||
| type interfaceValueEncoder struct { | ||||
| 	i  reflect.Type | ||||
| 	ve ValueEncoder | ||||
| } | ||||
|  | ||||
| type interfaceValueDecoder struct { | ||||
| 	i  reflect.Type | ||||
| 	vd ValueDecoder | ||||
| } | ||||
							
								
								
									
										125
									
								
								mongo/bson/bsoncodec/registry_examples_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								mongo/bson/bsoncodec/registry_examples_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,125 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec_test | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| func ExampleRegistry_customEncoder() { | ||||
| 	// Write a custom encoder for an integer type that is multiplied by -1 when | ||||
| 	// encoding. | ||||
|  | ||||
| 	// To register the default encoders and decoders in addition to this custom | ||||
| 	// one, use bson.NewRegistryBuilder instead. | ||||
| 	rb := bsoncodec.NewRegistryBuilder() | ||||
| 	type negatedInt int | ||||
|  | ||||
| 	niType := reflect.TypeOf(negatedInt(0)) | ||||
| 	encoder := func( | ||||
| 		ec bsoncodec.EncodeContext, | ||||
| 		vw bsonrw.ValueWriter, | ||||
| 		val reflect.Value, | ||||
| 	) error { | ||||
| 		// All encoder implementations should check that val is valid and is of | ||||
| 		// the correct type before proceeding. | ||||
| 		if !val.IsValid() || val.Type() != niType { | ||||
| 			return bsoncodec.ValueEncoderError{ | ||||
| 				Name:     "negatedIntEncodeValue", | ||||
| 				Types:    []reflect.Type{niType}, | ||||
| 				Received: val, | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// Negate val and encode as a BSON int32 if it can fit in 32 bits and a | ||||
| 		// BSON int64 otherwise. | ||||
| 		negatedVal := val.Int() * -1 | ||||
| 		if math.MinInt32 <= negatedVal && negatedVal <= math.MaxInt32 { | ||||
| 			return vw.WriteInt32(int32(negatedVal)) | ||||
| 		} | ||||
| 		return vw.WriteInt64(negatedVal) | ||||
| 	} | ||||
|  | ||||
| 	rb.RegisterTypeEncoder(niType, bsoncodec.ValueEncoderFunc(encoder)) | ||||
| } | ||||
|  | ||||
| func ExampleRegistry_customDecoder() { | ||||
| 	// Write a custom decoder for a boolean type that can be stored in the | ||||
| 	// database as a BSON boolean, int32, int64, double, or null. BSON int32, | ||||
| 	// int64, and double values are considered "true" in this decoder if they | ||||
| 	// are non-zero. BSON null values are always considered false. | ||||
|  | ||||
| 	// To register the default encoders and decoders in addition to this custom | ||||
| 	// one, use bson.NewRegistryBuilder instead. | ||||
| 	rb := bsoncodec.NewRegistryBuilder() | ||||
| 	type lenientBool bool | ||||
|  | ||||
| 	decoder := func( | ||||
| 		dc bsoncodec.DecodeContext, | ||||
| 		vr bsonrw.ValueReader, | ||||
| 		val reflect.Value, | ||||
| 	) error { | ||||
| 		// All decoder implementations should check that val is valid, settable, | ||||
| 		// and is of the correct kind before proceeding. | ||||
| 		if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool { | ||||
| 			return bsoncodec.ValueDecoderError{ | ||||
| 				Name:     "lenientBoolDecodeValue", | ||||
| 				Kinds:    []reflect.Kind{reflect.Bool}, | ||||
| 				Received: val, | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		var result bool | ||||
| 		switch vr.Type() { | ||||
| 		case bsontype.Boolean: | ||||
| 			b, err := vr.ReadBoolean() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			result = b | ||||
| 		case bsontype.Int32: | ||||
| 			i32, err := vr.ReadInt32() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			result = i32 != 0 | ||||
| 		case bsontype.Int64: | ||||
| 			i64, err := vr.ReadInt64() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			result = i64 != 0 | ||||
| 		case bsontype.Double: | ||||
| 			f64, err := vr.ReadDouble() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			result = f64 != 0 | ||||
| 		case bsontype.Null: | ||||
| 			if err := vr.ReadNull(); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			result = false | ||||
| 		default: | ||||
| 			return fmt.Errorf( | ||||
| 				"received invalid BSON type to decode into lenientBool: %s", | ||||
| 				vr.Type()) | ||||
| 		} | ||||
|  | ||||
| 		val.SetBool(result) | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	lbType := reflect.TypeOf(lenientBool(true)) | ||||
| 	rb.RegisterTypeDecoder(lbType, bsoncodec.ValueDecoderFunc(decoder)) | ||||
| } | ||||
							
								
								
									
										452
									
								
								mongo/bson/bsoncodec/registry_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										452
									
								
								mongo/bson/bsoncodec/registry_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,452 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| ) | ||||
|  | ||||
| func TestRegistry(t *testing.T) { | ||||
| 	t.Run("Register", func(t *testing.T) { | ||||
| 		fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) | ||||
| 		t.Run("interface", func(t *testing.T) { | ||||
| 			var t1f *testInterface1 | ||||
| 			var t2f *testInterface2 | ||||
| 			var t4f *testInterface4 | ||||
| 			ips := []interfaceValueEncoder{ | ||||
| 				{i: reflect.TypeOf(t1f).Elem(), ve: fc1}, | ||||
| 				{i: reflect.TypeOf(t2f).Elem(), ve: fc2}, | ||||
| 				{i: reflect.TypeOf(t1f).Elem(), ve: fc3}, | ||||
| 				{i: reflect.TypeOf(t4f).Elem(), ve: fc4}, | ||||
| 			} | ||||
| 			want := []interfaceValueEncoder{ | ||||
| 				{i: reflect.TypeOf(t1f).Elem(), ve: fc3}, | ||||
| 				{i: reflect.TypeOf(t2f).Elem(), ve: fc2}, | ||||
| 				{i: reflect.TypeOf(t4f).Elem(), ve: fc4}, | ||||
| 			} | ||||
| 			rb := NewRegistryBuilder() | ||||
| 			for _, ip := range ips { | ||||
| 				rb.RegisterHookEncoder(ip.i, ip.ve) | ||||
| 			} | ||||
| 			got := rb.interfaceEncoders | ||||
| 			if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { | ||||
| 				t.Errorf("The registered interfaces are not correct. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("type", func(t *testing.T) { | ||||
| 			ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} | ||||
| 			rb := NewRegistryBuilder(). | ||||
| 				RegisterTypeEncoder(reflect.TypeOf(ft1), fc1). | ||||
| 				RegisterTypeEncoder(reflect.TypeOf(ft2), fc2). | ||||
| 				RegisterTypeEncoder(reflect.TypeOf(ft1), fc3). | ||||
| 				RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) | ||||
| 			want := []struct { | ||||
| 				t reflect.Type | ||||
| 				c ValueEncoder | ||||
| 			}{ | ||||
| 				{reflect.TypeOf(ft1), fc3}, | ||||
| 				{reflect.TypeOf(ft2), fc2}, | ||||
| 				{reflect.TypeOf(ft4), fc4}, | ||||
| 			} | ||||
| 			got := rb.typeEncoders | ||||
| 			for _, s := range want { | ||||
| 				wantT, wantC := s.t, s.c | ||||
| 				gotC, exists := got[wantT] | ||||
| 				if !exists { | ||||
| 					t.Errorf("Did not find type in the type registry: %v", wantT) | ||||
| 				} | ||||
| 				if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { | ||||
| 					t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("kind", func(t *testing.T) { | ||||
| 			k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map | ||||
| 			rb := NewRegistryBuilder(). | ||||
| 				RegisterDefaultEncoder(k1, fc1). | ||||
| 				RegisterDefaultEncoder(k2, fc2). | ||||
| 				RegisterDefaultEncoder(k1, fc3). | ||||
| 				RegisterDefaultEncoder(k4, fc4) | ||||
| 			want := []struct { | ||||
| 				k reflect.Kind | ||||
| 				c ValueEncoder | ||||
| 			}{ | ||||
| 				{k1, fc3}, | ||||
| 				{k2, fc2}, | ||||
| 				{k4, fc4}, | ||||
| 			} | ||||
| 			got := rb.kindEncoders | ||||
| 			for _, s := range want { | ||||
| 				wantK, wantC := s.k, s.c | ||||
| 				gotC, exists := got[wantK] | ||||
| 				if !exists { | ||||
| 					t.Errorf("Did not find kind in the kind registry: %v", wantK) | ||||
| 				} | ||||
| 				if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { | ||||
| 					t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC) | ||||
| 				} | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("RegisterDefault", func(t *testing.T) { | ||||
| 			t.Run("MapCodec", func(t *testing.T) { | ||||
| 				codec := fakeCodec{num: 1} | ||||
| 				codec2 := fakeCodec{num: 2} | ||||
| 				rb := NewRegistryBuilder() | ||||
| 				rb.RegisterDefaultEncoder(reflect.Map, codec) | ||||
| 				if rb.kindEncoders[reflect.Map] != codec { | ||||
| 					t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec) | ||||
| 				} | ||||
| 				rb.RegisterDefaultEncoder(reflect.Map, codec2) | ||||
| 				if rb.kindEncoders[reflect.Map] != codec2 { | ||||
| 					t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec2) | ||||
| 				} | ||||
| 			}) | ||||
| 			t.Run("StructCodec", func(t *testing.T) { | ||||
| 				codec := fakeCodec{num: 1} | ||||
| 				codec2 := fakeCodec{num: 2} | ||||
| 				rb := NewRegistryBuilder() | ||||
| 				rb.RegisterDefaultEncoder(reflect.Struct, codec) | ||||
| 				if rb.kindEncoders[reflect.Struct] != codec { | ||||
| 					t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec) | ||||
| 				} | ||||
| 				rb.RegisterDefaultEncoder(reflect.Struct, codec2) | ||||
| 				if rb.kindEncoders[reflect.Struct] != codec2 { | ||||
| 					t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec2) | ||||
| 				} | ||||
| 			}) | ||||
| 			t.Run("SliceCodec", func(t *testing.T) { | ||||
| 				codec := fakeCodec{num: 1} | ||||
| 				codec2 := fakeCodec{num: 2} | ||||
| 				rb := NewRegistryBuilder() | ||||
| 				rb.RegisterDefaultEncoder(reflect.Slice, codec) | ||||
| 				if rb.kindEncoders[reflect.Slice] != codec { | ||||
| 					t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec) | ||||
| 				} | ||||
| 				rb.RegisterDefaultEncoder(reflect.Slice, codec2) | ||||
| 				if rb.kindEncoders[reflect.Slice] != codec2 { | ||||
| 					t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec2) | ||||
| 				} | ||||
| 			}) | ||||
| 			t.Run("ArrayCodec", func(t *testing.T) { | ||||
| 				codec := fakeCodec{num: 1} | ||||
| 				codec2 := fakeCodec{num: 2} | ||||
| 				rb := NewRegistryBuilder() | ||||
| 				rb.RegisterDefaultEncoder(reflect.Array, codec) | ||||
| 				if rb.kindEncoders[reflect.Array] != codec { | ||||
| 					t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec) | ||||
| 				} | ||||
| 				rb.RegisterDefaultEncoder(reflect.Array, codec2) | ||||
| 				if rb.kindEncoders[reflect.Array] != codec2 { | ||||
| 					t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec2) | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 		t.Run("Lookup", func(t *testing.T) { | ||||
| 			type Codec interface { | ||||
| 				ValueEncoder | ||||
| 				ValueDecoder | ||||
| 			} | ||||
|  | ||||
| 			var arrinstance [12]int | ||||
| 			arr := reflect.TypeOf(arrinstance) | ||||
| 			slc := reflect.TypeOf(make([]int, 12)) | ||||
| 			m := reflect.TypeOf(make(map[string]int)) | ||||
| 			strct := reflect.TypeOf(struct{ Foo string }{}) | ||||
| 			ft1 := reflect.PtrTo(reflect.TypeOf(fakeType1{})) | ||||
| 			ft2 := reflect.TypeOf(fakeType2{}) | ||||
| 			ft3 := reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" })) | ||||
| 			ti1 := reflect.TypeOf((*testInterface1)(nil)).Elem() | ||||
| 			ti2 := reflect.TypeOf((*testInterface2)(nil)).Elem() | ||||
| 			ti1Impl := reflect.TypeOf(testInterface1Impl{}) | ||||
| 			ti2Impl := reflect.TypeOf(testInterface2Impl{}) | ||||
| 			ti3 := reflect.TypeOf((*testInterface3)(nil)).Elem() | ||||
| 			ti3Impl := reflect.TypeOf(testInterface3Impl{}) | ||||
| 			ti3ImplPtr := reflect.TypeOf((*testInterface3Impl)(nil)) | ||||
| 			fc1, fc2 := fakeCodec{num: 1}, fakeCodec{num: 2} | ||||
| 			fsc, fslcc, fmc := new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) | ||||
| 			pc := NewPointerCodec() | ||||
|  | ||||
| 			reg := NewRegistryBuilder(). | ||||
| 				RegisterTypeEncoder(ft1, fc1). | ||||
| 				RegisterTypeEncoder(ft2, fc2). | ||||
| 				RegisterTypeEncoder(ti1, fc1). | ||||
| 				RegisterDefaultEncoder(reflect.Struct, fsc). | ||||
| 				RegisterDefaultEncoder(reflect.Slice, fslcc). | ||||
| 				RegisterDefaultEncoder(reflect.Array, fslcc). | ||||
| 				RegisterDefaultEncoder(reflect.Map, fmc). | ||||
| 				RegisterDefaultEncoder(reflect.Ptr, pc). | ||||
| 				RegisterTypeDecoder(ft1, fc1). | ||||
| 				RegisterTypeDecoder(ft2, fc2). | ||||
| 				RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder | ||||
| 				RegisterDefaultDecoder(reflect.Struct, fsc). | ||||
| 				RegisterDefaultDecoder(reflect.Slice, fslcc). | ||||
| 				RegisterDefaultDecoder(reflect.Array, fslcc). | ||||
| 				RegisterDefaultDecoder(reflect.Map, fmc). | ||||
| 				RegisterDefaultDecoder(reflect.Ptr, pc). | ||||
| 				RegisterHookEncoder(ti2, fc2). | ||||
| 				RegisterHookDecoder(ti2, fc2). | ||||
| 				RegisterHookEncoder(ti3, fc3). | ||||
| 				RegisterHookDecoder(ti3, fc3). | ||||
| 				Build() | ||||
|  | ||||
| 			testCases := []struct { | ||||
| 				name      string | ||||
| 				t         reflect.Type | ||||
| 				wantcodec Codec | ||||
| 				wanterr   error | ||||
| 				testcache bool | ||||
| 			}{ | ||||
| 				{ | ||||
| 					"type registry (pointer)", | ||||
| 					ft1, | ||||
| 					fc1, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"type registry (non-pointer)", | ||||
| 					ft2, | ||||
| 					fc2, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					// lookup an interface type and expect that the registered encoder is returned | ||||
| 					"interface with type encoder", | ||||
| 					ti1, | ||||
| 					fc1, | ||||
| 					nil, | ||||
| 					true, | ||||
| 				}, | ||||
| 				{ | ||||
| 					// lookup a type that implements an interface and expect that the default struct codec is returned | ||||
| 					"interface implementation with type encoder", | ||||
| 					ti1Impl, | ||||
| 					fsc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					// lookup an interface type and expect that the registered hook is returned | ||||
| 					"interface with hook", | ||||
| 					ti2, | ||||
| 					fc2, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					// lookup a type that implements an interface and expect that the registered hook is returned | ||||
| 					"interface implementation with hook", | ||||
| 					ti2Impl, | ||||
| 					fc2, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					// lookup a pointer to a type where the pointer implements an interface and expect that the | ||||
| 					// registered hook is returned | ||||
| 					"interface pointer to implementation with hook (pointer)", | ||||
| 					ti3ImplPtr, | ||||
| 					fc3, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"default struct codec (pointer)", | ||||
| 					reflect.PtrTo(strct), | ||||
| 					pc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"default struct codec (non-pointer)", | ||||
| 					strct, | ||||
| 					fsc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"default array codec", | ||||
| 					arr, | ||||
| 					fslcc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"default slice codec", | ||||
| 					slc, | ||||
| 					fslcc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"default map", | ||||
| 					m, | ||||
| 					fmc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"map non-string key", | ||||
| 					reflect.TypeOf(map[int]int{}), | ||||
| 					fmc, | ||||
| 					nil, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"No Codec Registered", | ||||
| 					ft3, | ||||
| 					nil, | ||||
| 					ErrNoEncoder{Type: ft3}, | ||||
| 					false, | ||||
| 				}, | ||||
| 			} | ||||
|  | ||||
| 			allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) | ||||
| 			comparepc := func(pc1, pc2 *PointerCodec) bool { return true } | ||||
| 			for _, tc := range testCases { | ||||
| 				t.Run(tc.name, func(t *testing.T) { | ||||
| 					t.Run("Encoder", func(t *testing.T) { | ||||
| 						gotcodec, goterr := reg.LookupEncoder(tc.t) | ||||
| 						if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) { | ||||
| 							t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr) | ||||
| 						} | ||||
| 						if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { | ||||
| 							t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec) | ||||
| 						} | ||||
| 					}) | ||||
| 					t.Run("Decoder", func(t *testing.T) { | ||||
| 						var wanterr error | ||||
| 						if ene, ok := tc.wanterr.(ErrNoEncoder); ok { | ||||
| 							wanterr = ErrNoDecoder(ene) | ||||
| 						} else { | ||||
| 							wanterr = tc.wanterr | ||||
| 						} | ||||
| 						gotcodec, goterr := reg.LookupDecoder(tc.t) | ||||
| 						if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { | ||||
| 							t.Errorf("Errors did not match. got %v; want %v", goterr, wanterr) | ||||
| 						} | ||||
| 						if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { | ||||
| 							t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec) | ||||
| 							t.Errorf("Codecs did not match. got %T; want %T", gotcodec, tc.wantcodec) | ||||
| 						} | ||||
| 					}) | ||||
| 				}) | ||||
| 			} | ||||
| 			// lookup a type whose pointer implements an interface and expect that the registered hook is | ||||
| 			// returned | ||||
| 			t.Run("interface implementation with hook (pointer)", func(t *testing.T) { | ||||
| 				t.Run("Encoder", func(t *testing.T) { | ||||
| 					gotEnc, err := reg.LookupEncoder(ti3Impl) | ||||
| 					assert.Nil(t, err, "LookupEncoder error: %v", err) | ||||
|  | ||||
| 					cae, ok := gotEnc.(*condAddrEncoder) | ||||
| 					assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc) | ||||
| 					if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) { | ||||
| 						t.Errorf("expected canAddrEnc %v, got %v", cae.canAddrEnc, fc3) | ||||
| 					} | ||||
| 					if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) { | ||||
| 						t.Errorf("expected elseEnc %v, got %v", cae.elseEnc, fsc) | ||||
| 					} | ||||
| 				}) | ||||
| 				t.Run("Decoder", func(t *testing.T) { | ||||
| 					gotDec, err := reg.LookupDecoder(ti3Impl) | ||||
| 					assert.Nil(t, err, "LookupDecoder error: %v", err) | ||||
|  | ||||
| 					cad, ok := gotDec.(*condAddrDecoder) | ||||
| 					assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec) | ||||
| 					if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) { | ||||
| 						t.Errorf("expected canAddrDec %v, got %v", cad.canAddrDec, fc3) | ||||
| 					} | ||||
| 					if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) { | ||||
| 						t.Errorf("expected elseDec %v, got %v", cad.elseDec, fsc) | ||||
| 					} | ||||
| 				}) | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("Type Map", func(t *testing.T) { | ||||
| 		reg := NewRegistryBuilder(). | ||||
| 			RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")). | ||||
| 			RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))). | ||||
| 			Build() | ||||
|  | ||||
| 		var got, want reflect.Type | ||||
|  | ||||
| 		want = reflect.TypeOf("") | ||||
| 		got, err := reg.LookupTypeMapEntry(bsontype.String) | ||||
| 		noerr(t, err) | ||||
| 		if got != want { | ||||
| 			t.Errorf("Did not get expected type. got %v; want %v", got, want) | ||||
| 		} | ||||
|  | ||||
| 		want = reflect.TypeOf(int(0)) | ||||
| 		got, err = reg.LookupTypeMapEntry(bsontype.Int32) | ||||
| 		noerr(t, err) | ||||
| 		if got != want { | ||||
| 			t.Errorf("Did not get expected type. got %v; want %v", got, want) | ||||
| 		} | ||||
|  | ||||
| 		want = nil | ||||
| 		wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID} | ||||
| 		got, err = reg.LookupTypeMapEntry(bsontype.ObjectID) | ||||
| 		if err != wanterr { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", err, wanterr) | ||||
| 		} | ||||
| 		if got != want { | ||||
| 			t.Errorf("Did not get expected type. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type fakeType1 struct{} | ||||
| type fakeType2 struct{} | ||||
| type fakeType4 struct{} | ||||
| type fakeType5 func(string, string) string | ||||
| type fakeStructCodec struct{ fakeCodec } | ||||
| type fakeSliceCodec struct{ fakeCodec } | ||||
| type fakeMapCodec struct{ fakeCodec } | ||||
|  | ||||
| type fakeCodec struct{ num int } | ||||
|  | ||||
| func (fc fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error { | ||||
| 	return nil | ||||
| } | ||||
| func (fc fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| type testInterface1 interface{ test1() } | ||||
| type testInterface2 interface{ test2() } | ||||
| type testInterface3 interface{ test3() } | ||||
| type testInterface4 interface{ test4() } | ||||
|  | ||||
| type testInterface1Impl struct{} | ||||
|  | ||||
| var _ testInterface1 = testInterface1Impl{} | ||||
|  | ||||
| func (testInterface1Impl) test1() {} | ||||
|  | ||||
| type testInterface2Impl struct{} | ||||
|  | ||||
| var _ testInterface2 = testInterface2Impl{} | ||||
|  | ||||
| func (testInterface2Impl) test2() {} | ||||
|  | ||||
| type testInterface3Impl struct{} | ||||
|  | ||||
| var _ testInterface3 = (*testInterface3Impl)(nil) | ||||
|  | ||||
| func (*testInterface3Impl) test3() {} | ||||
|  | ||||
| func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 } | ||||
							
								
								
									
										199
									
								
								mongo/bson/bsoncodec/slice_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								mongo/bson/bsoncodec/slice_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,199 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| var defaultSliceCodec = NewSliceCodec() | ||||
|  | ||||
| // SliceCodec is the Codec used for slice values. | ||||
| type SliceCodec struct { | ||||
| 	EncodeNilAsEmpty bool | ||||
| } | ||||
|  | ||||
| var _ ValueCodec = &MapCodec{} | ||||
|  | ||||
| // NewSliceCodec returns a MapCodec with options opts. | ||||
| func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec { | ||||
| 	sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...) | ||||
|  | ||||
| 	codec := SliceCodec{} | ||||
| 	if sliceOpt.EncodeNilAsEmpty != nil { | ||||
| 		codec.EncodeNilAsEmpty = *sliceOpt.EncodeNilAsEmpty | ||||
| 	} | ||||
| 	return &codec | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoder for slice types. | ||||
| func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Slice { | ||||
| 		return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() && !sc.EncodeNilAsEmpty { | ||||
| 		return vw.WriteNull() | ||||
| 	} | ||||
|  | ||||
| 	// If we have a []byte we want to treat it as a binary instead of as an array. | ||||
| 	if val.Type().Elem() == tByte { | ||||
| 		var byteSlice []byte | ||||
| 		for idx := 0; idx < val.Len(); idx++ { | ||||
| 			byteSlice = append(byteSlice, val.Index(idx).Interface().(byte)) | ||||
| 		} | ||||
| 		return vw.WriteBinary(byteSlice) | ||||
| 	} | ||||
|  | ||||
| 	// If we have a []primitive.E we want to treat it as a document instead of as an array. | ||||
| 	if val.Type().ConvertibleTo(tD) { | ||||
| 		d := val.Convert(tD).Interface().(primitive.D) | ||||
|  | ||||
| 		dw, err := vw.WriteDocument() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		for _, e := range d { | ||||
| 			err = encodeElement(ec, dw, e) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return dw.WriteDocumentEnd() | ||||
| 	} | ||||
|  | ||||
| 	aw, err := vw.WriteArray() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	elemType := val.Type().Elem() | ||||
| 	encoder, err := ec.LookupEncoder(elemType) | ||||
| 	if err != nil && elemType.Kind() != reflect.Interface { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for idx := 0; idx < val.Len(); idx++ { | ||||
| 		currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx)) | ||||
| 		if lookupErr != nil && lookupErr != errInvalidValue { | ||||
| 			return lookupErr | ||||
| 		} | ||||
|  | ||||
| 		vw, err := aw.WriteArrayElement() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if lookupErr == errInvalidValue { | ||||
| 			err = vw.WriteNull() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		err = currEncoder.EncodeValue(ec, vw, currVal) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return aw.WriteArrayEnd() | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoder for slice types. | ||||
| func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Kind() != reflect.Slice { | ||||
| 		return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	switch vrType := vr.Type(); vrType { | ||||
| 	case bsontype.Array: | ||||
| 	case bsontype.Null: | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return vr.ReadNull() | ||||
| 	case bsontype.Undefined: | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return vr.ReadUndefined() | ||||
| 	case bsontype.Type(0), bsontype.EmbeddedDocument: | ||||
| 		if val.Type().Elem() != tE { | ||||
| 			return fmt.Errorf("cannot decode document into %s", val.Type()) | ||||
| 		} | ||||
| 	case bsontype.Binary: | ||||
| 		if val.Type().Elem() != tByte { | ||||
| 			return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType) | ||||
| 		} | ||||
| 		data, subtype, err := vr.ReadBinary() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { | ||||
| 			return fmt.Errorf("SliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype) | ||||
| 		} | ||||
|  | ||||
| 		if val.IsNil() { | ||||
| 			val.Set(reflect.MakeSlice(val.Type(), 0, len(data))) | ||||
| 		} | ||||
|  | ||||
| 		val.SetLen(0) | ||||
| 		for _, elem := range data { | ||||
| 			val.Set(reflect.Append(val, reflect.ValueOf(elem))) | ||||
| 		} | ||||
| 		return nil | ||||
| 	case bsontype.String: | ||||
| 		if sliceType := val.Type().Elem(); sliceType != tByte { | ||||
| 			return fmt.Errorf("SliceDecodeValue can only decode a string into a byte array, got %v", sliceType) | ||||
| 		} | ||||
| 		str, err := vr.ReadString() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		byteStr := []byte(str) | ||||
|  | ||||
| 		if val.IsNil() { | ||||
| 			val.Set(reflect.MakeSlice(val.Type(), 0, len(byteStr))) | ||||
| 		} | ||||
|  | ||||
| 		val.SetLen(0) | ||||
| 		for _, elem := range byteStr { | ||||
| 			val.Set(reflect.Append(val, reflect.ValueOf(elem))) | ||||
| 		} | ||||
| 		return nil | ||||
| 	default: | ||||
| 		return fmt.Errorf("cannot decode %v into a slice", vrType) | ||||
| 	} | ||||
|  | ||||
| 	var elemsFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) ([]reflect.Value, error) | ||||
| 	switch val.Type().Elem() { | ||||
| 	case tE: | ||||
| 		dc.Ancestor = val.Type() | ||||
| 		elemsFunc = defaultValueDecoders.decodeD | ||||
| 	default: | ||||
| 		elemsFunc = defaultValueDecoders.decodeDefault | ||||
| 	} | ||||
|  | ||||
| 	elems, err := elemsFunc(dc, vr, val) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if val.IsNil() { | ||||
| 		val.Set(reflect.MakeSlice(val.Type(), 0, len(elems))) | ||||
| 	} | ||||
|  | ||||
| 	val.SetLen(0) | ||||
| 	val.Set(reflect.Append(val, elems...)) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										119
									
								
								mongo/bson/bsoncodec/string_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								mongo/bson/bsoncodec/string_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,119 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| // StringCodec is the Codec used for struct values. | ||||
| type StringCodec struct { | ||||
| 	DecodeObjectIDAsHex bool | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	defaultStringCodec = NewStringCodec() | ||||
|  | ||||
| 	_ ValueCodec  = defaultStringCodec | ||||
| 	_ typeDecoder = defaultStringCodec | ||||
| ) | ||||
|  | ||||
| // NewStringCodec returns a StringCodec with options opts. | ||||
| func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec { | ||||
| 	stringOpt := bsonoptions.MergeStringCodecOptions(opts...) | ||||
| 	return &StringCodec{*stringOpt.DecodeObjectIDAsHex} | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoder for string types. | ||||
| func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if val.Kind() != reflect.String { | ||||
| 		return ValueEncoderError{ | ||||
| 			Name:     "StringEncodeValue", | ||||
| 			Kinds:    []reflect.Kind{reflect.String}, | ||||
| 			Received: val, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return vw.WriteString(val.String()) | ||||
| } | ||||
|  | ||||
| func (sc *StringCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	if t.Kind() != reflect.String { | ||||
| 		return emptyValue, ValueDecoderError{ | ||||
| 			Name:     "StringDecodeValue", | ||||
| 			Kinds:    []reflect.Kind{reflect.String}, | ||||
| 			Received: reflect.Zero(t), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var str string | ||||
| 	var err error | ||||
| 	switch vr.Type() { | ||||
| 	case bsontype.String: | ||||
| 		str, err = vr.ReadString() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.ObjectID: | ||||
| 		oid, err := vr.ReadObjectID() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		if sc.DecodeObjectIDAsHex { | ||||
| 			str = oid.Hex() | ||||
| 		} else { | ||||
| 			byteArray := [12]byte(oid) | ||||
| 			str = string(byteArray[:]) | ||||
| 		} | ||||
| 	case bsontype.Symbol: | ||||
| 		str, err = vr.ReadSymbol() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.Binary: | ||||
| 		data, subtype, err := vr.ReadBinary() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld { | ||||
| 			return emptyValue, decodeBinaryError{subtype: subtype, typeName: "string"} | ||||
| 		} | ||||
| 		str = string(data) | ||||
| 	case bsontype.Null: | ||||
| 		if err = vr.ReadNull(); err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.Undefined: | ||||
| 		if err = vr.ReadUndefined(); err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	default: | ||||
| 		return emptyValue, fmt.Errorf("cannot decode %v into a string type", vr.Type()) | ||||
| 	} | ||||
|  | ||||
| 	return reflect.ValueOf(str), nil | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoder for string types. | ||||
| func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Kind() != reflect.String { | ||||
| 		return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	elem, err := sc.decodeType(dctx, vr, val.Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	val.SetString(elem.String()) | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										48
									
								
								mongo/bson/bsoncodec/string_codec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								mongo/bson/bsoncodec/string_codec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,48 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| ) | ||||
|  | ||||
| func TestStringCodec(t *testing.T) { | ||||
| 	t.Run("ObjectIDAsHex", func(t *testing.T) { | ||||
| 		oid := primitive.NewObjectID() | ||||
| 		byteArray := [12]byte(oid) | ||||
| 		reader := &bsonrwtest.ValueReaderWriter{BSONType: bsontype.ObjectID, Return: oid} | ||||
| 		testCases := []struct { | ||||
| 			name   string | ||||
| 			opts   *bsonoptions.StringCodecOptions | ||||
| 			hex    bool | ||||
| 			result string | ||||
| 		}{ | ||||
| 			{"default", bsonoptions.StringCodec(), true, oid.Hex()}, | ||||
| 			{"true", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(true), true, oid.Hex()}, | ||||
| 			{"false", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false), false, string(byteArray[:])}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				stringCodec := NewStringCodec(tc.opts) | ||||
|  | ||||
| 				actual := reflect.New(reflect.TypeOf("")).Elem() | ||||
| 				err := stringCodec.DecodeValue(DecodeContext{}, reader, actual) | ||||
| 				assert.Nil(t, err, "StringCodec.DecodeValue error: %v", err) | ||||
|  | ||||
| 				actualString := actual.Interface().(string) | ||||
| 				assert.Equal(t, tc.result, actualString, "Expected string %v, got %v", tc.result, actualString) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										675
									
								
								mongo/bson/bsoncodec/struct_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										675
									
								
								mongo/bson/bsoncodec/struct_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,675 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| // DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type. | ||||
| type DecodeError struct { | ||||
| 	keys    []string | ||||
| 	wrapped error | ||||
| } | ||||
|  | ||||
| // Unwrap returns the underlying error | ||||
| func (de *DecodeError) Unwrap() error { | ||||
| 	return de.wrapped | ||||
| } | ||||
|  | ||||
| // Error implements the error interface. | ||||
| func (de *DecodeError) Error() string { | ||||
| 	// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the | ||||
| 	// stack of BSON keys, so we call de.Keys(), which reverses them. | ||||
| 	keyPath := strings.Join(de.Keys(), ".") | ||||
| 	return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped) | ||||
| } | ||||
|  | ||||
| // Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down | ||||
| // order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be | ||||
| // a string, the keys slice will be ["a", "b", "c"]. | ||||
| func (de *DecodeError) Keys() []string { | ||||
| 	reversedKeys := make([]string, 0, len(de.keys)) | ||||
| 	for idx := len(de.keys) - 1; idx >= 0; idx-- { | ||||
| 		reversedKeys = append(reversedKeys, de.keys[idx]) | ||||
| 	} | ||||
|  | ||||
| 	return reversedKeys | ||||
| } | ||||
|  | ||||
| // Zeroer allows custom struct types to implement a report of zero | ||||
| // state. All struct types that don't implement Zeroer or where IsZero | ||||
| // returns false are considered to be not zero. | ||||
| type Zeroer interface { | ||||
| 	IsZero() bool | ||||
| } | ||||
|  | ||||
| // StructCodec is the Codec used for struct values. | ||||
| type StructCodec struct { | ||||
| 	cache                            map[reflect.Type]*structDescription | ||||
| 	l                                sync.RWMutex | ||||
| 	parser                           StructTagParser | ||||
| 	DecodeZeroStruct                 bool | ||||
| 	DecodeDeepZeroInline             bool | ||||
| 	EncodeOmitDefaultStruct          bool | ||||
| 	AllowUnexportedFields            bool | ||||
| 	OverwriteDuplicatedInlinedFields bool | ||||
| } | ||||
|  | ||||
| var _ ValueEncoder = &StructCodec{} | ||||
| var _ ValueDecoder = &StructCodec{} | ||||
|  | ||||
| // NewStructCodec returns a StructCodec that uses p for struct tag parsing. | ||||
| func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) { | ||||
| 	if p == nil { | ||||
| 		return nil, errors.New("a StructTagParser must be provided to NewStructCodec") | ||||
| 	} | ||||
|  | ||||
| 	structOpt := bsonoptions.MergeStructCodecOptions(opts...) | ||||
|  | ||||
| 	codec := &StructCodec{ | ||||
| 		cache:  make(map[reflect.Type]*structDescription), | ||||
| 		parser: p, | ||||
| 	} | ||||
|  | ||||
| 	if structOpt.DecodeZeroStruct != nil { | ||||
| 		codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct | ||||
| 	} | ||||
| 	if structOpt.DecodeDeepZeroInline != nil { | ||||
| 		codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline | ||||
| 	} | ||||
| 	if structOpt.EncodeOmitDefaultStruct != nil { | ||||
| 		codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct | ||||
| 	} | ||||
| 	if structOpt.OverwriteDuplicatedInlinedFields != nil { | ||||
| 		codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields | ||||
| 	} | ||||
| 	if structOpt.AllowUnexportedFields != nil { | ||||
| 		codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields | ||||
| 	} | ||||
|  | ||||
| 	return codec, nil | ||||
| } | ||||
|  | ||||
| // EncodeValue handles encoding generic struct types. | ||||
| func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Kind() != reflect.Struct { | ||||
| 		return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	sd, err := sc.describeStruct(r.Registry, val.Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	dw, err := vw.WriteDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	var rv reflect.Value | ||||
| 	for _, desc := range sd.fl { | ||||
| 		if desc.omitAlways { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if desc.inline == nil { | ||||
| 			rv = val.Field(desc.idx) | ||||
| 		} else { | ||||
| 			rv, err = fieldByIndexErr(val, desc.inline) | ||||
| 			if err != nil { | ||||
| 				continue | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv) | ||||
|  | ||||
| 		if err != nil && err != errInvalidValue { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if err == errInvalidValue { | ||||
| 			if desc.omitEmpty { | ||||
| 				continue | ||||
| 			} | ||||
| 			vw2, err := dw.WriteDocumentElement(desc.name) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			err = vw2.WriteNull() | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if desc.encoder == nil { | ||||
| 			return ErrNoEncoder{Type: rv.Type()} | ||||
| 		} | ||||
|  | ||||
| 		encoder := desc.encoder | ||||
|  | ||||
| 		var isZero bool | ||||
| 		rvInterface := rv.Interface() | ||||
| 		if cz, ok := encoder.(CodecZeroer); ok { | ||||
| 			isZero = cz.IsTypeZero(rvInterface) | ||||
| 		} else if rv.Kind() == reflect.Interface { | ||||
| 			// sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately. | ||||
| 			isZero = rv.IsNil() | ||||
| 		} else { | ||||
| 			isZero = sc.isZero(rvInterface) | ||||
| 		} | ||||
| 		if desc.omitEmpty && isZero { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		vw2, err := dw.WriteDocumentElement(desc.name) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize} | ||||
| 		err = encoder.EncodeValue(ectx, vw2, rv) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if sd.inlineMap >= 0 { | ||||
| 		rv := val.Field(sd.inlineMap) | ||||
| 		collisionFn := func(key string) bool { | ||||
| 			_, exists := sd.fm[key] | ||||
| 			return exists | ||||
| 		} | ||||
|  | ||||
| 		return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn) | ||||
| 	} | ||||
|  | ||||
| 	return dw.WriteDocumentEnd() | ||||
| } | ||||
|  | ||||
| func newDecodeError(key string, original error) error { | ||||
| 	de, ok := original.(*DecodeError) | ||||
| 	if !ok { | ||||
| 		return &DecodeError{ | ||||
| 			keys:    []string{key}, | ||||
| 			wrapped: original, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	de.keys = append(de.keys, key) | ||||
| 	return de | ||||
| } | ||||
|  | ||||
| // DecodeValue implements the Codec interface. | ||||
| // By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr. | ||||
| // For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared. | ||||
| func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Kind() != reflect.Struct { | ||||
| 		return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	switch vrType := vr.Type(); vrType { | ||||
| 	case bsontype.Type(0), bsontype.EmbeddedDocument: | ||||
| 	case bsontype.Null: | ||||
| 		if err := vr.ReadNull(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return nil | ||||
| 	case bsontype.Undefined: | ||||
| 		if err := vr.ReadUndefined(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 		return nil | ||||
| 	default: | ||||
| 		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) | ||||
| 	} | ||||
|  | ||||
| 	sd, err := sc.describeStruct(r.Registry, val.Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if sc.DecodeZeroStruct { | ||||
| 		val.Set(reflect.Zero(val.Type())) | ||||
| 	} | ||||
| 	if sc.DecodeDeepZeroInline && sd.inline { | ||||
| 		val.Set(deepZero(val.Type())) | ||||
| 	} | ||||
|  | ||||
| 	var decoder ValueDecoder | ||||
| 	var inlineMap reflect.Value | ||||
| 	if sd.inlineMap >= 0 { | ||||
| 		inlineMap = val.Field(sd.inlineMap) | ||||
| 		decoder, err = r.LookupDecoder(inlineMap.Type().Elem()) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	dr, err := vr.ReadDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		name, vr, err := dr.ReadElement() | ||||
| 		if err == bsonrw.ErrEOD { | ||||
| 			break | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		fd, exists := sd.fm[name] | ||||
| 		if !exists { | ||||
| 			// if the original name isn't found in the struct description, try again with the name in lowercase | ||||
| 			// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field | ||||
| 			// names | ||||
| 			fd, exists = sd.fm[strings.ToLower(name)] | ||||
| 		} | ||||
|  | ||||
| 		if !exists { | ||||
| 			if sd.inlineMap < 0 { | ||||
| 				// The encoding/json package requires a flag to return on error for non-existent fields. | ||||
| 				// This functionality seems appropriate for the struct codec. | ||||
| 				err = vr.Skip() | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			if inlineMap.IsNil() { | ||||
| 				inlineMap.Set(reflect.MakeMap(inlineMap.Type())) | ||||
| 			} | ||||
|  | ||||
| 			elem := reflect.New(inlineMap.Type().Elem()).Elem() | ||||
| 			r.Ancestor = inlineMap.Type() | ||||
| 			err = decoder.DecodeValue(r, vr, elem) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			inlineMap.SetMapIndex(reflect.ValueOf(name), elem) | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		var field reflect.Value | ||||
| 		if fd.inline == nil { | ||||
| 			field = val.Field(fd.idx) | ||||
| 		} else { | ||||
| 			field, err = getInlineField(val, fd.inline) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if !field.CanSet() { // Being settable is a super set of being addressable. | ||||
| 			innerErr := fmt.Errorf("field %v is not settable", field) | ||||
| 			return newDecodeError(fd.name, innerErr) | ||||
| 		} | ||||
| 		if field.Kind() == reflect.Ptr && field.IsNil() { | ||||
| 			field.Set(reflect.New(field.Type().Elem())) | ||||
| 		} | ||||
| 		field = field.Addr() | ||||
|  | ||||
| 		dctx := DecodeContext{ | ||||
| 			Registry:            r.Registry, | ||||
| 			Truncate:            fd.truncate || r.Truncate, | ||||
| 			defaultDocumentType: r.defaultDocumentType, | ||||
| 		} | ||||
|  | ||||
| 		if fd.decoder == nil { | ||||
| 			return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) | ||||
| 		} | ||||
|  | ||||
| 		err = fd.decoder.DecodeValue(dctx, vr, field.Elem()) | ||||
| 		if err != nil { | ||||
| 			return newDecodeError(fd.name, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (sc *StructCodec) isZero(i interface{}) bool { | ||||
| 	v := reflect.ValueOf(i) | ||||
|  | ||||
| 	// check the value validity | ||||
| 	if !v.IsValid() { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) { | ||||
| 		return z.IsZero() | ||||
| 	} | ||||
|  | ||||
| 	switch v.Kind() { | ||||
| 	case reflect.Array, reflect.Map, reflect.Slice, reflect.String: | ||||
| 		return v.Len() == 0 | ||||
| 	case reflect.Bool: | ||||
| 		return !v.Bool() | ||||
| 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||||
| 		return v.Int() == 0 | ||||
| 	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: | ||||
| 		return v.Uint() == 0 | ||||
| 	case reflect.Float32, reflect.Float64: | ||||
| 		return v.Float() == 0 | ||||
| 	case reflect.Interface, reflect.Ptr: | ||||
| 		return v.IsNil() | ||||
| 	case reflect.Struct: | ||||
| 		if sc.EncodeOmitDefaultStruct { | ||||
| 			vt := v.Type() | ||||
| 			if vt == tTime { | ||||
| 				return v.Interface().(time.Time).IsZero() | ||||
| 			} | ||||
| 			for i := 0; i < v.NumField(); i++ { | ||||
| 				if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous { | ||||
| 					continue // Private field | ||||
| 				} | ||||
| 				fld := v.Field(i) | ||||
| 				if !sc.isZero(fld.Interface()) { | ||||
| 					return false | ||||
| 				} | ||||
| 			} | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| type structDescription struct { | ||||
| 	fm        map[string]fieldDescription | ||||
| 	fl        []fieldDescription | ||||
| 	inlineMap int | ||||
| 	inline    bool | ||||
| } | ||||
|  | ||||
| type fieldDescription struct { | ||||
| 	name       string // BSON key name | ||||
| 	fieldName  string // struct field name | ||||
| 	idx        int | ||||
| 	omitEmpty  bool | ||||
| 	omitAlways bool | ||||
| 	minSize    bool | ||||
| 	truncate   bool | ||||
| 	inline     []int | ||||
| 	encoder    ValueEncoder | ||||
| 	decoder    ValueDecoder | ||||
| } | ||||
|  | ||||
| type byIndex []fieldDescription | ||||
|  | ||||
| func (bi byIndex) Len() int { return len(bi) } | ||||
|  | ||||
| func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] } | ||||
|  | ||||
| func (bi byIndex) Less(i, j int) bool { | ||||
| 	// If a field is inlined, its index in the top level struct is stored at inline[0] | ||||
| 	iIdx, jIdx := bi[i].idx, bi[j].idx | ||||
| 	if len(bi[i].inline) > 0 { | ||||
| 		iIdx = bi[i].inline[0] | ||||
| 	} | ||||
| 	if len(bi[j].inline) > 0 { | ||||
| 		jIdx = bi[j].inline[0] | ||||
| 	} | ||||
| 	if iIdx != jIdx { | ||||
| 		return iIdx < jIdx | ||||
| 	} | ||||
| 	for k, biik := range bi[i].inline { | ||||
| 		if k >= len(bi[j].inline) { | ||||
| 			return false | ||||
| 		} | ||||
| 		if biik != bi[j].inline[k] { | ||||
| 			return biik < bi[j].inline[k] | ||||
| 		} | ||||
| 	} | ||||
| 	return len(bi[i].inline) < len(bi[j].inline) | ||||
| } | ||||
|  | ||||
| func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) { | ||||
| 	// We need to analyze the struct, including getting the tags, collecting | ||||
| 	// information about inlining, and create a map of the field name to the field. | ||||
| 	sc.l.RLock() | ||||
| 	ds, exists := sc.cache[t] | ||||
| 	sc.l.RUnlock() | ||||
| 	if exists { | ||||
| 		return ds, nil | ||||
| 	} | ||||
|  | ||||
| 	numFields := t.NumField() | ||||
| 	sd := &structDescription{ | ||||
| 		fm:        make(map[string]fieldDescription, numFields), | ||||
| 		fl:        make([]fieldDescription, 0, numFields), | ||||
| 		inlineMap: -1, | ||||
| 	} | ||||
|  | ||||
| 	var fields []fieldDescription | ||||
| 	for i := 0; i < numFields; i++ { | ||||
| 		sf := t.Field(i) | ||||
| 		if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) { | ||||
| 			// field is private or unexported fields aren't allowed, ignore | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		sfType := sf.Type | ||||
| 		encoder, err := r.LookupEncoder(sfType) | ||||
| 		if err != nil { | ||||
| 			encoder = nil | ||||
| 		} | ||||
| 		decoder, err := r.LookupDecoder(sfType) | ||||
| 		if err != nil { | ||||
| 			decoder = nil | ||||
| 		} | ||||
|  | ||||
| 		description := fieldDescription{ | ||||
| 			fieldName: sf.Name, | ||||
| 			idx:       i, | ||||
| 			encoder:   encoder, | ||||
| 			decoder:   decoder, | ||||
| 		} | ||||
|  | ||||
| 		stags, err := sc.parser.ParseStructTags(sf) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if stags.Skip { | ||||
| 			continue | ||||
| 		} | ||||
| 		description.name = stags.Name | ||||
| 		description.omitEmpty = stags.OmitEmpty | ||||
| 		description.omitAlways = stags.OmitAlways | ||||
| 		description.minSize = stags.MinSize | ||||
| 		description.truncate = stags.Truncate | ||||
|  | ||||
| 		if stags.Inline { | ||||
| 			sd.inline = true | ||||
| 			switch sfType.Kind() { | ||||
| 			case reflect.Map: | ||||
| 				if sd.inlineMap >= 0 { | ||||
| 					return nil, errors.New("(struct " + t.String() + ") multiple inline maps") | ||||
| 				} | ||||
| 				if sfType.Key() != tString { | ||||
| 					return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys") | ||||
| 				} | ||||
| 				sd.inlineMap = description.idx | ||||
| 			case reflect.Ptr: | ||||
| 				sfType = sfType.Elem() | ||||
| 				if sfType.Kind() != reflect.Struct { | ||||
| 					return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) | ||||
| 				} | ||||
| 				fallthrough | ||||
| 			case reflect.Struct: | ||||
| 				inlinesf, err := sc.describeStruct(r, sfType) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
| 				for _, fd := range inlinesf.fl { | ||||
| 					if fd.inline == nil { | ||||
| 						fd.inline = []int{i, fd.idx} | ||||
| 					} else { | ||||
| 						fd.inline = append([]int{i}, fd.inline...) | ||||
| 					} | ||||
| 					fields = append(fields, fd) | ||||
|  | ||||
| 				} | ||||
| 			default: | ||||
| 				return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String()) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
| 		fields = append(fields, description) | ||||
| 	} | ||||
|  | ||||
| 	// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name | ||||
| 	sort.Slice(fields, func(i, j int) bool { | ||||
| 		x := fields | ||||
| 		// sort field by name, breaking ties with depth, then | ||||
| 		// breaking ties with index sequence. | ||||
| 		if x[i].name != x[j].name { | ||||
| 			return x[i].name < x[j].name | ||||
| 		} | ||||
| 		if len(x[i].inline) != len(x[j].inline) { | ||||
| 			return len(x[i].inline) < len(x[j].inline) | ||||
| 		} | ||||
| 		return byIndex(x).Less(i, j) | ||||
| 	}) | ||||
|  | ||||
| 	for advance, i := 0, 0; i < len(fields); i += advance { | ||||
| 		// One iteration per name. | ||||
| 		// Find the sequence of fields with the name of this first field. | ||||
| 		fi := fields[i] | ||||
| 		name := fi.name | ||||
| 		for advance = 1; i+advance < len(fields); advance++ { | ||||
| 			fj := fields[i+advance] | ||||
| 			if fj.name != name { | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		if advance == 1 { // Only one field with this name | ||||
| 			sd.fl = append(sd.fl, fi) | ||||
| 			sd.fm[name] = fi | ||||
| 			continue | ||||
| 		} | ||||
| 		dominant, ok := dominantField(fields[i : i+advance]) | ||||
| 		if !ok || !sc.OverwriteDuplicatedInlinedFields { | ||||
| 			return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name) | ||||
| 		} | ||||
| 		sd.fl = append(sd.fl, dominant) | ||||
| 		sd.fm[name] = dominant | ||||
| 	} | ||||
|  | ||||
| 	sort.Sort(byIndex(sd.fl)) | ||||
|  | ||||
| 	sc.l.Lock() | ||||
| 	sc.cache[t] = sd | ||||
| 	sc.l.Unlock() | ||||
|  | ||||
| 	return sd, nil | ||||
| } | ||||
|  | ||||
| // dominantField looks through the fields, all of which are known to | ||||
| // have the same name, to find the single field that dominates the | ||||
| // others using Go's inlining rules. If there are multiple top-level | ||||
| // fields, the boolean will be false: This condition is an error in Go | ||||
| // and we skip all the fields. | ||||
| func dominantField(fields []fieldDescription) (fieldDescription, bool) { | ||||
| 	// The fields are sorted in increasing index-length order, then by presence of tag. | ||||
| 	// That means that the first field is the dominant one. We need only check | ||||
| 	// for error cases: two fields at top level. | ||||
| 	if len(fields) > 1 && | ||||
| 		len(fields[0].inline) == len(fields[1].inline) { | ||||
| 		return fieldDescription{}, false | ||||
| 	} | ||||
| 	return fields[0], true | ||||
| } | ||||
|  | ||||
| func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) { | ||||
| 	defer func() { | ||||
| 		if recovered := recover(); recovered != nil { | ||||
| 			switch r := recovered.(type) { | ||||
| 			case string: | ||||
| 				err = fmt.Errorf("%s", r) | ||||
| 			case error: | ||||
| 				err = r | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	result = v.FieldByIndex(index) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func getInlineField(val reflect.Value, index []int) (reflect.Value, error) { | ||||
| 	field, err := fieldByIndexErr(val, index) | ||||
| 	if err == nil { | ||||
| 		return field, nil | ||||
| 	} | ||||
|  | ||||
| 	// if parent of this element doesn't exist, fix its parent | ||||
| 	inlineParent := index[:len(index)-1] | ||||
| 	var fParent reflect.Value | ||||
| 	if fParent, err = fieldByIndexErr(val, inlineParent); err != nil { | ||||
| 		fParent, err = getInlineField(val, inlineParent) | ||||
| 		if err != nil { | ||||
| 			return fParent, err | ||||
| 		} | ||||
| 	} | ||||
| 	fParent.Set(reflect.New(fParent.Type().Elem())) | ||||
|  | ||||
| 	return fieldByIndexErr(val, index) | ||||
| } | ||||
|  | ||||
| // DeepZero returns recursive zero object | ||||
| func deepZero(st reflect.Type) (result reflect.Value) { | ||||
| 	result = reflect.Indirect(reflect.New(st)) | ||||
|  | ||||
| 	if result.Kind() == reflect.Struct { | ||||
| 		for i := 0; i < result.NumField(); i++ { | ||||
| 			if f := result.Field(i); f.Kind() == reflect.Ptr { | ||||
| 				if f.CanInterface() { | ||||
| 					if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct { | ||||
| 						result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem()))) | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside | ||||
| func recursivePointerTo(v reflect.Value) reflect.Value { | ||||
| 	v = reflect.Indirect(v) | ||||
| 	result := reflect.New(v.Type()) | ||||
| 	if v.Kind() == reflect.Struct { | ||||
| 		for i := 0; i < v.NumField(); i++ { | ||||
| 			if f := v.Field(i); f.Kind() == reflect.Ptr { | ||||
| 				if f.Elem().Kind() == reflect.Struct { | ||||
| 					result.Elem().Field(i).Set(recursivePointerTo(f)) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return result | ||||
| } | ||||
							
								
								
									
										47
									
								
								mongo/bson/bsoncodec/struct_codec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								mongo/bson/bsoncodec/struct_codec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| func TestZeoerInterfaceUsedByDecoder(t *testing.T) { | ||||
| 	enc := &StructCodec{} | ||||
|  | ||||
| 	// cases that are zero, because they are known types or pointers | ||||
| 	var st *nonZeroer | ||||
| 	assert.True(t, enc.isZero(st)) | ||||
| 	assert.True(t, enc.isZero(0)) | ||||
| 	assert.True(t, enc.isZero(false)) | ||||
|  | ||||
| 	// cases that shouldn't be zero | ||||
| 	st = &nonZeroer{value: false} | ||||
| 	assert.False(t, enc.isZero(struct{ val bool }{val: true})) | ||||
| 	assert.False(t, enc.isZero(struct{ val bool }{val: false})) | ||||
| 	assert.False(t, enc.isZero(st)) | ||||
| 	st.value = true | ||||
| 	assert.False(t, enc.isZero(st)) | ||||
|  | ||||
| 	// a test to see if the interface impacts the outcome | ||||
| 	z := zeroTest{} | ||||
| 	assert.False(t, enc.isZero(z)) | ||||
|  | ||||
| 	z.reportZero = true | ||||
| 	assert.True(t, enc.isZero(z)) | ||||
|  | ||||
| 	// *time.Time with nil should be zero | ||||
| 	var tp *time.Time | ||||
| 	assert.True(t, enc.isZero(tp)) | ||||
|  | ||||
| 	// actually all zeroer if nil should also be zero | ||||
| 	var zp *zeroTest | ||||
| 	assert.True(t, enc.isZero(zp)) | ||||
| } | ||||
							
								
								
									
										142
									
								
								mongo/bson/bsoncodec/struct_tag_parser.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										142
									
								
								mongo/bson/bsoncodec/struct_tag_parser.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,142 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // StructTagParser returns the struct tags for a given struct field. | ||||
| type StructTagParser interface { | ||||
| 	ParseStructTags(reflect.StructField) (StructTags, error) | ||||
| } | ||||
|  | ||||
| // StructTagParserFunc is an adapter that allows a generic function to be used | ||||
| // as a StructTagParser. | ||||
| type StructTagParserFunc func(reflect.StructField) (StructTags, error) | ||||
|  | ||||
| // ParseStructTags implements the StructTagParser interface. | ||||
| func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) { | ||||
| 	return stpf(sf) | ||||
| } | ||||
|  | ||||
| // StructTags represents the struct tag fields that the StructCodec uses during | ||||
| // the encoding and decoding process. | ||||
| // | ||||
| // In the case of a struct, the lowercased field name is used as the key for each exported | ||||
| // field but this behavior may be changed using a struct tag. The tag may also contain flags to | ||||
| // adjust the marshalling behavior for the field. | ||||
| // | ||||
| // The properties are defined below: | ||||
| // | ||||
| //	OmitEmpty  Only include the field if it's not set to the zero value for the type or to | ||||
| //	           empty slices or maps. | ||||
| // | ||||
| //	MinSize    Marshal an integer of a type larger than 32 bits value as an int32, if that's | ||||
| //	           feasible while preserving the numeric value. | ||||
| // | ||||
| //	Truncate   When unmarshaling a BSON double, it is permitted to lose precision to fit within | ||||
| //	           a float32. | ||||
| // | ||||
| //	Inline     Inline the field, which must be a struct or a map, causing all of its fields | ||||
| //	           or keys to be processed as if they were part of the outer struct. For maps, | ||||
| //	           keys must not conflict with the bson keys of other struct fields. | ||||
| // | ||||
| //	Skip       This struct field should be skipped. This is usually denoted by parsing a "-" | ||||
| //	           for the name. | ||||
| // | ||||
| // TODO(skriptble): Add tags for undefined as nil and for null as nil. | ||||
| type StructTags struct { | ||||
| 	Name       string | ||||
| 	OmitEmpty  bool | ||||
| 	OmitAlways bool | ||||
| 	MinSize    bool | ||||
| 	Truncate   bool | ||||
| 	Inline     bool | ||||
| 	Skip       bool | ||||
| } | ||||
|  | ||||
| // DefaultStructTagParser is the StructTagParser used by the StructCodec by default. | ||||
| // It will handle the bson struct tag. See the documentation for StructTags to see | ||||
| // what each of the returned fields means. | ||||
| // | ||||
| // If there is no name in the struct tag fields, the struct field name is lowercased. | ||||
| // The tag formats accepted are: | ||||
| // | ||||
| //	"[<key>][,<flag1>[,<flag2>]]" | ||||
| // | ||||
| //	`(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)` | ||||
| // | ||||
| // An example: | ||||
| // | ||||
| //	type T struct { | ||||
| //	    A bool | ||||
| //	    B int    "myb" | ||||
| //	    C string "myc,omitempty" | ||||
| //	    D string `bson:",omitempty" json:"jsonkey"` | ||||
| //	    E int64  ",minsize" | ||||
| //	    F int64  "myf,omitempty,minsize" | ||||
| //	} | ||||
| // | ||||
| // A struct tag either consisting entirely of '-' or with a bson key with a | ||||
| // value consisting entirely of '-' will return a StructTags with Skip true and | ||||
| // the remaining fields will be their default values. | ||||
| var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { | ||||
| 	key := strings.ToLower(sf.Name) | ||||
| 	tag, ok := sf.Tag.Lookup("bson") | ||||
| 	if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { | ||||
| 		tag = string(sf.Tag) | ||||
| 	} | ||||
| 	return parseTags(key, tag) | ||||
| } | ||||
|  | ||||
| func parseTags(key string, tag string) (StructTags, error) { | ||||
| 	var st StructTags | ||||
| 	if tag == "-" { | ||||
| 		st.Skip = true | ||||
| 		return st, nil | ||||
| 	} | ||||
|  | ||||
| 	for idx, str := range strings.Split(tag, ",") { | ||||
| 		if idx == 0 && str != "" { | ||||
| 			key = str | ||||
| 		} | ||||
| 		switch str { | ||||
| 		case "omitempty": | ||||
| 			st.OmitEmpty = true | ||||
| 		case "omitalways": | ||||
| 			st.OmitAlways = true | ||||
| 		case "minsize": | ||||
| 			st.MinSize = true | ||||
| 		case "truncate": | ||||
| 			st.Truncate = true | ||||
| 		case "inline": | ||||
| 			st.Inline = true | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	st.Name = key | ||||
|  | ||||
| 	return st, nil | ||||
| } | ||||
|  | ||||
| // JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser | ||||
| // but will also fallback to parsing the json tag instead on a field where the | ||||
| // bson tag isn't available. | ||||
| var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) { | ||||
| 	key := strings.ToLower(sf.Name) | ||||
| 	tag, ok := sf.Tag.Lookup("bson") | ||||
| 	if !ok { | ||||
| 		tag, ok = sf.Tag.Lookup("json") | ||||
| 	} | ||||
| 	if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 { | ||||
| 		tag = string(sf.Tag) | ||||
| 	} | ||||
|  | ||||
| 	return parseTags(key, tag) | ||||
| } | ||||
							
								
								
									
										160
									
								
								mongo/bson/bsoncodec/struct_tag_parser_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								mongo/bson/bsoncodec/struct_tag_parser_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,160 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| ) | ||||
|  | ||||
| func TestStructTagParsers(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name   string | ||||
| 		sf     reflect.StructField | ||||
| 		want   StructTags | ||||
| 		parser StructTagParserFunc | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"default no bson tag", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, | ||||
| 			StructTags{Name: "bar"}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default empty", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, | ||||
| 			StructTags{Name: "foo"}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default tag only dash", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, | ||||
| 			StructTags{Skip: true}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default bson tag only dash", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, | ||||
| 			StructTags{Skip: true}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default all options", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, | ||||
| 			StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default all options default name", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, | ||||
| 			StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default bson tag all options", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, | ||||
| 			StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default bson tag all options default name", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, | ||||
| 			StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"default ignore xml", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, | ||||
| 			StructTags{Name: "foo"}, | ||||
| 			DefaultStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback no bson tag", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")}, | ||||
| 			StructTags{Name: "bar"}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback empty", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag("")}, | ||||
| 			StructTags{Name: "foo"}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback tag only dash", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")}, | ||||
| 			StructTags{Skip: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback bson tag only dash", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)}, | ||||
| 			StructTags{Skip: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback all options", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)}, | ||||
| 			StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback all options default name", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)}, | ||||
| 			StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback bson tag all options", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)}, | ||||
| 			StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback bson tag all options default name", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)}, | ||||
| 			StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback json tag all options", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:"bar,omitempty,minsize,truncate,inline"`)}, | ||||
| 			StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback json tag all options default name", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:",omitempty,minsize,truncate,inline"`)}, | ||||
| 			StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback bson tag overrides other tags", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar" json:"qux,truncate"`)}, | ||||
| 			StructTags{Name: "bar"}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"JSONFallback ignore xml", | ||||
| 			reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)}, | ||||
| 			StructTags{Name: "foo"}, | ||||
| 			JSONFallbackStructTagParser, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			got, err := tc.parser(tc.sf) | ||||
| 			noerr(t, err) | ||||
| 			if !cmp.Equal(got, tc.want) { | ||||
| 				t.Errorf("Returned struct tags do not match. got %#v; want %#v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										127
									
								
								mongo/bson/bsoncodec/time_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								mongo/bson/bsoncodec/time_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,127 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	timeFormatString = "2006-01-02T15:04:05.999Z07:00" | ||||
| ) | ||||
|  | ||||
| // TimeCodec is the Codec used for time.Time values. | ||||
| type TimeCodec struct { | ||||
| 	UseLocalTimeZone bool | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	defaultTimeCodec = NewTimeCodec() | ||||
|  | ||||
| 	_ ValueCodec  = defaultTimeCodec | ||||
| 	_ typeDecoder = defaultTimeCodec | ||||
| ) | ||||
|  | ||||
| // NewTimeCodec returns a TimeCodec with options opts. | ||||
| func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec { | ||||
| 	timeOpt := bsonoptions.MergeTimeCodecOptions(opts...) | ||||
|  | ||||
| 	codec := TimeCodec{} | ||||
| 	if timeOpt.UseLocalTimeZone != nil { | ||||
| 		codec.UseLocalTimeZone = *timeOpt.UseLocalTimeZone | ||||
| 	} | ||||
| 	return &codec | ||||
| } | ||||
|  | ||||
| func (tc *TimeCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	if t != tTime { | ||||
| 		return emptyValue, ValueDecoderError{ | ||||
| 			Name:     "TimeDecodeValue", | ||||
| 			Types:    []reflect.Type{tTime}, | ||||
| 			Received: reflect.Zero(t), | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	var timeVal time.Time | ||||
| 	switch vrType := vr.Type(); vrType { | ||||
| 	case bsontype.DateTime: | ||||
| 		dt, err := vr.ReadDateTime() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		timeVal = time.Unix(dt/1000, dt%1000*1000000) | ||||
| 	case bsontype.String: | ||||
| 		// assume strings are in the isoTimeFormat | ||||
| 		timeStr, err := vr.ReadString() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		timeVal, err = time.Parse(timeFormatString, timeStr) | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.Int64: | ||||
| 		i64, err := vr.ReadInt64() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		timeVal = time.Unix(i64/1000, i64%1000*1000000) | ||||
| 	case bsontype.Timestamp: | ||||
| 		t, _, err := vr.ReadTimestamp() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		timeVal = time.Unix(int64(t), 0) | ||||
| 	case bsontype.Null: | ||||
| 		if err := vr.ReadNull(); err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.Undefined: | ||||
| 		if err := vr.ReadUndefined(); err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	default: | ||||
| 		return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType) | ||||
| 	} | ||||
|  | ||||
| 	if !tc.UseLocalTimeZone { | ||||
| 		timeVal = timeVal.UTC() | ||||
| 	} | ||||
| 	return reflect.ValueOf(timeVal), nil | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoderFunc for time.Time. | ||||
| func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() || val.Type() != tTime { | ||||
| 		return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} | ||||
| 	} | ||||
|  | ||||
| 	elem, err := tc.decodeType(dc, vr, tTime) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	val.Set(elem) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoderFunc for time.TIme. | ||||
| func (tc *TimeCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	if !val.IsValid() || val.Type() != tTime { | ||||
| 		return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} | ||||
| 	} | ||||
| 	tt := val.Interface().(time.Time) | ||||
| 	dt := primitive.NewDateTimeFromTime(tt) | ||||
| 	return vw.WriteDateTime(int64(dt)) | ||||
| } | ||||
							
								
								
									
										79
									
								
								mongo/bson/bsoncodec/time_codec_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								mongo/bson/bsoncodec/time_codec_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| func TestTimeCodec(t *testing.T) { | ||||
| 	now := time.Now().Truncate(time.Millisecond) | ||||
|  | ||||
| 	t.Run("UseLocalTimeZone", func(t *testing.T) { | ||||
| 		reader := &bsonrwtest.ValueReaderWriter{BSONType: bsontype.DateTime, Return: now.UnixNano() / int64(time.Millisecond)} | ||||
| 		testCases := []struct { | ||||
| 			name string | ||||
| 			opts *bsonoptions.TimeCodecOptions | ||||
| 			utc  bool | ||||
| 		}{ | ||||
| 			{"default", bsonoptions.TimeCodec(), true}, | ||||
| 			{"false", bsonoptions.TimeCodec().SetUseLocalTimeZone(false), true}, | ||||
| 			{"true", bsonoptions.TimeCodec().SetUseLocalTimeZone(true), false}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				timeCodec := NewTimeCodec(tc.opts) | ||||
|  | ||||
| 				actual := reflect.New(reflect.TypeOf(now)).Elem() | ||||
| 				err := timeCodec.DecodeValue(DecodeContext{}, reader, actual) | ||||
| 				assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err) | ||||
|  | ||||
| 				actualTime := actual.Interface().(time.Time) | ||||
| 				assert.Equal(t, actualTime.Location().String() == "UTC", tc.utc, | ||||
| 					"Expected UTC: %v, got %v", tc.utc, actualTime.Location()) | ||||
| 				assert.Equal(t, now, actualTime, "expected time %v, got %v", now, actualTime) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("DecodeFromBsontype", func(t *testing.T) { | ||||
| 		testCases := []struct { | ||||
| 			name   string | ||||
| 			reader *bsonrwtest.ValueReaderWriter | ||||
| 		}{ | ||||
| 			{"string", &bsonrwtest.ValueReaderWriter{BSONType: bsontype.String, Return: now.Format(timeFormatString)}}, | ||||
| 			{"int64", &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Int64, Return: now.Unix()*1000 + int64(now.Nanosecond()/1e6)}}, | ||||
| 			{"timestamp", &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Timestamp, | ||||
| 				Return: bsoncore.Value{ | ||||
| 					Type: bsontype.Timestamp, | ||||
| 					Data: bsoncore.AppendTimestamp(nil, uint32(now.Unix()), 0), | ||||
| 				}}, | ||||
| 			}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				actual := reflect.New(reflect.TypeOf(now)).Elem() | ||||
| 				err := defaultTimeCodec.DecodeValue(DecodeContext{}, tc.reader, actual) | ||||
| 				assert.Nil(t, err, "DecodeValue error: %v", err) | ||||
|  | ||||
| 				actualTime := actual.Interface().(time.Time) | ||||
| 				if tc.name == "timestamp" { | ||||
| 					now = time.Unix(now.Unix(), 0) | ||||
| 				} | ||||
| 				assert.Equal(t, now, actualTime, "expected time %v, got %v", now, actualTime) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										57
									
								
								mongo/bson/bsoncodec/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								mongo/bson/bsoncodec/types.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"net/url" | ||||
| 	"reflect" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| var tBool = reflect.TypeOf(false) | ||||
| var tFloat64 = reflect.TypeOf(float64(0)) | ||||
| var tInt32 = reflect.TypeOf(int32(0)) | ||||
| var tInt64 = reflect.TypeOf(int64(0)) | ||||
| var tString = reflect.TypeOf("") | ||||
| var tTime = reflect.TypeOf(time.Time{}) | ||||
|  | ||||
| var tEmpty = reflect.TypeOf((*interface{})(nil)).Elem() | ||||
| var tByteSlice = reflect.TypeOf([]byte(nil)) | ||||
| var tByte = reflect.TypeOf(byte(0x00)) | ||||
| var tURL = reflect.TypeOf(url.URL{}) | ||||
| var tJSONNumber = reflect.TypeOf(json.Number("")) | ||||
|  | ||||
| var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem() | ||||
| var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem() | ||||
| var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem() | ||||
| var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem() | ||||
| var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem() | ||||
|  | ||||
| var tBinary = reflect.TypeOf(primitive.Binary{}) | ||||
| var tUndefined = reflect.TypeOf(primitive.Undefined{}) | ||||
| var tOID = reflect.TypeOf(primitive.ObjectID{}) | ||||
| var tDateTime = reflect.TypeOf(primitive.DateTime(0)) | ||||
| var tNull = reflect.TypeOf(primitive.Null{}) | ||||
| var tRegex = reflect.TypeOf(primitive.Regex{}) | ||||
| var tCodeWithScope = reflect.TypeOf(primitive.CodeWithScope{}) | ||||
| var tDBPointer = reflect.TypeOf(primitive.DBPointer{}) | ||||
| var tJavaScript = reflect.TypeOf(primitive.JavaScript("")) | ||||
| var tSymbol = reflect.TypeOf(primitive.Symbol("")) | ||||
| var tTimestamp = reflect.TypeOf(primitive.Timestamp{}) | ||||
| var tDecimal = reflect.TypeOf(primitive.Decimal128{}) | ||||
| var tMinKey = reflect.TypeOf(primitive.MinKey{}) | ||||
| var tMaxKey = reflect.TypeOf(primitive.MaxKey{}) | ||||
| var tD = reflect.TypeOf(primitive.D{}) | ||||
| var tA = reflect.TypeOf(primitive.A{}) | ||||
| var tE = reflect.TypeOf(primitive.E{}) | ||||
|  | ||||
| var tCoreDocument = reflect.TypeOf(bsoncore.Document{}) | ||||
| var tCoreArray = reflect.TypeOf(bsoncore.Array{}) | ||||
							
								
								
									
										173
									
								
								mongo/bson/bsoncodec/uint_codec.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								mongo/bson/bsoncodec/uint_codec.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsoncodec | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"reflect" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonoptions" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| // UIntCodec is the Codec used for uint values. | ||||
| type UIntCodec struct { | ||||
| 	EncodeToMinSize bool | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	defaultUIntCodec = NewUIntCodec() | ||||
|  | ||||
| 	_ ValueCodec  = defaultUIntCodec | ||||
| 	_ typeDecoder = defaultUIntCodec | ||||
| ) | ||||
|  | ||||
| // NewUIntCodec returns a UIntCodec with options opts. | ||||
| func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec { | ||||
| 	uintOpt := bsonoptions.MergeUIntCodecOptions(opts...) | ||||
|  | ||||
| 	codec := UIntCodec{} | ||||
| 	if uintOpt.EncodeToMinSize != nil { | ||||
| 		codec.EncodeToMinSize = *uintOpt.EncodeToMinSize | ||||
| 	} | ||||
| 	return &codec | ||||
| } | ||||
|  | ||||
| // EncodeValue is the ValueEncoder for uint types. | ||||
| func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 	switch val.Kind() { | ||||
| 	case reflect.Uint8, reflect.Uint16: | ||||
| 		return vw.WriteInt32(int32(val.Uint())) | ||||
| 	case reflect.Uint, reflect.Uint32, reflect.Uint64: | ||||
| 		u64 := val.Uint() | ||||
|  | ||||
| 		// If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32 | ||||
| 		useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64) | ||||
|  | ||||
| 		if u64 <= math.MaxInt32 && useMinSize { | ||||
| 			return vw.WriteInt32(int32(u64)) | ||||
| 		} | ||||
| 		if u64 > math.MaxInt64 { | ||||
| 			return fmt.Errorf("%d overflows int64", u64) | ||||
| 		} | ||||
| 		return vw.WriteInt64(int64(u64)) | ||||
| 	} | ||||
|  | ||||
| 	return ValueEncoderError{ | ||||
| 		Name:     "UintEncodeValue", | ||||
| 		Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, | ||||
| 		Received: val, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) { | ||||
| 	var i64 int64 | ||||
| 	var err error | ||||
| 	switch vrType := vr.Type(); vrType { | ||||
| 	case bsontype.Int32: | ||||
| 		i32, err := vr.ReadInt32() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		i64 = int64(i32) | ||||
| 	case bsontype.Int64: | ||||
| 		i64, err = vr.ReadInt64() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.Double: | ||||
| 		f64, err := vr.ReadDouble() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		if !dc.Truncate && math.Floor(f64) != f64 { | ||||
| 			return emptyValue, errCannotTruncate | ||||
| 		} | ||||
| 		if f64 > float64(math.MaxInt64) { | ||||
| 			return emptyValue, fmt.Errorf("%g overflows int64", f64) | ||||
| 		} | ||||
| 		i64 = int64(f64) | ||||
| 	case bsontype.Boolean: | ||||
| 		b, err := vr.ReadBoolean() | ||||
| 		if err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 		if b { | ||||
| 			i64 = 1 | ||||
| 		} | ||||
| 	case bsontype.Null: | ||||
| 		if err = vr.ReadNull(); err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	case bsontype.Undefined: | ||||
| 		if err = vr.ReadUndefined(); err != nil { | ||||
| 			return emptyValue, err | ||||
| 		} | ||||
| 	default: | ||||
| 		return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType) | ||||
| 	} | ||||
|  | ||||
| 	switch t.Kind() { | ||||
| 	case reflect.Uint8: | ||||
| 		if i64 < 0 || i64 > math.MaxUint8 { | ||||
| 			return emptyValue, fmt.Errorf("%d overflows uint8", i64) | ||||
| 		} | ||||
|  | ||||
| 		return reflect.ValueOf(uint8(i64)), nil | ||||
| 	case reflect.Uint16: | ||||
| 		if i64 < 0 || i64 > math.MaxUint16 { | ||||
| 			return emptyValue, fmt.Errorf("%d overflows uint16", i64) | ||||
| 		} | ||||
|  | ||||
| 		return reflect.ValueOf(uint16(i64)), nil | ||||
| 	case reflect.Uint32: | ||||
| 		if i64 < 0 || i64 > math.MaxUint32 { | ||||
| 			return emptyValue, fmt.Errorf("%d overflows uint32", i64) | ||||
| 		} | ||||
|  | ||||
| 		return reflect.ValueOf(uint32(i64)), nil | ||||
| 	case reflect.Uint64: | ||||
| 		if i64 < 0 { | ||||
| 			return emptyValue, fmt.Errorf("%d overflows uint64", i64) | ||||
| 		} | ||||
|  | ||||
| 		return reflect.ValueOf(uint64(i64)), nil | ||||
| 	case reflect.Uint: | ||||
| 		if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint | ||||
| 			return emptyValue, fmt.Errorf("%d overflows uint", i64) | ||||
| 		} | ||||
|  | ||||
| 		return reflect.ValueOf(uint(i64)), nil | ||||
| 	default: | ||||
| 		return emptyValue, ValueDecoderError{ | ||||
| 			Name:     "UintDecodeValue", | ||||
| 			Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, | ||||
| 			Received: reflect.Zero(t), | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // DecodeValue is the ValueDecoder for uint types. | ||||
| func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error { | ||||
| 	if !val.CanSet() { | ||||
| 		return ValueDecoderError{ | ||||
| 			Name:     "UintDecodeValue", | ||||
| 			Kinds:    []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, | ||||
| 			Received: val, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	elem, err := uic.decodeType(dc, vr, val.Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	val.SetUint(elem.Uint()) | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										38
									
								
								mongo/bson/bsonoptions/byte_slice_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mongo/bson/bsonoptions/byte_slice_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| // ByteSliceCodecOptions represents all possible options for byte slice encoding and decoding. | ||||
| type ByteSliceCodecOptions struct { | ||||
| 	EncodeNilAsEmpty *bool // Specifies if a nil byte slice should encode as an empty binary instead of null. Defaults to false. | ||||
| } | ||||
|  | ||||
| // ByteSliceCodec creates a new *ByteSliceCodecOptions | ||||
| func ByteSliceCodec() *ByteSliceCodecOptions { | ||||
| 	return &ByteSliceCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetEncodeNilAsEmpty specifies  if a nil byte slice should encode as an empty binary instead of null. Defaults to false. | ||||
| func (bs *ByteSliceCodecOptions) SetEncodeNilAsEmpty(b bool) *ByteSliceCodecOptions { | ||||
| 	bs.EncodeNilAsEmpty = &b | ||||
| 	return bs | ||||
| } | ||||
|  | ||||
| // MergeByteSliceCodecOptions combines the given *ByteSliceCodecOptions into a single *ByteSliceCodecOptions in a last one wins fashion. | ||||
| func MergeByteSliceCodecOptions(opts ...*ByteSliceCodecOptions) *ByteSliceCodecOptions { | ||||
| 	bs := ByteSliceCodec() | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.EncodeNilAsEmpty != nil { | ||||
| 			bs.EncodeNilAsEmpty = opt.EncodeNilAsEmpty | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return bs | ||||
| } | ||||
							
								
								
									
										8
									
								
								mongo/bson/bsonoptions/doc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								mongo/bson/bsonoptions/doc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,8 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2022-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| // Package bsonoptions defines the optional configurations for the BSON codecs. | ||||
| package bsonoptions | ||||
							
								
								
									
										38
									
								
								mongo/bson/bsonoptions/empty_interface_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mongo/bson/bsonoptions/empty_interface_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| // EmptyInterfaceCodecOptions represents all possible options for interface{} encoding and decoding. | ||||
| type EmptyInterfaceCodecOptions struct { | ||||
| 	DecodeBinaryAsSlice *bool // Specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. | ||||
| } | ||||
|  | ||||
| // EmptyInterfaceCodec creates a new *EmptyInterfaceCodecOptions | ||||
| func EmptyInterfaceCodec() *EmptyInterfaceCodecOptions { | ||||
| 	return &EmptyInterfaceCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetDecodeBinaryAsSlice specifies if Old and Generic type binarys should default to []slice instead of primitive.Binary. Defaults to false. | ||||
| func (e *EmptyInterfaceCodecOptions) SetDecodeBinaryAsSlice(b bool) *EmptyInterfaceCodecOptions { | ||||
| 	e.DecodeBinaryAsSlice = &b | ||||
| 	return e | ||||
| } | ||||
|  | ||||
| // MergeEmptyInterfaceCodecOptions combines the given *EmptyInterfaceCodecOptions into a single *EmptyInterfaceCodecOptions in a last one wins fashion. | ||||
| func MergeEmptyInterfaceCodecOptions(opts ...*EmptyInterfaceCodecOptions) *EmptyInterfaceCodecOptions { | ||||
| 	e := EmptyInterfaceCodec() | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.DecodeBinaryAsSlice != nil { | ||||
| 			e.DecodeBinaryAsSlice = opt.DecodeBinaryAsSlice | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return e | ||||
| } | ||||
							
								
								
									
										67
									
								
								mongo/bson/bsonoptions/map_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								mongo/bson/bsonoptions/map_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| // MapCodecOptions represents all possible options for map encoding and decoding. | ||||
| type MapCodecOptions struct { | ||||
| 	DecodeZerosMap   *bool // Specifies if the map should be zeroed before decoding into it. Defaults to false. | ||||
| 	EncodeNilAsEmpty *bool // Specifies if a nil map should encode as an empty document instead of null. Defaults to false. | ||||
| 	// Specifies how keys should be handled. If false, the behavior matches encoding/json, where the encoding key type must | ||||
| 	// either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key type must either be a | ||||
| 	// string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with fmt.Sprint() and the | ||||
| 	// encoding key type must be a string, an integer type, or a float. If true, the use of Stringer will override | ||||
| 	// TextMarshaler/TextUnmarshaler. Defaults to false. | ||||
| 	EncodeKeysWithStringer *bool | ||||
| } | ||||
|  | ||||
| // MapCodec creates a new *MapCodecOptions | ||||
| func MapCodec() *MapCodecOptions { | ||||
| 	return &MapCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetDecodeZerosMap specifies if the map should be zeroed before decoding into it. Defaults to false. | ||||
| func (t *MapCodecOptions) SetDecodeZerosMap(b bool) *MapCodecOptions { | ||||
| 	t.DecodeZerosMap = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // SetEncodeNilAsEmpty specifies if a nil map should encode as an empty document instead of null. Defaults to false. | ||||
| func (t *MapCodecOptions) SetEncodeNilAsEmpty(b bool) *MapCodecOptions { | ||||
| 	t.EncodeNilAsEmpty = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // SetEncodeKeysWithStringer specifies how keys should be handled. If false, the behavior matches encoding/json, where the | ||||
| // encoding key type must either be a string, an integer type, or implement bsoncodec.KeyMarshaler and the decoding key | ||||
| // type must either be a string, an integer type, or implement bsoncodec.KeyUnmarshaler. If true, keys are encoded with | ||||
| // fmt.Sprint() and the encoding key type must be a string, an integer type, or a float. If true, the use of Stringer | ||||
| // will override TextMarshaler/TextUnmarshaler. Defaults to false. | ||||
| func (t *MapCodecOptions) SetEncodeKeysWithStringer(b bool) *MapCodecOptions { | ||||
| 	t.EncodeKeysWithStringer = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // MergeMapCodecOptions combines the given *MapCodecOptions into a single *MapCodecOptions in a last one wins fashion. | ||||
| func MergeMapCodecOptions(opts ...*MapCodecOptions) *MapCodecOptions { | ||||
| 	s := MapCodec() | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.DecodeZerosMap != nil { | ||||
| 			s.DecodeZerosMap = opt.DecodeZerosMap | ||||
| 		} | ||||
| 		if opt.EncodeNilAsEmpty != nil { | ||||
| 			s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty | ||||
| 		} | ||||
| 		if opt.EncodeKeysWithStringer != nil { | ||||
| 			s.EncodeKeysWithStringer = opt.EncodeKeysWithStringer | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
							
								
								
									
										38
									
								
								mongo/bson/bsonoptions/slice_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mongo/bson/bsonoptions/slice_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| // SliceCodecOptions represents all possible options for slice encoding and decoding. | ||||
| type SliceCodecOptions struct { | ||||
| 	EncodeNilAsEmpty *bool // Specifies if a nil slice should encode as an empty array instead of null. Defaults to false. | ||||
| } | ||||
|  | ||||
| // SliceCodec creates a new *SliceCodecOptions | ||||
| func SliceCodec() *SliceCodecOptions { | ||||
| 	return &SliceCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetEncodeNilAsEmpty specifies  if a nil slice should encode as an empty array instead of null. Defaults to false. | ||||
| func (s *SliceCodecOptions) SetEncodeNilAsEmpty(b bool) *SliceCodecOptions { | ||||
| 	s.EncodeNilAsEmpty = &b | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| // MergeSliceCodecOptions combines the given *SliceCodecOptions into a single *SliceCodecOptions in a last one wins fashion. | ||||
| func MergeSliceCodecOptions(opts ...*SliceCodecOptions) *SliceCodecOptions { | ||||
| 	s := SliceCodec() | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.EncodeNilAsEmpty != nil { | ||||
| 			s.EncodeNilAsEmpty = opt.EncodeNilAsEmpty | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
							
								
								
									
										41
									
								
								mongo/bson/bsonoptions/string_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								mongo/bson/bsonoptions/string_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,41 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| var defaultDecodeOIDAsHex = true | ||||
|  | ||||
| // StringCodecOptions represents all possible options for string encoding and decoding. | ||||
| type StringCodecOptions struct { | ||||
| 	DecodeObjectIDAsHex *bool // Specifies if we should decode ObjectID as the hex value. Defaults to true. | ||||
| } | ||||
|  | ||||
| // StringCodec creates a new *StringCodecOptions | ||||
| func StringCodec() *StringCodecOptions { | ||||
| 	return &StringCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetDecodeObjectIDAsHex specifies if object IDs should be decoded as their hex representation. If false, a string made | ||||
| // from the raw object ID bytes will be used. Defaults to true. | ||||
| func (t *StringCodecOptions) SetDecodeObjectIDAsHex(b bool) *StringCodecOptions { | ||||
| 	t.DecodeObjectIDAsHex = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // MergeStringCodecOptions combines the given *StringCodecOptions into a single *StringCodecOptions in a last one wins fashion. | ||||
| func MergeStringCodecOptions(opts ...*StringCodecOptions) *StringCodecOptions { | ||||
| 	s := &StringCodecOptions{&defaultDecodeOIDAsHex} | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.DecodeObjectIDAsHex != nil { | ||||
| 			s.DecodeObjectIDAsHex = opt.DecodeObjectIDAsHex | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
							
								
								
									
										87
									
								
								mongo/bson/bsonoptions/struct_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								mongo/bson/bsonoptions/struct_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| var defaultOverwriteDuplicatedInlinedFields = true | ||||
|  | ||||
| // StructCodecOptions represents all possible options for struct encoding and decoding. | ||||
| type StructCodecOptions struct { | ||||
| 	DecodeZeroStruct                 *bool // Specifies if structs should be zeroed before decoding into them. Defaults to false. | ||||
| 	DecodeDeepZeroInline             *bool // Specifies if structs should be recursively zeroed when a inline value is decoded. Defaults to false. | ||||
| 	EncodeOmitDefaultStruct          *bool // Specifies if default structs should be considered empty by omitempty. Defaults to false. | ||||
| 	AllowUnexportedFields            *bool // Specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. | ||||
| 	OverwriteDuplicatedInlinedFields *bool // Specifies if fields in inlined structs can be overwritten by higher level struct fields with the same key. Defaults to true. | ||||
| } | ||||
|  | ||||
| // StructCodec creates a new *StructCodecOptions | ||||
| func StructCodec() *StructCodecOptions { | ||||
| 	return &StructCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetDecodeZeroStruct specifies if structs should be zeroed before decoding into them. Defaults to false. | ||||
| func (t *StructCodecOptions) SetDecodeZeroStruct(b bool) *StructCodecOptions { | ||||
| 	t.DecodeZeroStruct = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // SetDecodeDeepZeroInline specifies if structs should be zeroed before decoding into them. Defaults to false. | ||||
| func (t *StructCodecOptions) SetDecodeDeepZeroInline(b bool) *StructCodecOptions { | ||||
| 	t.DecodeDeepZeroInline = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // SetEncodeOmitDefaultStruct specifies if default structs should be considered empty by omitempty. A default struct has all | ||||
| // its values set to their default value. Defaults to false. | ||||
| func (t *StructCodecOptions) SetEncodeOmitDefaultStruct(b bool) *StructCodecOptions { | ||||
| 	t.EncodeOmitDefaultStruct = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // SetOverwriteDuplicatedInlinedFields specifies if inlined struct fields can be overwritten by higher level struct fields with the | ||||
| // same bson key. When true and decoding, values will be written to the outermost struct with a matching key, and when | ||||
| // encoding, keys will have the value of the top-most matching field. When false, decoding and encoding will error if | ||||
| // there are duplicate keys after the struct is inlined. Defaults to true. | ||||
| func (t *StructCodecOptions) SetOverwriteDuplicatedInlinedFields(b bool) *StructCodecOptions { | ||||
| 	t.OverwriteDuplicatedInlinedFields = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // SetAllowUnexportedFields specifies if unexported fields should be marshaled/unmarshaled. Defaults to false. | ||||
| func (t *StructCodecOptions) SetAllowUnexportedFields(b bool) *StructCodecOptions { | ||||
| 	t.AllowUnexportedFields = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // MergeStructCodecOptions combines the given *StructCodecOptions into a single *StructCodecOptions in a last one wins fashion. | ||||
| func MergeStructCodecOptions(opts ...*StructCodecOptions) *StructCodecOptions { | ||||
| 	s := &StructCodecOptions{ | ||||
| 		OverwriteDuplicatedInlinedFields: &defaultOverwriteDuplicatedInlinedFields, | ||||
| 	} | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if opt.DecodeZeroStruct != nil { | ||||
| 			s.DecodeZeroStruct = opt.DecodeZeroStruct | ||||
| 		} | ||||
| 		if opt.DecodeDeepZeroInline != nil { | ||||
| 			s.DecodeDeepZeroInline = opt.DecodeDeepZeroInline | ||||
| 		} | ||||
| 		if opt.EncodeOmitDefaultStruct != nil { | ||||
| 			s.EncodeOmitDefaultStruct = opt.EncodeOmitDefaultStruct | ||||
| 		} | ||||
| 		if opt.OverwriteDuplicatedInlinedFields != nil { | ||||
| 			s.OverwriteDuplicatedInlinedFields = opt.OverwriteDuplicatedInlinedFields | ||||
| 		} | ||||
| 		if opt.AllowUnexportedFields != nil { | ||||
| 			s.AllowUnexportedFields = opt.AllowUnexportedFields | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
							
								
								
									
										38
									
								
								mongo/bson/bsonoptions/time_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mongo/bson/bsonoptions/time_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| // TimeCodecOptions represents all possible options for time.Time encoding and decoding. | ||||
| type TimeCodecOptions struct { | ||||
| 	UseLocalTimeZone *bool // Specifies if we should decode into the local time zone. Defaults to false. | ||||
| } | ||||
|  | ||||
| // TimeCodec creates a new *TimeCodecOptions | ||||
| func TimeCodec() *TimeCodecOptions { | ||||
| 	return &TimeCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetUseLocalTimeZone specifies if we should decode into the local time zone. Defaults to false. | ||||
| func (t *TimeCodecOptions) SetUseLocalTimeZone(b bool) *TimeCodecOptions { | ||||
| 	t.UseLocalTimeZone = &b | ||||
| 	return t | ||||
| } | ||||
|  | ||||
| // MergeTimeCodecOptions combines the given *TimeCodecOptions into a single *TimeCodecOptions in a last one wins fashion. | ||||
| func MergeTimeCodecOptions(opts ...*TimeCodecOptions) *TimeCodecOptions { | ||||
| 	t := TimeCodec() | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.UseLocalTimeZone != nil { | ||||
| 			t.UseLocalTimeZone = opt.UseLocalTimeZone | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return t | ||||
| } | ||||
							
								
								
									
										38
									
								
								mongo/bson/bsonoptions/uint_codec_options.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mongo/bson/bsonoptions/uint_codec_options.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonoptions | ||||
|  | ||||
| // UIntCodecOptions represents all possible options for uint encoding and decoding. | ||||
| type UIntCodecOptions struct { | ||||
| 	EncodeToMinSize *bool // Specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. | ||||
| } | ||||
|  | ||||
| // UIntCodec creates a new *UIntCodecOptions | ||||
| func UIntCodec() *UIntCodecOptions { | ||||
| 	return &UIntCodecOptions{} | ||||
| } | ||||
|  | ||||
| // SetEncodeToMinSize specifies if all uints except uint64 should be decoded to minimum size bsontype. Defaults to false. | ||||
| func (u *UIntCodecOptions) SetEncodeToMinSize(b bool) *UIntCodecOptions { | ||||
| 	u.EncodeToMinSize = &b | ||||
| 	return u | ||||
| } | ||||
|  | ||||
| // MergeUIntCodecOptions combines the given *UIntCodecOptions into a single *UIntCodecOptions in a last one wins fashion. | ||||
| func MergeUIntCodecOptions(opts ...*UIntCodecOptions) *UIntCodecOptions { | ||||
| 	u := UIntCodec() | ||||
| 	for _, opt := range opts { | ||||
| 		if opt == nil { | ||||
| 			continue | ||||
| 		} | ||||
| 		if opt.EncodeToMinSize != nil { | ||||
| 			u.EncodeToMinSize = opt.EncodeToMinSize | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return u | ||||
| } | ||||
							
								
								
									
										33
									
								
								mongo/bson/bsonrw/bsonrw_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								mongo/bson/bsonrw/bsonrw_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import "testing" | ||||
|  | ||||
| func compareErrors(err1, err2 error) bool { | ||||
| 	if err1 == nil && err2 == nil { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	if err1 == nil || err2 == nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if err1.Error() != err2.Error() { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func noerr(t *testing.T, err error) { | ||||
| 	if err != nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("Unexpected error: (%T)%v", err, err) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										848
									
								
								mongo/bson/bsonrw/bsonrwtest/bsonrwtest.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										848
									
								
								mongo/bson/bsonrw/bsonrwtest/bsonrwtest.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,848 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| // Package bsonrwtest provides utilities for testing the "bson/bsonrw" package. | ||||
| package bsonrwtest // import "go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| var _ bsonrw.ValueReader = (*ValueReaderWriter)(nil) | ||||
| var _ bsonrw.ValueWriter = (*ValueReaderWriter)(nil) | ||||
|  | ||||
| // Invoked is a type used to indicate what method was called last. | ||||
| type Invoked byte | ||||
|  | ||||
| // These are the different methods that can be invoked. | ||||
| const ( | ||||
| 	Nothing Invoked = iota | ||||
| 	ReadArray | ||||
| 	ReadBinary | ||||
| 	ReadBoolean | ||||
| 	ReadDocument | ||||
| 	ReadCodeWithScope | ||||
| 	ReadDBPointer | ||||
| 	ReadDateTime | ||||
| 	ReadDecimal128 | ||||
| 	ReadDouble | ||||
| 	ReadInt32 | ||||
| 	ReadInt64 | ||||
| 	ReadJavascript | ||||
| 	ReadMaxKey | ||||
| 	ReadMinKey | ||||
| 	ReadNull | ||||
| 	ReadObjectID | ||||
| 	ReadRegex | ||||
| 	ReadString | ||||
| 	ReadSymbol | ||||
| 	ReadTimestamp | ||||
| 	ReadUndefined | ||||
| 	ReadElement | ||||
| 	ReadValue | ||||
| 	WriteArray | ||||
| 	WriteBinary | ||||
| 	WriteBinaryWithSubtype | ||||
| 	WriteBoolean | ||||
| 	WriteCodeWithScope | ||||
| 	WriteDBPointer | ||||
| 	WriteDateTime | ||||
| 	WriteDecimal128 | ||||
| 	WriteDouble | ||||
| 	WriteInt32 | ||||
| 	WriteInt64 | ||||
| 	WriteJavascript | ||||
| 	WriteMaxKey | ||||
| 	WriteMinKey | ||||
| 	WriteNull | ||||
| 	WriteObjectID | ||||
| 	WriteRegex | ||||
| 	WriteString | ||||
| 	WriteDocument | ||||
| 	WriteSymbol | ||||
| 	WriteTimestamp | ||||
| 	WriteUndefined | ||||
| 	WriteDocumentElement | ||||
| 	WriteDocumentEnd | ||||
| 	WriteArrayElement | ||||
| 	WriteArrayEnd | ||||
| 	Skip | ||||
| ) | ||||
|  | ||||
| func (i Invoked) String() string { | ||||
| 	switch i { | ||||
| 	case Nothing: | ||||
| 		return "Nothing" | ||||
| 	case ReadArray: | ||||
| 		return "ReadArray" | ||||
| 	case ReadBinary: | ||||
| 		return "ReadBinary" | ||||
| 	case ReadBoolean: | ||||
| 		return "ReadBoolean" | ||||
| 	case ReadDocument: | ||||
| 		return "ReadDocument" | ||||
| 	case ReadCodeWithScope: | ||||
| 		return "ReadCodeWithScope" | ||||
| 	case ReadDBPointer: | ||||
| 		return "ReadDBPointer" | ||||
| 	case ReadDateTime: | ||||
| 		return "ReadDateTime" | ||||
| 	case ReadDecimal128: | ||||
| 		return "ReadDecimal128" | ||||
| 	case ReadDouble: | ||||
| 		return "ReadDouble" | ||||
| 	case ReadInt32: | ||||
| 		return "ReadInt32" | ||||
| 	case ReadInt64: | ||||
| 		return "ReadInt64" | ||||
| 	case ReadJavascript: | ||||
| 		return "ReadJavascript" | ||||
| 	case ReadMaxKey: | ||||
| 		return "ReadMaxKey" | ||||
| 	case ReadMinKey: | ||||
| 		return "ReadMinKey" | ||||
| 	case ReadNull: | ||||
| 		return "ReadNull" | ||||
| 	case ReadObjectID: | ||||
| 		return "ReadObjectID" | ||||
| 	case ReadRegex: | ||||
| 		return "ReadRegex" | ||||
| 	case ReadString: | ||||
| 		return "ReadString" | ||||
| 	case ReadSymbol: | ||||
| 		return "ReadSymbol" | ||||
| 	case ReadTimestamp: | ||||
| 		return "ReadTimestamp" | ||||
| 	case ReadUndefined: | ||||
| 		return "ReadUndefined" | ||||
| 	case ReadElement: | ||||
| 		return "ReadElement" | ||||
| 	case ReadValue: | ||||
| 		return "ReadValue" | ||||
| 	case WriteArray: | ||||
| 		return "WriteArray" | ||||
| 	case WriteBinary: | ||||
| 		return "WriteBinary" | ||||
| 	case WriteBinaryWithSubtype: | ||||
| 		return "WriteBinaryWithSubtype" | ||||
| 	case WriteBoolean: | ||||
| 		return "WriteBoolean" | ||||
| 	case WriteCodeWithScope: | ||||
| 		return "WriteCodeWithScope" | ||||
| 	case WriteDBPointer: | ||||
| 		return "WriteDBPointer" | ||||
| 	case WriteDateTime: | ||||
| 		return "WriteDateTime" | ||||
| 	case WriteDecimal128: | ||||
| 		return "WriteDecimal128" | ||||
| 	case WriteDouble: | ||||
| 		return "WriteDouble" | ||||
| 	case WriteInt32: | ||||
| 		return "WriteInt32" | ||||
| 	case WriteInt64: | ||||
| 		return "WriteInt64" | ||||
| 	case WriteJavascript: | ||||
| 		return "WriteJavascript" | ||||
| 	case WriteMaxKey: | ||||
| 		return "WriteMaxKey" | ||||
| 	case WriteMinKey: | ||||
| 		return "WriteMinKey" | ||||
| 	case WriteNull: | ||||
| 		return "WriteNull" | ||||
| 	case WriteObjectID: | ||||
| 		return "WriteObjectID" | ||||
| 	case WriteRegex: | ||||
| 		return "WriteRegex" | ||||
| 	case WriteString: | ||||
| 		return "WriteString" | ||||
| 	case WriteDocument: | ||||
| 		return "WriteDocument" | ||||
| 	case WriteSymbol: | ||||
| 		return "WriteSymbol" | ||||
| 	case WriteTimestamp: | ||||
| 		return "WriteTimestamp" | ||||
| 	case WriteUndefined: | ||||
| 		return "WriteUndefined" | ||||
| 	case WriteDocumentElement: | ||||
| 		return "WriteDocumentElement" | ||||
| 	case WriteDocumentEnd: | ||||
| 		return "WriteDocumentEnd" | ||||
| 	case WriteArrayElement: | ||||
| 		return "WriteArrayElement" | ||||
| 	case WriteArrayEnd: | ||||
| 		return "WriteArrayEnd" | ||||
| 	default: | ||||
| 		return "<unknown>" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ValueReaderWriter is a test implementation of a bsonrw.ValueReader and bsonrw.ValueWriter | ||||
| type ValueReaderWriter struct { | ||||
| 	T        *testing.T | ||||
| 	Invoked  Invoked | ||||
| 	Return   interface{} // Can be a primitive or a bsoncore.Value | ||||
| 	BSONType bsontype.Type | ||||
| 	Err      error | ||||
| 	ErrAfter Invoked // error after this method is called | ||||
| 	depth    uint64 | ||||
| } | ||||
|  | ||||
| // prevent infinite recursion. | ||||
| func (llvrw *ValueReaderWriter) checkdepth() { | ||||
| 	llvrw.depth++ | ||||
| 	if llvrw.depth > 1000 { | ||||
| 		panic("max depth exceeded") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Type implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) Type() bsontype.Type { | ||||
| 	llvrw.checkdepth() | ||||
| 	return llvrw.BSONType | ||||
| } | ||||
|  | ||||
| // Skip implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) Skip() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = Skip | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadArray implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadArray() (bsonrw.ArrayReader, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadArray | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // ReadBinary implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadBinary() (b []byte, btype byte, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadBinary | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, 0x00, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	switch tt := llvrw.Return.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		subtype, data, _, ok := bsoncore.ReadBinary(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.T.Error("Invalid Value provided for return value of ReadBinary.") | ||||
| 			return nil, 0x00, nil | ||||
| 		} | ||||
| 		return data, subtype, nil | ||||
| 	default: | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.Return) | ||||
| 		return nil, 0x00, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadBoolean implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadBoolean() (bool, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadBoolean | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return false, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	switch tt := llvrw.Return.(type) { | ||||
| 	case bool: | ||||
| 		return tt, nil | ||||
| 	case bsoncore.Value: | ||||
| 		b, _, ok := bsoncore.ReadBoolean(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.T.Error("Invalid Value provided for return value of ReadBoolean.") | ||||
| 			return false, nil | ||||
| 		} | ||||
| 		return b, nil | ||||
| 	default: | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.Return) | ||||
| 		return false, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadDocument implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadDocument() (bsonrw.DocumentReader, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadDocument | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // ReadCodeWithScope implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadCodeWithScope() (code string, dr bsonrw.DocumentReader, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadCodeWithScope | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return "", llvrw, nil | ||||
| } | ||||
|  | ||||
| // ReadDBPointer implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadDBPointer | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", primitive.ObjectID{}, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	switch tt := llvrw.Return.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.T.Error("Invalid Value instance provided for return value of ReadDBPointer") | ||||
| 			return "", primitive.ObjectID{}, nil | ||||
| 		} | ||||
| 		return ns, oid, nil | ||||
| 	default: | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.Return) | ||||
| 		return "", primitive.ObjectID{}, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadDateTime implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadDateTime() (int64, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadDateTime | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return 0, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	dt, ok := llvrw.Return.(int64) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.Return) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return dt, nil | ||||
| } | ||||
|  | ||||
| // ReadDecimal128 implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadDecimal128() (primitive.Decimal128, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadDecimal128 | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return primitive.Decimal128{}, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	d128, ok := llvrw.Return.(primitive.Decimal128) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.Return) | ||||
| 		return primitive.Decimal128{}, nil | ||||
| 	} | ||||
|  | ||||
| 	return d128, nil | ||||
| } | ||||
|  | ||||
| // ReadDouble implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadDouble() (float64, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadDouble | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return 0, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	f64, ok := llvrw.Return.(float64) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.Return) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return f64, nil | ||||
| } | ||||
|  | ||||
| // ReadInt32 implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadInt32() (int32, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadInt32 | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return 0, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	i32, ok := llvrw.Return.(int32) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.Return) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return i32, nil | ||||
| } | ||||
|  | ||||
| // ReadInt64 implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadInt64() (int64, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadInt64 | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return 0, llvrw.Err | ||||
| 	} | ||||
| 	i64, ok := llvrw.Return.(int64) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.Return) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return i64, nil | ||||
| } | ||||
|  | ||||
| // ReadJavascript implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadJavascript() (code string, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadJavascript | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", llvrw.Err | ||||
| 	} | ||||
| 	js, ok := llvrw.Return.(string) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.Return) | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	return js, nil | ||||
| } | ||||
|  | ||||
| // ReadMaxKey implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadMaxKey() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadMaxKey | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadMinKey implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadMinKey() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadMinKey | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadNull implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadNull() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadNull | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadObjectID implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadObjectID() (primitive.ObjectID, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadObjectID | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return primitive.ObjectID{}, llvrw.Err | ||||
| 	} | ||||
| 	oid, ok := llvrw.Return.(primitive.ObjectID) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.Return) | ||||
| 		return primitive.ObjectID{}, nil | ||||
| 	} | ||||
|  | ||||
| 	return oid, nil | ||||
| } | ||||
|  | ||||
| // ReadRegex implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadRegex() (pattern string, options string, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadRegex | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", "", llvrw.Err | ||||
| 	} | ||||
| 	switch tt := llvrw.Return.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		pattern, options, _, ok := bsoncore.ReadRegex(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.T.Error("Invalid Value instance provided for ReadRegex") | ||||
| 			return "", "", nil | ||||
| 		} | ||||
| 		return pattern, options, nil | ||||
| 	default: | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.Return) | ||||
| 		return "", "", nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadString implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadString() (string, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadString | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", llvrw.Err | ||||
| 	} | ||||
| 	str, ok := llvrw.Return.(string) | ||||
| 	if !ok { | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.Return) | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	return str, nil | ||||
| } | ||||
|  | ||||
| // ReadSymbol implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadSymbol() (symbol string, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadSymbol | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", llvrw.Err | ||||
| 	} | ||||
| 	switch tt := llvrw.Return.(type) { | ||||
| 	case string: | ||||
| 		return tt, nil | ||||
| 	case bsoncore.Value: | ||||
| 		symbol, _, ok := bsoncore.ReadSymbol(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.T.Error("Invalid Value instance provided for ReadSymbol") | ||||
| 			return "", nil | ||||
| 		} | ||||
| 		return symbol, nil | ||||
| 	default: | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.Return) | ||||
| 		return "", nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadTimestamp implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadTimestamp | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return 0, 0, llvrw.Err | ||||
| 	} | ||||
| 	switch tt := llvrw.Return.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		t, i, _, ok := bsoncore.ReadTimestamp(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.T.Errorf("Invalid Value instance provided for return value of ReadTimestamp") | ||||
| 			return 0, 0, nil | ||||
| 		} | ||||
| 		return t, i, nil | ||||
| 	default: | ||||
| 		llvrw.T.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.Return) | ||||
| 		return 0, 0, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ReadUndefined implements the bsonrw.ValueReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadUndefined() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadUndefined | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteArray implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteArray() (bsonrw.ArrayWriter, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteArray | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteBinary implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteBinary(b []byte) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteBinary | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteBinaryWithSubtype implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteBinaryWithSubtype | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteBoolean implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteBoolean(bool) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteBoolean | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteCodeWithScope implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteCodeWithScope(code string) (bsonrw.DocumentWriter, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteCodeWithScope | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteDBPointer implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDBPointer | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteDateTime implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDateTime(dt int64) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDateTime | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteDecimal128 implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDecimal128(primitive.Decimal128) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDecimal128 | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteDouble implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDouble(float64) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDouble | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteInt32 implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteInt32(int32) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteInt32 | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteInt64 implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteInt64(int64) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteInt64 | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteJavascript implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteJavascript(code string) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteJavascript | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteMaxKey implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteMaxKey() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteMaxKey | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteMinKey implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteMinKey() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteMinKey | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteNull implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteNull() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteNull | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteObjectID implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteObjectID(primitive.ObjectID) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteObjectID | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteRegex implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteRegex(pattern string, options string) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteRegex | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteString implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteString(string) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteString | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteDocument implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDocument() (bsonrw.DocumentWriter, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDocument | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteSymbol implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteSymbol(symbol string) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteSymbol | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteTimestamp implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteTimestamp(t uint32, i uint32) error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteTimestamp | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WriteUndefined implements the bsonrw.ValueWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteUndefined() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteUndefined | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadElement implements the bsonrw.DocumentReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadElement() (string, bsonrw.ValueReader, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadElement | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return "", nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return "", llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteDocumentElement implements the bsonrw.DocumentWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDocumentElement(string) (bsonrw.ValueWriter, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDocumentElement | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteDocumentEnd implements the bsonrw.DocumentWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteDocumentEnd() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteDocumentEnd | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ReadValue implements the bsonrw.ArrayReader interface. | ||||
| func (llvrw *ValueReaderWriter) ReadValue() (bsonrw.ValueReader, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = ReadValue | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteArrayElement implements the bsonrw.ArrayWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteArrayElement() (bsonrw.ValueWriter, error) { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteArrayElement | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return nil, llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| // WriteArrayEnd implements the bsonrw.ArrayWriter interface. | ||||
| func (llvrw *ValueReaderWriter) WriteArrayEnd() error { | ||||
| 	llvrw.checkdepth() | ||||
| 	llvrw.Invoked = WriteArrayEnd | ||||
| 	if llvrw.ErrAfter == llvrw.Invoked { | ||||
| 		return llvrw.Err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										445
									
								
								mongo/bson/bsonrw/copier.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										445
									
								
								mongo/bson/bsonrw/copier.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,445 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| // Copier is a type that allows copying between ValueReaders, ValueWriters, and | ||||
| // []byte values. | ||||
| type Copier struct{} | ||||
|  | ||||
| // NewCopier creates a new copier with the given registry. If a nil registry is provided | ||||
| // a default registry is used. | ||||
| func NewCopier() Copier { | ||||
| 	return Copier{} | ||||
| } | ||||
|  | ||||
| // CopyDocument handles copying a document from src to dst. | ||||
| func CopyDocument(dst ValueWriter, src ValueReader) error { | ||||
| 	return Copier{}.CopyDocument(dst, src) | ||||
| } | ||||
|  | ||||
| // CopyDocument handles copying one document from the src to the dst. | ||||
| func (c Copier) CopyDocument(dst ValueWriter, src ValueReader) error { | ||||
| 	dr, err := src.ReadDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	dw, err := dst.WriteDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return c.copyDocumentCore(dw, dr) | ||||
| } | ||||
|  | ||||
| // CopyArrayFromBytes copies the values from a BSON array represented as a | ||||
| // []byte to a ValueWriter. | ||||
| func (c Copier) CopyArrayFromBytes(dst ValueWriter, src []byte) error { | ||||
| 	aw, err := dst.WriteArray() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = c.CopyBytesToArrayWriter(aw, src) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return aw.WriteArrayEnd() | ||||
| } | ||||
|  | ||||
| // CopyDocumentFromBytes copies the values from a BSON document represented as a | ||||
| // []byte to a ValueWriter. | ||||
| func (c Copier) CopyDocumentFromBytes(dst ValueWriter, src []byte) error { | ||||
| 	dw, err := dst.WriteDocument() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = c.CopyBytesToDocumentWriter(dw, src) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return dw.WriteDocumentEnd() | ||||
| } | ||||
|  | ||||
| type writeElementFn func(key string) (ValueWriter, error) | ||||
|  | ||||
| // CopyBytesToArrayWriter copies the values from a BSON Array represented as a []byte to an | ||||
| // ArrayWriter. | ||||
| func (c Copier) CopyBytesToArrayWriter(dst ArrayWriter, src []byte) error { | ||||
| 	wef := func(_ string) (ValueWriter, error) { | ||||
| 		return dst.WriteArrayElement() | ||||
| 	} | ||||
|  | ||||
| 	return c.copyBytesToValueWriter(src, wef) | ||||
| } | ||||
|  | ||||
| // CopyBytesToDocumentWriter copies the values from a BSON document represented as a []byte to a | ||||
| // DocumentWriter. | ||||
| func (c Copier) CopyBytesToDocumentWriter(dst DocumentWriter, src []byte) error { | ||||
| 	wef := func(key string) (ValueWriter, error) { | ||||
| 		return dst.WriteDocumentElement(key) | ||||
| 	} | ||||
|  | ||||
| 	return c.copyBytesToValueWriter(src, wef) | ||||
| } | ||||
|  | ||||
| func (c Copier) copyBytesToValueWriter(src []byte, wef writeElementFn) error { | ||||
| 	// TODO(skriptble): Create errors types here. Anything thats a tag should be a property. | ||||
| 	length, rem, ok := bsoncore.ReadLength(src) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("couldn't read length from src, not enough bytes. length=%d", len(src)) | ||||
| 	} | ||||
| 	if len(src) < int(length) { | ||||
| 		return fmt.Errorf("length read exceeds number of bytes available. length=%d bytes=%d", len(src), length) | ||||
| 	} | ||||
| 	rem = rem[:length-4] | ||||
|  | ||||
| 	var t bsontype.Type | ||||
| 	var key string | ||||
| 	var val bsoncore.Value | ||||
| 	for { | ||||
| 		t, rem, ok = bsoncore.ReadType(rem) | ||||
| 		if !ok { | ||||
| 			return io.EOF | ||||
| 		} | ||||
| 		if t == bsontype.Type(0) { | ||||
| 			if len(rem) != 0 { | ||||
| 				return fmt.Errorf("document end byte found before end of document. remaining bytes=%v", rem) | ||||
| 			} | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		key, rem, ok = bsoncore.ReadKey(rem) | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("invalid key found. remaining bytes=%v", rem) | ||||
| 		} | ||||
|  | ||||
| 		// write as either array element or document element using writeElementFn | ||||
| 		vw, err := wef(key) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		val, rem, ok = bsoncore.ReadValue(rem, t) | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("not enough bytes available to read type. bytes=%d type=%s", len(rem), t) | ||||
| 		} | ||||
| 		err = c.CopyValueFromBytes(vw, t, val.Data) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // CopyDocumentToBytes copies an entire document from the ValueReader and | ||||
| // returns it as bytes. | ||||
| func (c Copier) CopyDocumentToBytes(src ValueReader) ([]byte, error) { | ||||
| 	return c.AppendDocumentBytes(nil, src) | ||||
| } | ||||
|  | ||||
| // AppendDocumentBytes functions the same as CopyDocumentToBytes, but will | ||||
| // append the result to dst. | ||||
| func (c Copier) AppendDocumentBytes(dst []byte, src ValueReader) ([]byte, error) { | ||||
| 	if br, ok := src.(BytesReader); ok { | ||||
| 		_, dst, err := br.ReadValueBytes(dst) | ||||
| 		return dst, err | ||||
| 	} | ||||
|  | ||||
| 	vw := vwPool.Get().(*valueWriter) | ||||
| 	defer vwPool.Put(vw) | ||||
|  | ||||
| 	vw.reset(dst) | ||||
|  | ||||
| 	err := c.CopyDocument(vw, src) | ||||
| 	dst = vw.buf | ||||
| 	return dst, err | ||||
| } | ||||
|  | ||||
| // AppendArrayBytes copies an array from the ValueReader to dst. | ||||
| func (c Copier) AppendArrayBytes(dst []byte, src ValueReader) ([]byte, error) { | ||||
| 	if br, ok := src.(BytesReader); ok { | ||||
| 		_, dst, err := br.ReadValueBytes(dst) | ||||
| 		return dst, err | ||||
| 	} | ||||
|  | ||||
| 	vw := vwPool.Get().(*valueWriter) | ||||
| 	defer vwPool.Put(vw) | ||||
|  | ||||
| 	vw.reset(dst) | ||||
|  | ||||
| 	err := c.copyArray(vw, src) | ||||
| 	dst = vw.buf | ||||
| 	return dst, err | ||||
| } | ||||
|  | ||||
| // CopyValueFromBytes will write the value represtend by t and src to dst. | ||||
| func (c Copier) CopyValueFromBytes(dst ValueWriter, t bsontype.Type, src []byte) error { | ||||
| 	if wvb, ok := dst.(BytesWriter); ok { | ||||
| 		return wvb.WriteValueBytes(t, src) | ||||
| 	} | ||||
|  | ||||
| 	vr := vrPool.Get().(*valueReader) | ||||
| 	defer vrPool.Put(vr) | ||||
|  | ||||
| 	vr.reset(src) | ||||
| 	vr.pushElement(t) | ||||
|  | ||||
| 	return c.CopyValue(dst, vr) | ||||
| } | ||||
|  | ||||
| // CopyValueToBytes copies a value from src and returns it as a bsontype.Type and a | ||||
| // []byte. | ||||
| func (c Copier) CopyValueToBytes(src ValueReader) (bsontype.Type, []byte, error) { | ||||
| 	return c.AppendValueBytes(nil, src) | ||||
| } | ||||
|  | ||||
| // AppendValueBytes functions the same as CopyValueToBytes, but will append the | ||||
| // result to dst. | ||||
| func (c Copier) AppendValueBytes(dst []byte, src ValueReader) (bsontype.Type, []byte, error) { | ||||
| 	if br, ok := src.(BytesReader); ok { | ||||
| 		return br.ReadValueBytes(dst) | ||||
| 	} | ||||
|  | ||||
| 	vw := vwPool.Get().(*valueWriter) | ||||
| 	defer vwPool.Put(vw) | ||||
|  | ||||
| 	start := len(dst) | ||||
|  | ||||
| 	vw.reset(dst) | ||||
| 	vw.push(mElement) | ||||
|  | ||||
| 	err := c.CopyValue(vw, src) | ||||
| 	if err != nil { | ||||
| 		return 0, dst, err | ||||
| 	} | ||||
|  | ||||
| 	return bsontype.Type(vw.buf[start]), vw.buf[start+2:], nil | ||||
| } | ||||
|  | ||||
| // CopyValue will copy a single value from src to dst. | ||||
| func (c Copier) CopyValue(dst ValueWriter, src ValueReader) error { | ||||
| 	var err error | ||||
| 	switch src.Type() { | ||||
| 	case bsontype.Double: | ||||
| 		var f64 float64 | ||||
| 		f64, err = src.ReadDouble() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteDouble(f64) | ||||
| 	case bsontype.String: | ||||
| 		var str string | ||||
| 		str, err = src.ReadString() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		err = dst.WriteString(str) | ||||
| 	case bsontype.EmbeddedDocument: | ||||
| 		err = c.CopyDocument(dst, src) | ||||
| 	case bsontype.Array: | ||||
| 		err = c.copyArray(dst, src) | ||||
| 	case bsontype.Binary: | ||||
| 		var data []byte | ||||
| 		var subtype byte | ||||
| 		data, subtype, err = src.ReadBinary() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteBinaryWithSubtype(data, subtype) | ||||
| 	case bsontype.Undefined: | ||||
| 		err = src.ReadUndefined() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteUndefined() | ||||
| 	case bsontype.ObjectID: | ||||
| 		var oid primitive.ObjectID | ||||
| 		oid, err = src.ReadObjectID() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteObjectID(oid) | ||||
| 	case bsontype.Boolean: | ||||
| 		var b bool | ||||
| 		b, err = src.ReadBoolean() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteBoolean(b) | ||||
| 	case bsontype.DateTime: | ||||
| 		var dt int64 | ||||
| 		dt, err = src.ReadDateTime() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteDateTime(dt) | ||||
| 	case bsontype.Null: | ||||
| 		err = src.ReadNull() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteNull() | ||||
| 	case bsontype.Regex: | ||||
| 		var pattern, options string | ||||
| 		pattern, options, err = src.ReadRegex() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteRegex(pattern, options) | ||||
| 	case bsontype.DBPointer: | ||||
| 		var ns string | ||||
| 		var pointer primitive.ObjectID | ||||
| 		ns, pointer, err = src.ReadDBPointer() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteDBPointer(ns, pointer) | ||||
| 	case bsontype.JavaScript: | ||||
| 		var js string | ||||
| 		js, err = src.ReadJavascript() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteJavascript(js) | ||||
| 	case bsontype.Symbol: | ||||
| 		var symbol string | ||||
| 		symbol, err = src.ReadSymbol() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteSymbol(symbol) | ||||
| 	case bsontype.CodeWithScope: | ||||
| 		var code string | ||||
| 		var srcScope DocumentReader | ||||
| 		code, srcScope, err = src.ReadCodeWithScope() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		var dstScope DocumentWriter | ||||
| 		dstScope, err = dst.WriteCodeWithScope(code) | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = c.copyDocumentCore(dstScope, srcScope) | ||||
| 	case bsontype.Int32: | ||||
| 		var i32 int32 | ||||
| 		i32, err = src.ReadInt32() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteInt32(i32) | ||||
| 	case bsontype.Timestamp: | ||||
| 		var t, i uint32 | ||||
| 		t, i, err = src.ReadTimestamp() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteTimestamp(t, i) | ||||
| 	case bsontype.Int64: | ||||
| 		var i64 int64 | ||||
| 		i64, err = src.ReadInt64() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteInt64(i64) | ||||
| 	case bsontype.Decimal128: | ||||
| 		var d128 primitive.Decimal128 | ||||
| 		d128, err = src.ReadDecimal128() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteDecimal128(d128) | ||||
| 	case bsontype.MinKey: | ||||
| 		err = src.ReadMinKey() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteMinKey() | ||||
| 	case bsontype.MaxKey: | ||||
| 		err = src.ReadMaxKey() | ||||
| 		if err != nil { | ||||
| 			break | ||||
| 		} | ||||
| 		err = dst.WriteMaxKey() | ||||
| 	default: | ||||
| 		err = fmt.Errorf("Cannot copy unknown BSON type %s", src.Type()) | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c Copier) copyArray(dst ValueWriter, src ValueReader) error { | ||||
| 	ar, err := src.ReadArray() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	aw, err := dst.WriteArray() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		vr, err := ar.ReadValue() | ||||
| 		if err == ErrEOA { | ||||
| 			break | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		vw, err := aw.WriteArrayElement() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = c.CopyValue(vw, vr) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return aw.WriteArrayEnd() | ||||
| } | ||||
|  | ||||
| func (c Copier) copyDocumentCore(dw DocumentWriter, dr DocumentReader) error { | ||||
| 	for { | ||||
| 		key, vr, err := dr.ReadElement() | ||||
| 		if err == ErrEOD { | ||||
| 			break | ||||
| 		} | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		vw, err := dw.WriteDocumentElement(key) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		err = c.CopyValue(vw, vr) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return dw.WriteDocumentEnd() | ||||
| } | ||||
							
								
								
									
										529
									
								
								mongo/bson/bsonrw/copier_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										529
									
								
								mongo/bson/bsonrw/copier_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,529 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| func TestCopier(t *testing.T) { | ||||
| 	t.Run("CopyDocument", func(t *testing.T) { | ||||
| 		t.Run("ReadDocument Error", func(t *testing.T) { | ||||
| 			want := errors.New("ReadDocumentError") | ||||
| 			src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadDocument} | ||||
| 			got := Copier{}.CopyDocument(nil, src) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Did not receive correct error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("WriteDocument Error", func(t *testing.T) { | ||||
| 			want := errors.New("WriteDocumentError") | ||||
| 			src := &TestValueReaderWriter{} | ||||
| 			dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteDocument} | ||||
| 			got := Copier{}.CopyDocument(dst, src) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Did not receive correct error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("success", func(t *testing.T) { | ||||
| 			idx, doc := bsoncore.AppendDocumentStart(nil) | ||||
| 			doc = bsoncore.AppendStringElement(doc, "Hello", "world") | ||||
| 			doc, err := bsoncore.AppendDocumentEnd(doc, idx) | ||||
| 			noerr(t, err) | ||||
| 			src := newValueReader(doc) | ||||
| 			dst := newValueWriterFromSlice(make([]byte, 0)) | ||||
| 			want := doc | ||||
| 			err = Copier{}.CopyDocument(dst, src) | ||||
| 			noerr(t, err) | ||||
| 			got := dst.buf | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("copyArray", func(t *testing.T) { | ||||
| 		t.Run("ReadArray Error", func(t *testing.T) { | ||||
| 			want := errors.New("ReadArrayError") | ||||
| 			src := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwReadArray} | ||||
| 			got := Copier{}.copyArray(nil, src) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Did not receive correct error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("WriteArray Error", func(t *testing.T) { | ||||
| 			want := errors.New("WriteArrayError") | ||||
| 			src := &TestValueReaderWriter{} | ||||
| 			dst := &TestValueReaderWriter{t: t, err: want, errAfter: llvrwWriteArray} | ||||
| 			got := Copier{}.copyArray(dst, src) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Did not receive correct error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("success", func(t *testing.T) { | ||||
| 			idx, doc := bsoncore.AppendDocumentStart(nil) | ||||
| 			aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo") | ||||
| 			doc = bsoncore.AppendStringElement(doc, "0", "Hello, world!") | ||||
| 			doc, err := bsoncore.AppendArrayEnd(doc, aidx) | ||||
| 			noerr(t, err) | ||||
| 			doc, err = bsoncore.AppendDocumentEnd(doc, idx) | ||||
| 			noerr(t, err) | ||||
| 			src := newValueReader(doc) | ||||
|  | ||||
| 			_, err = src.ReadDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, _, err = src.ReadElement() | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			dst := newValueWriterFromSlice(make([]byte, 0)) | ||||
| 			_, err = dst.WriteDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, err = dst.WriteDocumentElement("foo") | ||||
| 			noerr(t, err) | ||||
| 			want := doc | ||||
|  | ||||
| 			err = Copier{}.copyArray(dst, src) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			err = dst.WriteDocumentEnd() | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			got := dst.buf | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("CopyValue", func(t *testing.T) { | ||||
| 		testCases := []struct { | ||||
| 			name string | ||||
| 			dst  *TestValueReaderWriter | ||||
| 			src  *TestValueReaderWriter | ||||
| 			err  error | ||||
| 		}{ | ||||
| 			{ | ||||
| 				"Double/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Double, err: errors.New("1"), errAfter: llvrwReadDouble}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Double/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Double, err: errors.New("2"), errAfter: llvrwWriteDouble}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Double, readval: float64(3.14159)}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"String/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.String, err: errors.New("1"), errAfter: llvrwReadString}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"String/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.String, err: errors.New("2"), errAfter: llvrwWriteString}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.String, readval: "hello, world"}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Document/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.EmbeddedDocument, err: errors.New("1"), errAfter: llvrwReadDocument}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Array/dst/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Array, err: errors.New("2"), errAfter: llvrwReadArray}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Binary/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Binary, err: errors.New("1"), errAfter: llvrwReadBinary}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Binary/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Binary, err: errors.New("2"), errAfter: llvrwWriteBinaryWithSubtype}, | ||||
| 				&TestValueReaderWriter{ | ||||
| 					bsontype: bsontype.Binary, | ||||
| 					readval: bsoncore.Value{ | ||||
| 						Type: bsontype.Binary, | ||||
| 						Data: []byte{0x03, 0x00, 0x00, 0x00, 0xFF, 0x01, 0x02, 0x03}, | ||||
| 					}, | ||||
| 				}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Undefined/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Undefined, err: errors.New("1"), errAfter: llvrwReadUndefined}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Undefined/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Undefined, err: errors.New("2"), errAfter: llvrwWriteUndefined}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Undefined}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"ObjectID/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.ObjectID, err: errors.New("1"), errAfter: llvrwReadObjectID}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"ObjectID/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.ObjectID, err: errors.New("2"), errAfter: llvrwWriteObjectID}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.ObjectID, readval: primitive.ObjectID{0x01, 0x02, 0x03}}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Boolean/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Boolean, err: errors.New("1"), errAfter: llvrwReadBoolean}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Boolean/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Boolean, err: errors.New("2"), errAfter: llvrwWriteBoolean}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Boolean, readval: bool(true)}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"DateTime/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.DateTime, err: errors.New("1"), errAfter: llvrwReadDateTime}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"DateTime/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.DateTime, err: errors.New("2"), errAfter: llvrwWriteDateTime}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.DateTime, readval: int64(1234567890)}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Null/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Null, err: errors.New("1"), errAfter: llvrwReadNull}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Null/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Null, err: errors.New("2"), errAfter: llvrwWriteNull}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Null}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Regex/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Regex, err: errors.New("1"), errAfter: llvrwReadRegex}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Regex/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Regex, err: errors.New("2"), errAfter: llvrwWriteRegex}, | ||||
| 				&TestValueReaderWriter{ | ||||
| 					bsontype: bsontype.Regex, | ||||
| 					readval: bsoncore.Value{ | ||||
| 						Type: bsontype.Regex, | ||||
| 						Data: bsoncore.AppendRegex(nil, "hello", "world"), | ||||
| 					}, | ||||
| 				}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"DBPointer/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.DBPointer, err: errors.New("1"), errAfter: llvrwReadDBPointer}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"DBPointer/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.DBPointer, err: errors.New("2"), errAfter: llvrwWriteDBPointer}, | ||||
| 				&TestValueReaderWriter{ | ||||
| 					bsontype: bsontype.DBPointer, | ||||
| 					readval: bsoncore.Value{ | ||||
| 						Type: bsontype.DBPointer, | ||||
| 						Data: bsoncore.AppendDBPointer(nil, "foo", primitive.ObjectID{0x01, 0x02, 0x03}), | ||||
| 					}, | ||||
| 				}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Javascript/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.JavaScript, err: errors.New("1"), errAfter: llvrwReadJavascript}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Javascript/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.JavaScript, err: errors.New("2"), errAfter: llvrwWriteJavascript}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.JavaScript, readval: "hello, world"}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Symbol/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Symbol, err: errors.New("1"), errAfter: llvrwReadSymbol}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Symbol/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Symbol, err: errors.New("2"), errAfter: llvrwWriteSymbol}, | ||||
| 				&TestValueReaderWriter{ | ||||
| 					bsontype: bsontype.Symbol, | ||||
| 					readval: bsoncore.Value{ | ||||
| 						Type: bsontype.Symbol, | ||||
| 						Data: bsoncore.AppendSymbol(nil, "hello, world"), | ||||
| 					}, | ||||
| 				}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"CodeWithScope/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.CodeWithScope, err: errors.New("1"), errAfter: llvrwReadCodeWithScope}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"CodeWithScope/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.CodeWithScope, err: errors.New("2"), errAfter: llvrwWriteCodeWithScope}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.CodeWithScope}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"CodeWithScope/dst/copyDocumentCore error", | ||||
| 				&TestValueReaderWriter{err: errors.New("3"), errAfter: llvrwWriteDocumentElement}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.CodeWithScope}, | ||||
| 				errors.New("3"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Int32/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Int32, err: errors.New("1"), errAfter: llvrwReadInt32}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Int32/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Int32, err: errors.New("2"), errAfter: llvrwWriteInt32}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Int32, readval: int32(12345)}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Timestamp/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Timestamp, err: errors.New("1"), errAfter: llvrwReadTimestamp}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Timestamp/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Timestamp, err: errors.New("2"), errAfter: llvrwWriteTimestamp}, | ||||
| 				&TestValueReaderWriter{ | ||||
| 					bsontype: bsontype.Timestamp, | ||||
| 					readval: bsoncore.Value{ | ||||
| 						Type: bsontype.Timestamp, | ||||
| 						Data: bsoncore.AppendTimestamp(nil, 12345, 67890), | ||||
| 					}, | ||||
| 				}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Int64/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Int64, err: errors.New("1"), errAfter: llvrwReadInt64}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Int64/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Int64, err: errors.New("2"), errAfter: llvrwWriteInt64}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Int64, readval: int64(1234567890)}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Decimal128/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Decimal128, err: errors.New("1"), errAfter: llvrwReadDecimal128}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Decimal128/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Decimal128, err: errors.New("2"), errAfter: llvrwWriteDecimal128}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.Decimal128, readval: primitive.NewDecimal128(12345, 67890)}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"MinKey/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.MinKey, err: errors.New("1"), errAfter: llvrwReadMinKey}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"MinKey/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.MinKey, err: errors.New("2"), errAfter: llvrwWriteMinKey}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.MinKey}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"MaxKey/src/error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.MaxKey, err: errors.New("1"), errAfter: llvrwReadMaxKey}, | ||||
| 				errors.New("1"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"MaxKey/dst/error", | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.MaxKey, err: errors.New("2"), errAfter: llvrwWriteMaxKey}, | ||||
| 				&TestValueReaderWriter{bsontype: bsontype.MaxKey}, | ||||
| 				errors.New("2"), | ||||
| 			}, | ||||
| 			{ | ||||
| 				"Unknown BSON type error", | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				&TestValueReaderWriter{}, | ||||
| 				fmt.Errorf("Cannot copy unknown BSON type %s", bsontype.Type(0)), | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				tc.dst.t, tc.src.t = t, t | ||||
| 				err := Copier{}.CopyValue(tc.dst, tc.src) | ||||
| 				if !compareErrors(err, tc.err) { | ||||
| 					t.Errorf("Did not receive expected error. got %v; want %v", err, tc.err) | ||||
| 				} | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("CopyValueFromBytes", func(t *testing.T) { | ||||
| 		t.Run("BytesWriter", func(t *testing.T) { | ||||
| 			vw := newValueWriterFromSlice(make([]byte, 0)) | ||||
| 			_, err := vw.WriteDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, err = vw.WriteDocumentElement("foo") | ||||
| 			noerr(t, err) | ||||
| 			err = Copier{}.CopyValueFromBytes(vw, bsontype.String, bsoncore.AppendString(nil, "bar")) | ||||
| 			noerr(t, err) | ||||
| 			err = vw.WriteDocumentEnd() | ||||
| 			noerr(t, err) | ||||
| 			var idx int32 | ||||
| 			want, err := bsoncore.AppendDocumentEnd( | ||||
| 				bsoncore.AppendStringElement( | ||||
| 					bsoncore.AppendDocumentStartInline(nil, &idx), | ||||
| 					"foo", "bar", | ||||
| 				), | ||||
| 				idx, | ||||
| 			) | ||||
| 			noerr(t, err) | ||||
| 			got := vw.buf | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("Non BytesWriter", func(t *testing.T) { | ||||
| 			llvrw := &TestValueReaderWriter{t: t} | ||||
| 			err := Copier{}.CopyValueFromBytes(llvrw, bsontype.String, bsoncore.AppendString(nil, "bar")) | ||||
| 			noerr(t, err) | ||||
| 			got, want := llvrw.invoked, llvrwWriteString | ||||
| 			if got != want { | ||||
| 				t.Errorf("Incorrect method invoked on llvrw. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("CopyValueToBytes", func(t *testing.T) { | ||||
| 		t.Run("BytesReader", func(t *testing.T) { | ||||
| 			var idx int32 | ||||
| 			b, err := bsoncore.AppendDocumentEnd( | ||||
| 				bsoncore.AppendStringElement( | ||||
| 					bsoncore.AppendDocumentStartInline(nil, &idx), | ||||
| 					"hello", "world", | ||||
| 				), | ||||
| 				idx, | ||||
| 			) | ||||
| 			noerr(t, err) | ||||
| 			vr := newValueReader(b) | ||||
| 			_, err = vr.ReadDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, _, err = vr.ReadElement() | ||||
| 			noerr(t, err) | ||||
| 			btype, got, err := Copier{}.CopyValueToBytes(vr) | ||||
| 			noerr(t, err) | ||||
| 			want := bsoncore.AppendString(nil, "world") | ||||
| 			if btype != bsontype.String { | ||||
| 				t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String) | ||||
| 			} | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes do not match. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("Non BytesReader", func(t *testing.T) { | ||||
| 			llvrw := &TestValueReaderWriter{t: t, bsontype: bsontype.String, readval: "Hello, world!"} | ||||
| 			btype, got, err := Copier{}.CopyValueToBytes(llvrw) | ||||
| 			noerr(t, err) | ||||
| 			want := bsoncore.AppendString(nil, "Hello, world!") | ||||
| 			if btype != bsontype.String { | ||||
| 				t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String) | ||||
| 			} | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes do not match. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("AppendValueBytes", func(t *testing.T) { | ||||
| 		t.Run("BytesReader", func(t *testing.T) { | ||||
| 			var idx int32 | ||||
| 			b, err := bsoncore.AppendDocumentEnd( | ||||
| 				bsoncore.AppendStringElement( | ||||
| 					bsoncore.AppendDocumentStartInline(nil, &idx), | ||||
| 					"hello", "world", | ||||
| 				), | ||||
| 				idx, | ||||
| 			) | ||||
| 			noerr(t, err) | ||||
| 			vr := newValueReader(b) | ||||
| 			_, err = vr.ReadDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, _, err = vr.ReadElement() | ||||
| 			noerr(t, err) | ||||
| 			btype, got, err := Copier{}.AppendValueBytes(nil, vr) | ||||
| 			noerr(t, err) | ||||
| 			want := bsoncore.AppendString(nil, "world") | ||||
| 			if btype != bsontype.String { | ||||
| 				t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String) | ||||
| 			} | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes do not match. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("Non BytesReader", func(t *testing.T) { | ||||
| 			llvrw := &TestValueReaderWriter{t: t, bsontype: bsontype.String, readval: "Hello, world!"} | ||||
| 			btype, got, err := Copier{}.AppendValueBytes(nil, llvrw) | ||||
| 			noerr(t, err) | ||||
| 			want := bsoncore.AppendString(nil, "Hello, world!") | ||||
| 			if btype != bsontype.String { | ||||
| 				t.Errorf("Incorrect type returned. got %v; want %v", btype, bsontype.String) | ||||
| 			} | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes do not match. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("CopyValue error", func(t *testing.T) { | ||||
| 			want := errors.New("CopyValue error") | ||||
| 			llvrw := &TestValueReaderWriter{t: t, bsontype: bsontype.String, err: want, errAfter: llvrwReadString} | ||||
| 			_, _, got := Copier{}.AppendValueBytes(make([]byte, 0), llvrw) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Errors do not match. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										9
									
								
								mongo/bson/bsonrw/doc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								mongo/bson/bsonrw/doc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| // Package bsonrw contains abstractions for reading and writing | ||||
| // BSON and BSON like types from sources. | ||||
| package bsonrw // import "go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
							
								
								
									
										806
									
								
								mongo/bson/bsonrw/extjson_parser.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										806
									
								
								mongo/bson/bsonrw/extjson_parser.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,806 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"strings" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| const maxNestingDepth = 200 | ||||
|  | ||||
| // ErrInvalidJSON indicates the JSON input is invalid | ||||
| var ErrInvalidJSON = errors.New("invalid JSON input") | ||||
|  | ||||
| type jsonParseState byte | ||||
|  | ||||
| const ( | ||||
| 	jpsStartState jsonParseState = iota | ||||
| 	jpsSawBeginObject | ||||
| 	jpsSawEndObject | ||||
| 	jpsSawBeginArray | ||||
| 	jpsSawEndArray | ||||
| 	jpsSawColon | ||||
| 	jpsSawComma | ||||
| 	jpsSawKey | ||||
| 	jpsSawValue | ||||
| 	jpsDoneState | ||||
| 	jpsInvalidState | ||||
| ) | ||||
|  | ||||
| type jsonParseMode byte | ||||
|  | ||||
| const ( | ||||
| 	jpmInvalidMode jsonParseMode = iota | ||||
| 	jpmObjectMode | ||||
| 	jpmArrayMode | ||||
| ) | ||||
|  | ||||
| type extJSONValue struct { | ||||
| 	t bsontype.Type | ||||
| 	v interface{} | ||||
| } | ||||
|  | ||||
| type extJSONObject struct { | ||||
| 	keys   []string | ||||
| 	values []*extJSONValue | ||||
| } | ||||
|  | ||||
| type extJSONParser struct { | ||||
| 	js *jsonScanner | ||||
| 	s  jsonParseState | ||||
| 	m  []jsonParseMode | ||||
| 	k  string | ||||
| 	v  *extJSONValue | ||||
|  | ||||
| 	err       error | ||||
| 	canonical bool | ||||
| 	depth     int | ||||
| 	maxDepth  int | ||||
|  | ||||
| 	emptyObject bool | ||||
| 	relaxedUUID bool | ||||
| } | ||||
|  | ||||
| // newExtJSONParser returns a new extended JSON parser, ready to to begin | ||||
| // parsing from the first character of the argued json input. It will not | ||||
| // perform any read-ahead and will therefore not report any errors about | ||||
| // malformed JSON at this point. | ||||
| func newExtJSONParser(r io.Reader, canonical bool) *extJSONParser { | ||||
| 	return &extJSONParser{ | ||||
| 		js:        &jsonScanner{r: r}, | ||||
| 		s:         jpsStartState, | ||||
| 		m:         []jsonParseMode{}, | ||||
| 		canonical: canonical, | ||||
| 		maxDepth:  maxNestingDepth, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // peekType examines the next value and returns its BSON Type | ||||
| func (ejp *extJSONParser) peekType() (bsontype.Type, error) { | ||||
| 	var t bsontype.Type | ||||
| 	var err error | ||||
| 	initialState := ejp.s | ||||
|  | ||||
| 	ejp.advanceState() | ||||
| 	switch ejp.s { | ||||
| 	case jpsSawValue: | ||||
| 		t = ejp.v.t | ||||
| 	case jpsSawBeginArray: | ||||
| 		t = bsontype.Array | ||||
| 	case jpsInvalidState: | ||||
| 		err = ejp.err | ||||
| 	case jpsSawComma: | ||||
| 		// in array mode, seeing a comma means we need to progress again to actually observe a type | ||||
| 		if ejp.peekMode() == jpmArrayMode { | ||||
| 			return ejp.peekType() | ||||
| 		} | ||||
| 	case jpsSawEndArray: | ||||
| 		// this would only be a valid state if we were in array mode, so return end-of-array error | ||||
| 		err = ErrEOA | ||||
| 	case jpsSawBeginObject: | ||||
| 		// peek key to determine type | ||||
| 		ejp.advanceState() | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawEndObject: // empty embedded document | ||||
| 			t = bsontype.EmbeddedDocument | ||||
| 			ejp.emptyObject = true | ||||
| 		case jpsInvalidState: | ||||
| 			err = ejp.err | ||||
| 		case jpsSawKey: | ||||
| 			if initialState == jpsStartState { | ||||
| 				return bsontype.EmbeddedDocument, nil | ||||
| 			} | ||||
| 			t = wrapperKeyBSONType(ejp.k) | ||||
|  | ||||
| 			// if $uuid is encountered, parse as binary subtype 4 | ||||
| 			if ejp.k == "$uuid" { | ||||
| 				ejp.relaxedUUID = true | ||||
| 				t = bsontype.Binary | ||||
| 			} | ||||
|  | ||||
| 			switch t { | ||||
| 			case bsontype.JavaScript: | ||||
| 				// just saw $code, need to check for $scope at same level | ||||
| 				_, err = ejp.readValue(bsontype.JavaScript) | ||||
| 				if err != nil { | ||||
| 					break | ||||
| 				} | ||||
|  | ||||
| 				switch ejp.s { | ||||
| 				case jpsSawEndObject: // type is TypeJavaScript | ||||
| 				case jpsSawComma: | ||||
| 					ejp.advanceState() | ||||
|  | ||||
| 					if ejp.s == jpsSawKey && ejp.k == "$scope" { | ||||
| 						t = bsontype.CodeWithScope | ||||
| 					} else { | ||||
| 						err = fmt.Errorf("invalid extended JSON: unexpected key %s in CodeWithScope object", ejp.k) | ||||
| 					} | ||||
| 				case jpsInvalidState: | ||||
| 					err = ejp.err | ||||
| 				default: | ||||
| 					err = ErrInvalidJSON | ||||
| 				} | ||||
| 			case bsontype.CodeWithScope: | ||||
| 				err = errors.New("invalid extended JSON: code with $scope must contain $code before $scope") | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return t, err | ||||
| } | ||||
|  | ||||
| // readKey parses the next key and its type and returns them | ||||
| func (ejp *extJSONParser) readKey() (string, bsontype.Type, error) { | ||||
| 	if ejp.emptyObject { | ||||
| 		ejp.emptyObject = false | ||||
| 		return "", 0, ErrEOD | ||||
| 	} | ||||
|  | ||||
| 	// advance to key (or return with error) | ||||
| 	switch ejp.s { | ||||
| 	case jpsStartState: | ||||
| 		ejp.advanceState() | ||||
| 		if ejp.s == jpsSawBeginObject { | ||||
| 			ejp.advanceState() | ||||
| 		} | ||||
| 	case jpsSawBeginObject: | ||||
| 		ejp.advanceState() | ||||
| 	case jpsSawValue, jpsSawEndObject, jpsSawEndArray: | ||||
| 		ejp.advanceState() | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawBeginObject, jpsSawComma: | ||||
| 			ejp.advanceState() | ||||
| 		case jpsSawEndObject: | ||||
| 			return "", 0, ErrEOD | ||||
| 		case jpsDoneState: | ||||
| 			return "", 0, io.EOF | ||||
| 		case jpsInvalidState: | ||||
| 			return "", 0, ejp.err | ||||
| 		default: | ||||
| 			return "", 0, ErrInvalidJSON | ||||
| 		} | ||||
| 	case jpsSawKey: // do nothing (key was peeked before) | ||||
| 	default: | ||||
| 		return "", 0, invalidRequestError("key") | ||||
| 	} | ||||
|  | ||||
| 	// read key | ||||
| 	var key string | ||||
|  | ||||
| 	switch ejp.s { | ||||
| 	case jpsSawKey: | ||||
| 		key = ejp.k | ||||
| 	case jpsSawEndObject: | ||||
| 		return "", 0, ErrEOD | ||||
| 	case jpsInvalidState: | ||||
| 		return "", 0, ejp.err | ||||
| 	default: | ||||
| 		return "", 0, invalidRequestError("key") | ||||
| 	} | ||||
|  | ||||
| 	// check for colon | ||||
| 	ejp.advanceState() | ||||
| 	if err := ensureColon(ejp.s, key); err != nil { | ||||
| 		return "", 0, err | ||||
| 	} | ||||
|  | ||||
| 	// peek at the value to determine type | ||||
| 	t, err := ejp.peekType() | ||||
| 	if err != nil { | ||||
| 		return "", 0, err | ||||
| 	} | ||||
|  | ||||
| 	return key, t, nil | ||||
| } | ||||
|  | ||||
| // readValue returns the value corresponding to the Type returned by peekType | ||||
| func (ejp *extJSONParser) readValue(t bsontype.Type) (*extJSONValue, error) { | ||||
| 	if ejp.s == jpsInvalidState { | ||||
| 		return nil, ejp.err | ||||
| 	} | ||||
|  | ||||
| 	var v *extJSONValue | ||||
|  | ||||
| 	switch t { | ||||
| 	case bsontype.Null, bsontype.Boolean, bsontype.String: | ||||
| 		if ejp.s != jpsSawValue { | ||||
| 			return nil, invalidRequestError(t.String()) | ||||
| 		} | ||||
| 		v = ejp.v | ||||
| 	case bsontype.Int32, bsontype.Int64, bsontype.Double: | ||||
| 		// relaxed version allows these to be literal number values | ||||
| 		if ejp.s == jpsSawValue { | ||||
| 			v = ejp.v | ||||
| 			break | ||||
| 		} | ||||
| 		fallthrough | ||||
| 	case bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID, bsontype.MinKey, bsontype.MaxKey, bsontype.Undefined: | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawKey: | ||||
| 			// read colon | ||||
| 			ejp.advanceState() | ||||
| 			if err := ensureColon(ejp.s, ejp.k); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			// read value | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawValue || !ejp.ensureExtValueType(t) { | ||||
| 				return nil, invalidJSONErrorForType("value", t) | ||||
| 			} | ||||
|  | ||||
| 			v = ejp.v | ||||
|  | ||||
| 			// read end object | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawEndObject { | ||||
| 				return nil, invalidJSONErrorForType("} after value", t) | ||||
| 			} | ||||
| 		default: | ||||
| 			return nil, invalidRequestError(t.String()) | ||||
| 		} | ||||
| 	case bsontype.Binary, bsontype.Regex, bsontype.Timestamp, bsontype.DBPointer: | ||||
| 		if ejp.s != jpsSawKey { | ||||
| 			return nil, invalidRequestError(t.String()) | ||||
| 		} | ||||
| 		// read colon | ||||
| 		ejp.advanceState() | ||||
| 		if err := ensureColon(ejp.s, ejp.k); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		ejp.advanceState() | ||||
| 		if t == bsontype.Binary && ejp.s == jpsSawValue { | ||||
| 			// convert relaxed $uuid format | ||||
| 			if ejp.relaxedUUID { | ||||
| 				defer func() { ejp.relaxedUUID = false }() | ||||
| 				uuid, err := ejp.v.parseSymbol() | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
|  | ||||
| 				// RFC 4122 defines the length of a UUID as 36 and the hyphens in a UUID as appearing | ||||
| 				// in the 8th, 13th, 18th, and 23rd characters. | ||||
| 				// | ||||
| 				// See https://tools.ietf.org/html/rfc4122#section-3 | ||||
| 				valid := len(uuid) == 36 && | ||||
| 					string(uuid[8]) == "-" && | ||||
| 					string(uuid[13]) == "-" && | ||||
| 					string(uuid[18]) == "-" && | ||||
| 					string(uuid[23]) == "-" | ||||
| 				if !valid { | ||||
| 					return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") | ||||
| 				} | ||||
|  | ||||
| 				// remove hyphens | ||||
| 				uuidNoHyphens := strings.Replace(uuid, "-", "", -1) | ||||
| 				if len(uuidNoHyphens) != 32 { | ||||
| 					return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding length and hyphens") | ||||
| 				} | ||||
|  | ||||
| 				// convert hex to bytes | ||||
| 				bytes, err := hex.DecodeString(uuidNoHyphens) | ||||
| 				if err != nil { | ||||
| 					return nil, fmt.Errorf("$uuid value does not follow RFC 4122 format regarding hex bytes: %v", err) | ||||
| 				} | ||||
|  | ||||
| 				ejp.advanceState() | ||||
| 				if ejp.s != jpsSawEndObject { | ||||
| 					return nil, invalidJSONErrorForType("$uuid and value and then }", bsontype.Binary) | ||||
| 				} | ||||
|  | ||||
| 				base64 := &extJSONValue{ | ||||
| 					t: bsontype.String, | ||||
| 					v: base64.StdEncoding.EncodeToString(bytes), | ||||
| 				} | ||||
| 				subType := &extJSONValue{ | ||||
| 					t: bsontype.String, | ||||
| 					v: "04", | ||||
| 				} | ||||
|  | ||||
| 				v = &extJSONValue{ | ||||
| 					t: bsontype.EmbeddedDocument, | ||||
| 					v: &extJSONObject{ | ||||
| 						keys:   []string{"base64", "subType"}, | ||||
| 						values: []*extJSONValue{base64, subType}, | ||||
| 					}, | ||||
| 				} | ||||
|  | ||||
| 				break | ||||
| 			} | ||||
|  | ||||
| 			// convert legacy $binary format | ||||
| 			base64 := ejp.v | ||||
|  | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawComma { | ||||
| 				return nil, invalidJSONErrorForType(",", bsontype.Binary) | ||||
| 			} | ||||
|  | ||||
| 			ejp.advanceState() | ||||
| 			key, t, err := ejp.readKey() | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
| 			if key != "$type" { | ||||
| 				return nil, invalidJSONErrorForType("$type", bsontype.Binary) | ||||
| 			} | ||||
|  | ||||
| 			subType, err := ejp.readValue(t) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawEndObject { | ||||
| 				return nil, invalidJSONErrorForType("2 key-value pairs and then }", bsontype.Binary) | ||||
| 			} | ||||
|  | ||||
| 			v = &extJSONValue{ | ||||
| 				t: bsontype.EmbeddedDocument, | ||||
| 				v: &extJSONObject{ | ||||
| 					keys:   []string{"base64", "subType"}, | ||||
| 					values: []*extJSONValue{base64, subType}, | ||||
| 				}, | ||||
| 			} | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// read KV pairs | ||||
| 		if ejp.s != jpsSawBeginObject { | ||||
| 			return nil, invalidJSONErrorForType("{", t) | ||||
| 		} | ||||
|  | ||||
| 		keys, vals, err := ejp.readObject(2, true) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		ejp.advanceState() | ||||
| 		if ejp.s != jpsSawEndObject { | ||||
| 			return nil, invalidJSONErrorForType("2 key-value pairs and then }", t) | ||||
| 		} | ||||
|  | ||||
| 		v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}} | ||||
|  | ||||
| 	case bsontype.DateTime: | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawValue: | ||||
| 			v = ejp.v | ||||
| 		case jpsSawKey: | ||||
| 			// read colon | ||||
| 			ejp.advanceState() | ||||
| 			if err := ensureColon(ejp.s, ejp.k); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			ejp.advanceState() | ||||
| 			switch ejp.s { | ||||
| 			case jpsSawBeginObject: | ||||
| 				keys, vals, err := ejp.readObject(1, true) | ||||
| 				if err != nil { | ||||
| 					return nil, err | ||||
| 				} | ||||
| 				v = &extJSONValue{t: bsontype.EmbeddedDocument, v: &extJSONObject{keys: keys, values: vals}} | ||||
| 			case jpsSawValue: | ||||
| 				if ejp.canonical { | ||||
| 					return nil, invalidJSONError("{") | ||||
| 				} | ||||
| 				v = ejp.v | ||||
| 			default: | ||||
| 				if ejp.canonical { | ||||
| 					return nil, invalidJSONErrorForType("object", t) | ||||
| 				} | ||||
| 				return nil, invalidJSONErrorForType("ISO-8601 Internet Date/Time Format as described in RFC-3339", t) | ||||
| 			} | ||||
|  | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawEndObject { | ||||
| 				return nil, invalidJSONErrorForType("value and then }", t) | ||||
| 			} | ||||
| 		default: | ||||
| 			return nil, invalidRequestError(t.String()) | ||||
| 		} | ||||
| 	case bsontype.JavaScript: | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawKey: | ||||
| 			// read colon | ||||
| 			ejp.advanceState() | ||||
| 			if err := ensureColon(ejp.s, ejp.k); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			// read value | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawValue { | ||||
| 				return nil, invalidJSONErrorForType("value", t) | ||||
| 			} | ||||
| 			v = ejp.v | ||||
|  | ||||
| 			// read end object or comma and just return | ||||
| 			ejp.advanceState() | ||||
| 		case jpsSawEndObject: | ||||
| 			v = ejp.v | ||||
| 		default: | ||||
| 			return nil, invalidRequestError(t.String()) | ||||
| 		} | ||||
| 	case bsontype.CodeWithScope: | ||||
| 		if ejp.s == jpsSawKey && ejp.k == "$scope" { | ||||
| 			v = ejp.v // this is the $code string from earlier | ||||
|  | ||||
| 			// read colon | ||||
| 			ejp.advanceState() | ||||
| 			if err := ensureColon(ejp.s, ejp.k); err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			// read { | ||||
| 			ejp.advanceState() | ||||
| 			if ejp.s != jpsSawBeginObject { | ||||
| 				return nil, invalidJSONError("$scope to be embedded document") | ||||
| 			} | ||||
| 		} else { | ||||
| 			return nil, invalidRequestError(t.String()) | ||||
| 		} | ||||
| 	case bsontype.EmbeddedDocument, bsontype.Array: | ||||
| 		return nil, invalidRequestError(t.String()) | ||||
| 	} | ||||
|  | ||||
| 	return v, nil | ||||
| } | ||||
|  | ||||
| // readObject is a utility method for reading full objects of known (or expected) size | ||||
| // it is useful for extended JSON types such as binary, datetime, regex, and timestamp | ||||
| func (ejp *extJSONParser) readObject(numKeys int, started bool) ([]string, []*extJSONValue, error) { | ||||
| 	keys := make([]string, numKeys) | ||||
| 	vals := make([]*extJSONValue, numKeys) | ||||
|  | ||||
| 	if !started { | ||||
| 		ejp.advanceState() | ||||
| 		if ejp.s != jpsSawBeginObject { | ||||
| 			return nil, nil, invalidJSONError("{") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	for i := 0; i < numKeys; i++ { | ||||
| 		key, t, err := ejp.readKey() | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 		} | ||||
|  | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawKey: | ||||
| 			v, err := ejp.readValue(t) | ||||
| 			if err != nil { | ||||
| 				return nil, nil, err | ||||
| 			} | ||||
|  | ||||
| 			keys[i] = key | ||||
| 			vals[i] = v | ||||
| 		case jpsSawValue: | ||||
| 			keys[i] = key | ||||
| 			vals[i] = ejp.v | ||||
| 		default: | ||||
| 			return nil, nil, invalidJSONError("value") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ejp.advanceState() | ||||
| 	if ejp.s != jpsSawEndObject { | ||||
| 		return nil, nil, invalidJSONError("}") | ||||
| 	} | ||||
|  | ||||
| 	return keys, vals, nil | ||||
| } | ||||
|  | ||||
| // advanceState reads the next JSON token from the scanner and transitions | ||||
| // from the current state based on that token's type | ||||
| func (ejp *extJSONParser) advanceState() { | ||||
| 	if ejp.s == jpsDoneState || ejp.s == jpsInvalidState { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	jt, err := ejp.js.nextToken() | ||||
|  | ||||
| 	if err != nil { | ||||
| 		ejp.err = err | ||||
| 		ejp.s = jpsInvalidState | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	valid := ejp.validateToken(jt.t) | ||||
| 	if !valid { | ||||
| 		ejp.err = unexpectedTokenError(jt) | ||||
| 		ejp.s = jpsInvalidState | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	switch jt.t { | ||||
| 	case jttBeginObject: | ||||
| 		ejp.s = jpsSawBeginObject | ||||
| 		ejp.pushMode(jpmObjectMode) | ||||
| 		ejp.depth++ | ||||
|  | ||||
| 		if ejp.depth > ejp.maxDepth { | ||||
| 			ejp.err = nestingDepthError(jt.p, ejp.depth) | ||||
| 			ejp.s = jpsInvalidState | ||||
| 		} | ||||
| 	case jttEndObject: | ||||
| 		ejp.s = jpsSawEndObject | ||||
| 		ejp.depth-- | ||||
|  | ||||
| 		if ejp.popMode() != jpmObjectMode { | ||||
| 			ejp.err = unexpectedTokenError(jt) | ||||
| 			ejp.s = jpsInvalidState | ||||
| 		} | ||||
| 	case jttBeginArray: | ||||
| 		ejp.s = jpsSawBeginArray | ||||
| 		ejp.pushMode(jpmArrayMode) | ||||
| 	case jttEndArray: | ||||
| 		ejp.s = jpsSawEndArray | ||||
|  | ||||
| 		if ejp.popMode() != jpmArrayMode { | ||||
| 			ejp.err = unexpectedTokenError(jt) | ||||
| 			ejp.s = jpsInvalidState | ||||
| 		} | ||||
| 	case jttColon: | ||||
| 		ejp.s = jpsSawColon | ||||
| 	case jttComma: | ||||
| 		ejp.s = jpsSawComma | ||||
| 	case jttEOF: | ||||
| 		ejp.s = jpsDoneState | ||||
| 		if len(ejp.m) != 0 { | ||||
| 			ejp.err = unexpectedTokenError(jt) | ||||
| 			ejp.s = jpsInvalidState | ||||
| 		} | ||||
| 	case jttString: | ||||
| 		switch ejp.s { | ||||
| 		case jpsSawComma: | ||||
| 			if ejp.peekMode() == jpmArrayMode { | ||||
| 				ejp.s = jpsSawValue | ||||
| 				ejp.v = extendJSONToken(jt) | ||||
| 				return | ||||
| 			} | ||||
| 			fallthrough | ||||
| 		case jpsSawBeginObject: | ||||
| 			ejp.s = jpsSawKey | ||||
| 			ejp.k = jt.v.(string) | ||||
| 			return | ||||
| 		} | ||||
| 		fallthrough | ||||
| 	default: | ||||
| 		ejp.s = jpsSawValue | ||||
| 		ejp.v = extendJSONToken(jt) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| var jpsValidTransitionTokens = map[jsonParseState]map[jsonTokenType]bool{ | ||||
| 	jpsStartState: { | ||||
| 		jttBeginObject: true, | ||||
| 		jttBeginArray:  true, | ||||
| 		jttInt32:       true, | ||||
| 		jttInt64:       true, | ||||
| 		jttDouble:      true, | ||||
| 		jttString:      true, | ||||
| 		jttBool:        true, | ||||
| 		jttNull:        true, | ||||
| 		jttEOF:         true, | ||||
| 	}, | ||||
| 	jpsSawBeginObject: { | ||||
| 		jttEndObject: true, | ||||
| 		jttString:    true, | ||||
| 	}, | ||||
| 	jpsSawEndObject: { | ||||
| 		jttEndObject: true, | ||||
| 		jttEndArray:  true, | ||||
| 		jttComma:     true, | ||||
| 		jttEOF:       true, | ||||
| 	}, | ||||
| 	jpsSawBeginArray: { | ||||
| 		jttBeginObject: true, | ||||
| 		jttBeginArray:  true, | ||||
| 		jttEndArray:    true, | ||||
| 		jttInt32:       true, | ||||
| 		jttInt64:       true, | ||||
| 		jttDouble:      true, | ||||
| 		jttString:      true, | ||||
| 		jttBool:        true, | ||||
| 		jttNull:        true, | ||||
| 	}, | ||||
| 	jpsSawEndArray: { | ||||
| 		jttEndObject: true, | ||||
| 		jttEndArray:  true, | ||||
| 		jttComma:     true, | ||||
| 		jttEOF:       true, | ||||
| 	}, | ||||
| 	jpsSawColon: { | ||||
| 		jttBeginObject: true, | ||||
| 		jttBeginArray:  true, | ||||
| 		jttInt32:       true, | ||||
| 		jttInt64:       true, | ||||
| 		jttDouble:      true, | ||||
| 		jttString:      true, | ||||
| 		jttBool:        true, | ||||
| 		jttNull:        true, | ||||
| 	}, | ||||
| 	jpsSawComma: { | ||||
| 		jttBeginObject: true, | ||||
| 		jttBeginArray:  true, | ||||
| 		jttInt32:       true, | ||||
| 		jttInt64:       true, | ||||
| 		jttDouble:      true, | ||||
| 		jttString:      true, | ||||
| 		jttBool:        true, | ||||
| 		jttNull:        true, | ||||
| 	}, | ||||
| 	jpsSawKey: { | ||||
| 		jttColon: true, | ||||
| 	}, | ||||
| 	jpsSawValue: { | ||||
| 		jttEndObject: true, | ||||
| 		jttEndArray:  true, | ||||
| 		jttComma:     true, | ||||
| 		jttEOF:       true, | ||||
| 	}, | ||||
| 	jpsDoneState:    {}, | ||||
| 	jpsInvalidState: {}, | ||||
| } | ||||
|  | ||||
| func (ejp *extJSONParser) validateToken(jtt jsonTokenType) bool { | ||||
| 	switch ejp.s { | ||||
| 	case jpsSawEndObject: | ||||
| 		// if we are at depth zero and the next token is a '{', | ||||
| 		// we can consider it valid only if we are not in array mode. | ||||
| 		if jtt == jttBeginObject && ejp.depth == 0 { | ||||
| 			return ejp.peekMode() != jpmArrayMode | ||||
| 		} | ||||
| 	case jpsSawComma: | ||||
| 		switch ejp.peekMode() { | ||||
| 		// the only valid next token after a comma inside a document is a string (a key) | ||||
| 		case jpmObjectMode: | ||||
| 			return jtt == jttString | ||||
| 		case jpmInvalidMode: | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	_, ok := jpsValidTransitionTokens[ejp.s][jtt] | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| // ensureExtValueType returns true if the current value has the expected | ||||
| // value type for single-key extended JSON types. For example, | ||||
| // {"$numberInt": v} v must be TypeString | ||||
| func (ejp *extJSONParser) ensureExtValueType(t bsontype.Type) bool { | ||||
| 	switch t { | ||||
| 	case bsontype.MinKey, bsontype.MaxKey: | ||||
| 		return ejp.v.t == bsontype.Int32 | ||||
| 	case bsontype.Undefined: | ||||
| 		return ejp.v.t == bsontype.Boolean | ||||
| 	case bsontype.Int32, bsontype.Int64, bsontype.Double, bsontype.Decimal128, bsontype.Symbol, bsontype.ObjectID: | ||||
| 		return ejp.v.t == bsontype.String | ||||
| 	default: | ||||
| 		return false | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ejp *extJSONParser) pushMode(m jsonParseMode) { | ||||
| 	ejp.m = append(ejp.m, m) | ||||
| } | ||||
|  | ||||
| func (ejp *extJSONParser) popMode() jsonParseMode { | ||||
| 	l := len(ejp.m) | ||||
| 	if l == 0 { | ||||
| 		return jpmInvalidMode | ||||
| 	} | ||||
|  | ||||
| 	m := ejp.m[l-1] | ||||
| 	ejp.m = ejp.m[:l-1] | ||||
|  | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| func (ejp *extJSONParser) peekMode() jsonParseMode { | ||||
| 	l := len(ejp.m) | ||||
| 	if l == 0 { | ||||
| 		return jpmInvalidMode | ||||
| 	} | ||||
|  | ||||
| 	return ejp.m[l-1] | ||||
| } | ||||
|  | ||||
| func extendJSONToken(jt *jsonToken) *extJSONValue { | ||||
| 	var t bsontype.Type | ||||
|  | ||||
| 	switch jt.t { | ||||
| 	case jttInt32: | ||||
| 		t = bsontype.Int32 | ||||
| 	case jttInt64: | ||||
| 		t = bsontype.Int64 | ||||
| 	case jttDouble: | ||||
| 		t = bsontype.Double | ||||
| 	case jttString: | ||||
| 		t = bsontype.String | ||||
| 	case jttBool: | ||||
| 		t = bsontype.Boolean | ||||
| 	case jttNull: | ||||
| 		t = bsontype.Null | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return &extJSONValue{t: t, v: jt.v} | ||||
| } | ||||
|  | ||||
| func ensureColon(s jsonParseState, key string) error { | ||||
| 	if s != jpsSawColon { | ||||
| 		return fmt.Errorf("invalid JSON input: missing colon after key \"%s\"", key) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func invalidRequestError(s string) error { | ||||
| 	return fmt.Errorf("invalid request to read %s", s) | ||||
| } | ||||
|  | ||||
| func invalidJSONError(expected string) error { | ||||
| 	return fmt.Errorf("invalid JSON input; expected %s", expected) | ||||
| } | ||||
|  | ||||
| func invalidJSONErrorForType(expected string, t bsontype.Type) error { | ||||
| 	return fmt.Errorf("invalid JSON input; expected %s for %s", expected, t) | ||||
| } | ||||
|  | ||||
| func unexpectedTokenError(jt *jsonToken) error { | ||||
| 	switch jt.t { | ||||
| 	case jttInt32, jttInt64, jttDouble: | ||||
| 		return fmt.Errorf("invalid JSON input; unexpected number (%v) at position %d", jt.v, jt.p) | ||||
| 	case jttString: | ||||
| 		return fmt.Errorf("invalid JSON input; unexpected string (\"%v\") at position %d", jt.v, jt.p) | ||||
| 	case jttBool: | ||||
| 		return fmt.Errorf("invalid JSON input; unexpected boolean literal (%v) at position %d", jt.v, jt.p) | ||||
| 	case jttNull: | ||||
| 		return fmt.Errorf("invalid JSON input; unexpected null literal at position %d", jt.p) | ||||
| 	case jttEOF: | ||||
| 		return fmt.Errorf("invalid JSON input; unexpected end of input at position %d", jt.p) | ||||
| 	default: | ||||
| 		return fmt.Errorf("invalid JSON input; unexpected %c at position %d", jt.v.(byte), jt.p) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func nestingDepthError(p, depth int) error { | ||||
| 	return fmt.Errorf("invalid JSON input; nesting too deep (%d levels) at position %d", depth, p) | ||||
| } | ||||
							
								
								
									
										788
									
								
								mongo/bson/bsonrw/extjson_parser_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										788
									
								
								mongo/bson/bsonrw/extjson_parser_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,788 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	keyDiff = specificDiff("key") | ||||
| 	typDiff = specificDiff("type") | ||||
| 	valDiff = specificDiff("value") | ||||
|  | ||||
| 	expectErrEOF = expectSpecificError(io.EOF) | ||||
| 	expectErrEOD = expectSpecificError(ErrEOD) | ||||
| 	expectErrEOA = expectSpecificError(ErrEOA) | ||||
| ) | ||||
|  | ||||
| type expectedErrorFunc func(t *testing.T, err error, desc string) | ||||
|  | ||||
| type peekTypeTestCase struct { | ||||
| 	desc  string | ||||
| 	input string | ||||
| 	typs  []bsontype.Type | ||||
| 	errFs []expectedErrorFunc | ||||
| } | ||||
|  | ||||
| type readKeyValueTestCase struct { | ||||
| 	desc  string | ||||
| 	input string | ||||
| 	keys  []string | ||||
| 	typs  []bsontype.Type | ||||
| 	vals  []*extJSONValue | ||||
|  | ||||
| 	keyEFs []expectedErrorFunc | ||||
| 	valEFs []expectedErrorFunc | ||||
| } | ||||
|  | ||||
| func expectSpecificError(expected error) expectedErrorFunc { | ||||
| 	return func(t *testing.T, err error, desc string) { | ||||
| 		if err != expected { | ||||
| 			t.Helper() | ||||
| 			t.Errorf("%s: Expected %v but got: %v", desc, expected, err) | ||||
| 			t.FailNow() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func specificDiff(name string) func(t *testing.T, expected, actual interface{}, desc string) { | ||||
| 	return func(t *testing.T, expected, actual interface{}, desc string) { | ||||
| 		if diff := cmp.Diff(expected, actual); diff != "" { | ||||
| 			t.Helper() | ||||
| 			t.Errorf("%s: Incorrect JSON %s (-want, +got): %s\n", desc, name, diff) | ||||
| 			t.FailNow() | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectErrorNOOP(_ *testing.T, _ error, _ string) { | ||||
| } | ||||
|  | ||||
| func readKeyDiff(t *testing.T, eKey, aKey string, eTyp, aTyp bsontype.Type, err error, errF expectedErrorFunc, desc string) { | ||||
| 	keyDiff(t, eKey, aKey, desc) | ||||
| 	typDiff(t, eTyp, aTyp, desc) | ||||
| 	errF(t, err, desc) | ||||
| } | ||||
|  | ||||
| func readValueDiff(t *testing.T, eVal, aVal *extJSONValue, err error, errF expectedErrorFunc, desc string) { | ||||
| 	if aVal != nil { | ||||
| 		typDiff(t, eVal.t, aVal.t, desc) | ||||
| 		valDiff(t, eVal.v, aVal.v, desc) | ||||
| 	} else { | ||||
| 		valDiff(t, eVal, aVal, desc) | ||||
| 	} | ||||
|  | ||||
| 	errF(t, err, desc) | ||||
| } | ||||
|  | ||||
| func TestExtJSONParserPeekType(t *testing.T) { | ||||
| 	makeValidPeekTypeTestCase := func(input string, typ bsontype.Type, desc string) peekTypeTestCase { | ||||
| 		return peekTypeTestCase{ | ||||
| 			desc: desc, input: input, | ||||
| 			typs:  []bsontype.Type{typ}, | ||||
| 			errFs: []expectedErrorFunc{expectNoError}, | ||||
| 		} | ||||
| 	} | ||||
| 	makeInvalidTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase { | ||||
| 		return peekTypeTestCase{ | ||||
| 			desc: desc, input: input, | ||||
| 			typs:  []bsontype.Type{bsontype.Type(0)}, | ||||
| 			errFs: []expectedErrorFunc{lastEF}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	makeInvalidPeekTypeTestCase := func(desc, input string, lastEF expectedErrorFunc) peekTypeTestCase { | ||||
| 		return peekTypeTestCase{ | ||||
| 			desc: desc, input: input, | ||||
| 			typs:  []bsontype.Type{bsontype.Array, bsontype.String, bsontype.Type(0)}, | ||||
| 			errFs: []expectedErrorFunc{expectNoError, expectNoError, lastEF}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cases := []peekTypeTestCase{ | ||||
| 		makeValidPeekTypeTestCase(`null`, bsontype.Null, "Null"), | ||||
| 		makeValidPeekTypeTestCase(`"string"`, bsontype.String, "String"), | ||||
| 		makeValidPeekTypeTestCase(`true`, bsontype.Boolean, "Boolean--true"), | ||||
| 		makeValidPeekTypeTestCase(`false`, bsontype.Boolean, "Boolean--false"), | ||||
| 		makeValidPeekTypeTestCase(`{"$minKey": 1}`, bsontype.MinKey, "MinKey"), | ||||
| 		makeValidPeekTypeTestCase(`{"$maxKey": 1}`, bsontype.MaxKey, "MaxKey"), | ||||
| 		makeValidPeekTypeTestCase(`{"$numberInt": "42"}`, bsontype.Int32, "Int32"), | ||||
| 		makeValidPeekTypeTestCase(`{"$numberLong": "42"}`, bsontype.Int64, "Int64"), | ||||
| 		makeValidPeekTypeTestCase(`{"$symbol": "symbol"}`, bsontype.Symbol, "Symbol"), | ||||
| 		makeValidPeekTypeTestCase(`{"$numberDouble": "42.42"}`, bsontype.Double, "Double"), | ||||
| 		makeValidPeekTypeTestCase(`{"$undefined": true}`, bsontype.Undefined, "Undefined"), | ||||
| 		makeValidPeekTypeTestCase(`{"$numberDouble": "NaN"}`, bsontype.Double, "Double--NaN"), | ||||
| 		makeValidPeekTypeTestCase(`{"$numberDecimal": "1234"}`, bsontype.Decimal128, "Decimal"), | ||||
| 		makeValidPeekTypeTestCase(`{"foo": "bar"}`, bsontype.EmbeddedDocument, "Toplevel document"), | ||||
| 		makeValidPeekTypeTestCase(`{"$date": {"$numberLong": "0"}}`, bsontype.DateTime, "Datetime"), | ||||
| 		makeValidPeekTypeTestCase(`{"$code": "function() {}"}`, bsontype.JavaScript, "Code no scope"), | ||||
| 		makeValidPeekTypeTestCase(`[{"$numberInt": "1"},{"$numberInt": "2"}]`, bsontype.Array, "Array"), | ||||
| 		makeValidPeekTypeTestCase(`{"$timestamp": {"t": 42, "i": 1}}`, bsontype.Timestamp, "Timestamp"), | ||||
| 		makeValidPeekTypeTestCase(`{"$oid": "57e193d7a9cc81b4027498b5"}`, bsontype.ObjectID, "Object ID"), | ||||
| 		makeValidPeekTypeTestCase(`{"$binary": {"base64": "AQIDBAU=", "subType": "80"}}`, bsontype.Binary, "Binary"), | ||||
| 		makeValidPeekTypeTestCase(`{"$code": "function() {}", "$scope": {}}`, bsontype.CodeWithScope, "Code With Scope"), | ||||
| 		makeValidPeekTypeTestCase(`{"$binary": {"base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03"}}`, bsontype.Binary, "Binary"), | ||||
| 		makeValidPeekTypeTestCase(`{"$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03"}`, bsontype.Binary, "Binary"), | ||||
| 		makeValidPeekTypeTestCase(`{"$regularExpression": {"pattern": "foo*", "options": "ix"}}`, bsontype.Regex, "Regular expression"), | ||||
| 		makeValidPeekTypeTestCase(`{"$dbPointer": {"$ref": "db.collection", "$id": {"$oid": "57e193d7a9cc81b4027498b1"}}}`, bsontype.DBPointer, "DBPointer"), | ||||
| 		makeValidPeekTypeTestCase(`{"$ref": "collection", "$id": {"$oid": "57fd71e96e32ab4225b723fb"}, "$db": "database"}`, bsontype.EmbeddedDocument, "DBRef"), | ||||
| 		makeInvalidPeekTypeTestCase("invalid array--missing ]", `["a"`, expectError), | ||||
| 		makeInvalidPeekTypeTestCase("invalid array--colon in array", `["a":`, expectError), | ||||
| 		makeInvalidPeekTypeTestCase("invalid array--extra comma", `["a",,`, expectError), | ||||
| 		makeInvalidPeekTypeTestCase("invalid array--trailing comma", `["a",]`, expectError), | ||||
| 		makeInvalidPeekTypeTestCase("peekType after end of array", `["a"]`, expectErrEOA), | ||||
| 		{ | ||||
| 			desc:  "invalid array--leading comma", | ||||
| 			input: `[,`, | ||||
| 			typs:  []bsontype.Type{bsontype.Array, bsontype.Type(0)}, | ||||
| 			errFs: []expectedErrorFunc{expectNoError, expectError}, | ||||
| 		}, | ||||
| 		makeInvalidTestCase("lone $scope", `{"$scope": {}}`, expectError), | ||||
| 		makeInvalidTestCase("empty code with unknown extra key", `{"$code":"", "0":""}`, expectError), | ||||
| 		makeInvalidTestCase("non-empty code with unknown extra key", `{"$code":"foobar", "0":""}`, expectError), | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		t.Run(tc.desc, func(t *testing.T) { | ||||
| 			ejp := newExtJSONParser(strings.NewReader(tc.input), true) | ||||
| 			// Manually set the parser's starting state to jpsSawColon so peekType will read ahead to find the extjson | ||||
| 			// type of the value. If not set, the parser will be in jpsStartState and advance to jpsSawKey, which will | ||||
| 			// cause it to return without peeking the extjson type. | ||||
| 			ejp.s = jpsSawColon | ||||
|  | ||||
| 			for i, eTyp := range tc.typs { | ||||
| 				errF := tc.errFs[i] | ||||
|  | ||||
| 				typ, err := ejp.peekType() | ||||
| 				errF(t, err, tc.desc) | ||||
| 				if err != nil { | ||||
| 					// Don't inspect the type if there was an error | ||||
| 					return | ||||
| 				} | ||||
|  | ||||
| 				typDiff(t, eTyp, typ, tc.desc) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExtJSONParserReadKeyReadValue(t *testing.T) { | ||||
| 	// several test cases will use the same keys, types, and values, and only differ on input structure | ||||
|  | ||||
| 	keys := []string{"_id", "Symbol", "String", "Int32", "Int64", "Int", "MinKey"} | ||||
| 	types := []bsontype.Type{bsontype.ObjectID, bsontype.Symbol, bsontype.String, bsontype.Int32, bsontype.Int64, bsontype.Int32, bsontype.MinKey} | ||||
| 	values := []*extJSONValue{ | ||||
| 		{t: bsontype.String, v: "57e193d7a9cc81b4027498b5"}, | ||||
| 		{t: bsontype.String, v: "symbol"}, | ||||
| 		{t: bsontype.String, v: "string"}, | ||||
| 		{t: bsontype.String, v: "42"}, | ||||
| 		{t: bsontype.String, v: "42"}, | ||||
| 		{t: bsontype.Int32, v: int32(42)}, | ||||
| 		{t: bsontype.Int32, v: int32(1)}, | ||||
| 	} | ||||
|  | ||||
| 	errFuncs := make([]expectedErrorFunc, 7) | ||||
| 	for i := 0; i < 7; i++ { | ||||
| 		errFuncs[i] = expectNoError | ||||
| 	} | ||||
|  | ||||
| 	firstKeyError := func(desc, input string) readKeyValueTestCase { | ||||
| 		return readKeyValueTestCase{ | ||||
| 			desc:   desc, | ||||
| 			input:  input, | ||||
| 			keys:   []string{""}, | ||||
| 			typs:   []bsontype.Type{bsontype.Type(0)}, | ||||
| 			vals:   []*extJSONValue{nil}, | ||||
| 			keyEFs: []expectedErrorFunc{expectError}, | ||||
| 			valEFs: []expectedErrorFunc{expectErrorNOOP}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	secondKeyError := func(desc, input, firstKey string, firstType bsontype.Type, firstValue *extJSONValue) readKeyValueTestCase { | ||||
| 		return readKeyValueTestCase{ | ||||
| 			desc:   desc, | ||||
| 			input:  input, | ||||
| 			keys:   []string{firstKey, ""}, | ||||
| 			typs:   []bsontype.Type{firstType, bsontype.Type(0)}, | ||||
| 			vals:   []*extJSONValue{firstValue, nil}, | ||||
| 			keyEFs: []expectedErrorFunc{expectNoError, expectError}, | ||||
| 			valEFs: []expectedErrorFunc{expectNoError, expectErrorNOOP}, | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	cases := []readKeyValueTestCase{ | ||||
| 		{ | ||||
| 			desc: "normal spacing", | ||||
| 			input: `{ | ||||
| 					"_id": { "$oid": "57e193d7a9cc81b4027498b5" }, | ||||
| 					"Symbol": { "$symbol": "symbol" }, | ||||
| 					"String": "string", | ||||
| 					"Int32": { "$numberInt": "42" }, | ||||
| 					"Int64": { "$numberLong": "42" }, | ||||
| 					"Int": 42, | ||||
| 					"MinKey": { "$minKey": 1 } | ||||
| 				}`, | ||||
| 			keys: keys, typs: types, vals: values, | ||||
| 			keyEFs: errFuncs, valEFs: errFuncs, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "new line before comma", | ||||
| 			input: `{ "_id": { "$oid": "57e193d7a9cc81b4027498b5" } | ||||
| 				 , "Symbol": { "$symbol": "symbol" } | ||||
| 				 , "String": "string" | ||||
| 				 , "Int32": { "$numberInt": "42" } | ||||
| 				 , "Int64": { "$numberLong": "42" } | ||||
| 				 , "Int": 42 | ||||
| 				 , "MinKey": { "$minKey": 1 } | ||||
| 				 }`, | ||||
| 			keys: keys, typs: types, vals: values, | ||||
| 			keyEFs: errFuncs, valEFs: errFuncs, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "tabs around colons", | ||||
| 			input: `{ | ||||
| 					"_id":    { "$oid"       : "57e193d7a9cc81b4027498b5" }, | ||||
| 					"Symbol": { "$symbol"    : "symbol" }, | ||||
| 					"String": "string", | ||||
| 					"Int32":  { "$numberInt" : "42" }, | ||||
| 					"Int64":  { "$numberLong": "42" }, | ||||
| 					"Int":    42, | ||||
| 					"MinKey": { "$minKey": 1 } | ||||
| 				}`, | ||||
| 			keys: keys, typs: types, vals: values, | ||||
| 			keyEFs: errFuncs, valEFs: errFuncs, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "no whitespace", | ||||
| 			input: `{"_id":{"$oid":"57e193d7a9cc81b4027498b5"},"Symbol":{"$symbol":"symbol"},"String":"string","Int32":{"$numberInt":"42"},"Int64":{"$numberLong":"42"},"Int":42,"MinKey":{"$minKey":1}}`, | ||||
| 			keys:  keys, typs: types, vals: values, | ||||
| 			keyEFs: errFuncs, valEFs: errFuncs, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "mixed whitespace", | ||||
| 			input: `	{ | ||||
| 					"_id"		: { "$oid": "57e193d7a9cc81b4027498b5" }, | ||||
| 			        "Symbol"	: { "$symbol": "symbol" }	, | ||||
| 				    "String"	: "string", | ||||
| 					"Int32"		: { "$numberInt": "42" }    , | ||||
| 					"Int64"		: {"$numberLong" : "42"}, | ||||
| 					"Int"		: 42, | ||||
| 			      	"MinKey"	: { "$minKey": 1 } 	}	`, | ||||
| 			keys: keys, typs: types, vals: values, | ||||
| 			keyEFs: errFuncs, valEFs: errFuncs, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "nested object", | ||||
| 			input: `{"k1": 1, "k2": { "k3": { "k4": 4 } }, "k5": 5}`, | ||||
| 			keys:  []string{"k1", "k2", "k3", "k4", "", "", "k5", ""}, | ||||
| 			typs:  []bsontype.Type{bsontype.Int32, bsontype.EmbeddedDocument, bsontype.EmbeddedDocument, bsontype.Int32, bsontype.Type(0), bsontype.Type(0), bsontype.Int32, bsontype.Type(0)}, | ||||
| 			vals: []*extJSONValue{ | ||||
| 				{t: bsontype.Int32, v: int32(1)}, nil, nil, {t: bsontype.Int32, v: int32(4)}, nil, nil, {t: bsontype.Int32, v: int32(5)}, nil, | ||||
| 			}, | ||||
| 			keyEFs: []expectedErrorFunc{ | ||||
| 				expectNoError, expectNoError, expectNoError, expectNoError, expectErrEOD, | ||||
| 				expectErrEOD, expectNoError, expectErrEOD, | ||||
| 			}, | ||||
| 			valEFs: []expectedErrorFunc{ | ||||
| 				expectNoError, expectError, expectError, expectNoError, expectErrorNOOP, | ||||
| 				expectErrorNOOP, expectNoError, expectErrorNOOP, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "invalid input: invalid values for extended type", | ||||
| 			input:  `{"a": {"$numberInt": "1", "x"`, | ||||
| 			keys:   []string{"a"}, | ||||
| 			typs:   []bsontype.Type{bsontype.Int32}, | ||||
| 			vals:   []*extJSONValue{nil}, | ||||
| 			keyEFs: []expectedErrorFunc{expectNoError}, | ||||
| 			valEFs: []expectedErrorFunc{expectError}, | ||||
| 		}, | ||||
| 		firstKeyError("invalid input: missing key--EOF", "{"), | ||||
| 		firstKeyError("invalid input: missing key--colon first", "{:"), | ||||
| 		firstKeyError("invalid input: missing value", `{"a":`), | ||||
| 		firstKeyError("invalid input: missing colon", `{"a" 1`), | ||||
| 		firstKeyError("invalid input: extra colon", `{"a"::`), | ||||
| 		secondKeyError("invalid input: missing }", `{"a": 1`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}), | ||||
| 		secondKeyError("invalid input: missing comma", `{"a": 1 "b"`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}), | ||||
| 		secondKeyError("invalid input: extra comma", `{"a": 1,, "b"`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}), | ||||
| 		secondKeyError("invalid input: trailing comma in object", `{"a": 1,}`, "a", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}), | ||||
| 		{ | ||||
| 			desc:   "invalid input: lone scope after a complete value", | ||||
| 			input:  `{"a": "", "b": {"$scope: ""}}`, | ||||
| 			keys:   []string{"a"}, | ||||
| 			typs:   []bsontype.Type{bsontype.String}, | ||||
| 			vals:   []*extJSONValue{{bsontype.String, ""}}, | ||||
| 			keyEFs: []expectedErrorFunc{expectNoError, expectNoError}, | ||||
| 			valEFs: []expectedErrorFunc{expectNoError, expectError}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "invalid input: lone scope nested", | ||||
| 			input:  `{"a":{"b":{"$scope":{`, | ||||
| 			keys:   []string{}, | ||||
| 			typs:   []bsontype.Type{}, | ||||
| 			vals:   []*extJSONValue{nil}, | ||||
| 			keyEFs: []expectedErrorFunc{expectNoError}, | ||||
| 			valEFs: []expectedErrorFunc{expectError}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		t.Run(tc.desc, func(t *testing.T) { | ||||
| 			ejp := newExtJSONParser(strings.NewReader(tc.input), true) | ||||
|  | ||||
| 			for i, eKey := range tc.keys { | ||||
| 				eTyp := tc.typs[i] | ||||
| 				eVal := tc.vals[i] | ||||
|  | ||||
| 				keyErrF := tc.keyEFs[i] | ||||
| 				valErrF := tc.valEFs[i] | ||||
|  | ||||
| 				k, typ, err := ejp.readKey() | ||||
| 				readKeyDiff(t, eKey, k, eTyp, typ, err, keyErrF, tc.desc) | ||||
|  | ||||
| 				v, err := ejp.readValue(typ) | ||||
| 				readValueDiff(t, eVal, v, err, valErrF, tc.desc) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ejpExpectationTest func(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) | ||||
|  | ||||
| type ejpTestCase struct { | ||||
| 	f ejpExpectationTest | ||||
| 	p *extJSONParser | ||||
| 	k string | ||||
| 	t bsontype.Type | ||||
| 	v interface{} | ||||
| } | ||||
|  | ||||
| // expectSingleValue is used for simple JSON types (strings, numbers, literals) and for extended JSON types that | ||||
| // have single key-value pairs (i.e. { "$minKey": 1 }, { "$numberLong": "42.42" }) | ||||
| func expectSingleValue(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) { | ||||
| 	eVal := expectedValue.(*extJSONValue) | ||||
|  | ||||
| 	k, typ, err := p.readKey() | ||||
| 	readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey) | ||||
|  | ||||
| 	v, err := p.readValue(typ) | ||||
| 	readValueDiff(t, eVal, v, err, expectNoError, expectedKey) | ||||
| } | ||||
|  | ||||
| // expectMultipleValues is used for values that are subdocuments of known size and with known keys (such as extended | ||||
| // JSON types { "$timestamp": {"t": 1, "i": 1} } and { "$regularExpression": {"pattern": "", options: ""} }) | ||||
| func expectMultipleValues(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) { | ||||
| 	k, typ, err := p.readKey() | ||||
| 	readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey) | ||||
|  | ||||
| 	v, err := p.readValue(typ) | ||||
| 	expectNoError(t, err, "") | ||||
| 	typDiff(t, bsontype.EmbeddedDocument, v.t, expectedKey) | ||||
|  | ||||
| 	actObj := v.v.(*extJSONObject) | ||||
| 	expObj := expectedValue.(*extJSONObject) | ||||
|  | ||||
| 	for i, actKey := range actObj.keys { | ||||
| 		expKey := expObj.keys[i] | ||||
| 		actVal := actObj.values[i] | ||||
| 		expVal := expObj.values[i] | ||||
|  | ||||
| 		keyDiff(t, expKey, actKey, expectedKey) | ||||
| 		typDiff(t, expVal.t, actVal.t, expectedKey) | ||||
| 		valDiff(t, expVal.v, actVal.v, expectedKey) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ejpKeyTypValTriple struct { | ||||
| 	key string | ||||
| 	typ bsontype.Type | ||||
| 	val *extJSONValue | ||||
| } | ||||
|  | ||||
| type ejpSubDocumentTestValue struct { | ||||
| 	code string               // code is only used for TypeCodeWithScope (and is ignored for TypeEmbeddedDocument | ||||
| 	ktvs []ejpKeyTypValTriple // list of (key, type, value) triples; this is "scope" for TypeCodeWithScope | ||||
| } | ||||
|  | ||||
| // expectSubDocument is used for embedded documents and code with scope types; it reads all the keys and values | ||||
| // in the embedded document (or scope for codeWithScope) and compares them to the expectedValue's list of (key, type, | ||||
| // value) triples | ||||
| func expectSubDocument(t *testing.T, p *extJSONParser, expectedKey string, expectedType bsontype.Type, expectedValue interface{}) { | ||||
| 	subdoc := expectedValue.(ejpSubDocumentTestValue) | ||||
|  | ||||
| 	k, typ, err := p.readKey() | ||||
| 	readKeyDiff(t, expectedKey, k, expectedType, typ, err, expectNoError, expectedKey) | ||||
|  | ||||
| 	if expectedType == bsontype.CodeWithScope { | ||||
| 		v, err := p.readValue(typ) | ||||
| 		readValueDiff(t, &extJSONValue{t: bsontype.String, v: subdoc.code}, v, err, expectNoError, expectedKey) | ||||
| 	} | ||||
|  | ||||
| 	for _, ktv := range subdoc.ktvs { | ||||
| 		eKey := ktv.key | ||||
| 		eTyp := ktv.typ | ||||
| 		eVal := ktv.val | ||||
|  | ||||
| 		k, typ, err = p.readKey() | ||||
| 		readKeyDiff(t, eKey, k, eTyp, typ, err, expectNoError, expectedKey) | ||||
|  | ||||
| 		v, err := p.readValue(typ) | ||||
| 		readValueDiff(t, eVal, v, err, expectNoError, expectedKey) | ||||
| 	} | ||||
|  | ||||
| 	if expectedType == bsontype.CodeWithScope { | ||||
| 		// expect scope doc to close | ||||
| 		k, typ, err = p.readKey() | ||||
| 		readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOD, expectedKey) | ||||
| 	} | ||||
|  | ||||
| 	// expect subdoc to close | ||||
| 	k, typ, err = p.readKey() | ||||
| 	readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOD, expectedKey) | ||||
| } | ||||
|  | ||||
| // expectArray takes the expectedKey, ignores the expectedType, and uses the expectedValue | ||||
| // as a slice of (type Type, value *extJSONValue) pairs | ||||
| func expectArray(t *testing.T, p *extJSONParser, expectedKey string, _ bsontype.Type, expectedValue interface{}) { | ||||
| 	ktvs := expectedValue.([]ejpKeyTypValTriple) | ||||
|  | ||||
| 	k, typ, err := p.readKey() | ||||
| 	readKeyDiff(t, expectedKey, k, bsontype.Array, typ, err, expectNoError, expectedKey) | ||||
|  | ||||
| 	for _, ktv := range ktvs { | ||||
| 		eTyp := ktv.typ | ||||
| 		eVal := ktv.val | ||||
|  | ||||
| 		typ, err = p.peekType() | ||||
| 		typDiff(t, eTyp, typ, expectedKey) | ||||
| 		expectNoError(t, err, expectedKey) | ||||
|  | ||||
| 		v, err := p.readValue(typ) | ||||
| 		readValueDiff(t, eVal, v, err, expectNoError, expectedKey) | ||||
| 	} | ||||
|  | ||||
| 	// expect array to end | ||||
| 	typ, err = p.peekType() | ||||
| 	typDiff(t, bsontype.Type(0), typ, expectedKey) | ||||
| 	expectErrEOA(t, err, expectedKey) | ||||
| } | ||||
|  | ||||
| func TestExtJSONParserAllTypes(t *testing.T) { | ||||
| 	in := ` { "_id"					: { "$oid": "57e193d7a9cc81b4027498b5"} | ||||
| 			, "Symbol"				: { "$symbol": "symbol"} | ||||
| 			, "String"				: "string" | ||||
| 			, "Int32"				: { "$numberInt": "42"} | ||||
| 			, "Int64"				: { "$numberLong": "42"} | ||||
| 			, "Double"				: { "$numberDouble": "42.42"} | ||||
| 			, "SpecialFloat"		: { "$numberDouble": "NaN" } | ||||
| 			, "Decimal"				: { "$numberDecimal": "1234" } | ||||
| 			, "Binary"			 	: { "$binary": { "base64": "o0w498Or7cijeBSpkquNtg==", "subType": "03" } } | ||||
| 			, "BinaryLegacy"  : { "$binary": "o0w498Or7cijeBSpkquNtg==", "$type": "03" } | ||||
| 			, "BinaryUserDefined"	: { "$binary": { "base64": "AQIDBAU=", "subType": "80" } } | ||||
| 			, "Code"				: { "$code": "function() {}" } | ||||
| 			, "CodeWithEmptyScope"	: { "$code": "function() {}", "$scope": {} } | ||||
| 			, "CodeWithScope"		: { "$code": "function() {}", "$scope": { "x": 1 } } | ||||
| 			, "EmptySubdocument"    : {} | ||||
| 			, "Subdocument"			: { "foo": "bar", "baz": { "$numberInt": "42" } } | ||||
| 			, "Array"				: [{"$numberInt": "1"}, {"$numberLong": "2"}, {"$numberDouble": "3"}, 4, "string", 5.0] | ||||
| 			, "Timestamp"			: { "$timestamp": { "t": 42, "i": 1 } } | ||||
| 			, "RegularExpression"	: { "$regularExpression": { "pattern": "foo*", "options": "ix" } } | ||||
| 			, "DatetimeEpoch"		: { "$date": { "$numberLong": "0" } } | ||||
| 			, "DatetimePositive"	: { "$date": { "$numberLong": "9223372036854775807" } } | ||||
| 			, "DatetimeNegative"	: { "$date": { "$numberLong": "-9223372036854775808" } } | ||||
| 			, "True"				: true | ||||
| 			, "False"				: false | ||||
| 			, "DBPointer"			: { "$dbPointer": { "$ref": "db.collection", "$id": { "$oid": "57e193d7a9cc81b4027498b1" } } } | ||||
| 			, "DBRef"				: { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" }, "$db": "database" } | ||||
| 			, "DBRefNoDB"			: { "$ref": "collection", "$id": { "$oid": "57fd71e96e32ab4225b723fb" } } | ||||
| 			, "MinKey"				: { "$minKey": 1 } | ||||
| 			, "MaxKey"				: { "$maxKey": 1 } | ||||
| 			, "Null"				: null | ||||
| 			, "Undefined"			: { "$undefined": true } | ||||
| 			}` | ||||
|  | ||||
| 	ejp := newExtJSONParser(strings.NewReader(in), true) | ||||
|  | ||||
| 	cases := []ejpTestCase{ | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "_id", t: bsontype.ObjectID, v: &extJSONValue{t: bsontype.String, v: "57e193d7a9cc81b4027498b5"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Symbol", t: bsontype.Symbol, v: &extJSONValue{t: bsontype.String, v: "symbol"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "String", t: bsontype.String, v: &extJSONValue{t: bsontype.String, v: "string"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Int32", t: bsontype.Int32, v: &extJSONValue{t: bsontype.String, v: "42"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Int64", t: bsontype.Int64, v: &extJSONValue{t: bsontype.String, v: "42"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Double", t: bsontype.Double, v: &extJSONValue{t: bsontype.String, v: "42.42"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "SpecialFloat", t: bsontype.Double, v: &extJSONValue{t: bsontype.String, v: "NaN"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Decimal", t: bsontype.Decimal128, v: &extJSONValue{t: bsontype.String, v: "1234"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "Binary", t: bsontype.Binary, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"base64", "subType"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "o0w498Or7cijeBSpkquNtg=="}, | ||||
| 					{t: bsontype.String, v: "03"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "BinaryLegacy", t: bsontype.Binary, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"base64", "subType"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "o0w498Or7cijeBSpkquNtg=="}, | ||||
| 					{t: bsontype.String, v: "03"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "BinaryUserDefined", t: bsontype.Binary, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"base64", "subType"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "AQIDBAU="}, | ||||
| 					{t: bsontype.String, v: "80"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Code", t: bsontype.JavaScript, v: &extJSONValue{t: bsontype.String, v: "function() {}"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSubDocument, p: ejp, | ||||
| 			k: "CodeWithEmptyScope", t: bsontype.CodeWithScope, | ||||
| 			v: ejpSubDocumentTestValue{ | ||||
| 				code: "function() {}", | ||||
| 				ktvs: []ejpKeyTypValTriple{}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSubDocument, p: ejp, | ||||
| 			k: "CodeWithScope", t: bsontype.CodeWithScope, | ||||
| 			v: ejpSubDocumentTestValue{ | ||||
| 				code: "function() {}", | ||||
| 				ktvs: []ejpKeyTypValTriple{ | ||||
| 					{"x", bsontype.Int32, &extJSONValue{t: bsontype.Int32, v: int32(1)}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSubDocument, p: ejp, | ||||
| 			k: "EmptySubdocument", t: bsontype.EmbeddedDocument, | ||||
| 			v: ejpSubDocumentTestValue{ | ||||
| 				ktvs: []ejpKeyTypValTriple{}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSubDocument, p: ejp, | ||||
| 			k: "Subdocument", t: bsontype.EmbeddedDocument, | ||||
| 			v: ejpSubDocumentTestValue{ | ||||
| 				ktvs: []ejpKeyTypValTriple{ | ||||
| 					{"foo", bsontype.String, &extJSONValue{t: bsontype.String, v: "bar"}}, | ||||
| 					{"baz", bsontype.Int32, &extJSONValue{t: bsontype.String, v: "42"}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectArray, p: ejp, | ||||
| 			k: "Array", t: bsontype.Array, | ||||
| 			v: []ejpKeyTypValTriple{ | ||||
| 				{typ: bsontype.Int32, val: &extJSONValue{t: bsontype.String, v: "1"}}, | ||||
| 				{typ: bsontype.Int64, val: &extJSONValue{t: bsontype.String, v: "2"}}, | ||||
| 				{typ: bsontype.Double, val: &extJSONValue{t: bsontype.String, v: "3"}}, | ||||
| 				{typ: bsontype.Int32, val: &extJSONValue{t: bsontype.Int32, v: int32(4)}}, | ||||
| 				{typ: bsontype.String, val: &extJSONValue{t: bsontype.String, v: "string"}}, | ||||
| 				{typ: bsontype.Double, val: &extJSONValue{t: bsontype.Double, v: 5.0}}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "Timestamp", t: bsontype.Timestamp, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"t", "i"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.Int32, v: int32(42)}, | ||||
| 					{t: bsontype.Int32, v: int32(1)}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "RegularExpression", t: bsontype.Regex, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"pattern", "options"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "foo*"}, | ||||
| 					{t: bsontype.String, v: "ix"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "DatetimeEpoch", t: bsontype.DateTime, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"$numberLong"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "0"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "DatetimePositive", t: bsontype.DateTime, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"$numberLong"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "9223372036854775807"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "DatetimeNegative", t: bsontype.DateTime, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"$numberLong"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "-9223372036854775808"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "True", t: bsontype.Boolean, v: &extJSONValue{t: bsontype.Boolean, v: true}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "False", t: bsontype.Boolean, v: &extJSONValue{t: bsontype.Boolean, v: false}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectMultipleValues, p: ejp, | ||||
| 			k: "DBPointer", t: bsontype.DBPointer, | ||||
| 			v: &extJSONObject{ | ||||
| 				keys: []string{"$ref", "$id"}, | ||||
| 				values: []*extJSONValue{ | ||||
| 					{t: bsontype.String, v: "db.collection"}, | ||||
| 					{t: bsontype.String, v: "57e193d7a9cc81b4027498b1"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSubDocument, p: ejp, | ||||
| 			k: "DBRef", t: bsontype.EmbeddedDocument, | ||||
| 			v: ejpSubDocumentTestValue{ | ||||
| 				ktvs: []ejpKeyTypValTriple{ | ||||
| 					{"$ref", bsontype.String, &extJSONValue{t: bsontype.String, v: "collection"}}, | ||||
| 					{"$id", bsontype.ObjectID, &extJSONValue{t: bsontype.String, v: "57fd71e96e32ab4225b723fb"}}, | ||||
| 					{"$db", bsontype.String, &extJSONValue{t: bsontype.String, v: "database"}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSubDocument, p: ejp, | ||||
| 			k: "DBRefNoDB", t: bsontype.EmbeddedDocument, | ||||
| 			v: ejpSubDocumentTestValue{ | ||||
| 				ktvs: []ejpKeyTypValTriple{ | ||||
| 					{"$ref", bsontype.String, &extJSONValue{t: bsontype.String, v: "collection"}}, | ||||
| 					{"$id", bsontype.ObjectID, &extJSONValue{t: bsontype.String, v: "57fd71e96e32ab4225b723fb"}}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "MinKey", t: bsontype.MinKey, v: &extJSONValue{t: bsontype.Int32, v: int32(1)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "MaxKey", t: bsontype.MaxKey, v: &extJSONValue{t: bsontype.Int32, v: int32(1)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Null", t: bsontype.Null, v: &extJSONValue{t: bsontype.Null, v: nil}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			f: expectSingleValue, p: ejp, | ||||
| 			k: "Undefined", t: bsontype.Undefined, v: &extJSONValue{t: bsontype.Boolean, v: true}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// run the test cases | ||||
| 	for _, tc := range cases { | ||||
| 		tc.f(t, tc.p, tc.k, tc.t, tc.v) | ||||
| 	} | ||||
|  | ||||
| 	// expect end of whole document: read final } | ||||
| 	k, typ, err := ejp.readKey() | ||||
| 	readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOD, "") | ||||
|  | ||||
| 	// expect end of whole document: read EOF | ||||
| 	k, typ, err = ejp.readKey() | ||||
| 	readKeyDiff(t, "", k, bsontype.Type(0), typ, err, expectErrEOF, "") | ||||
| 	if diff := cmp.Diff(jpsDoneState, ejp.s); diff != "" { | ||||
| 		t.Errorf("expected parser to be in done state but instead is in %v\n", ejp.s) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestExtJSONValue(t *testing.T) { | ||||
| 	t.Run("Large Date", func(t *testing.T) { | ||||
| 		val := &extJSONValue{ | ||||
| 			t: bsontype.String, | ||||
| 			v: "3001-01-01T00:00:00Z", | ||||
| 		} | ||||
|  | ||||
| 		intVal, err := val.parseDateTime() | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("error parsing date time: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		if intVal <= 0 { | ||||
| 			t.Fatalf("expected value above 0, got %v", intVal) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("fallback time format", func(t *testing.T) { | ||||
| 		val := &extJSONValue{ | ||||
| 			t: bsontype.String, | ||||
| 			v: "2019-06-04T14:54:31.416+0000", | ||||
| 		} | ||||
|  | ||||
| 		_, err := val.parseDateTime() | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("error parsing date time: %v", err) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										644
									
								
								mongo/bson/bsonrw/extjson_reader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										644
									
								
								mongo/bson/bsonrw/extjson_reader.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,644 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| // ExtJSONValueReaderPool is a pool for ValueReaders that read ExtJSON. | ||||
| type ExtJSONValueReaderPool struct { | ||||
| 	pool sync.Pool | ||||
| } | ||||
|  | ||||
| // NewExtJSONValueReaderPool instantiates a new ExtJSONValueReaderPool. | ||||
| func NewExtJSONValueReaderPool() *ExtJSONValueReaderPool { | ||||
| 	return &ExtJSONValueReaderPool{ | ||||
| 		pool: sync.Pool{ | ||||
| 			New: func() interface{} { | ||||
| 				return new(extJSONValueReader) | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Get retrieves a ValueReader from the pool and uses src as the underlying ExtJSON. | ||||
| func (bvrp *ExtJSONValueReaderPool) Get(r io.Reader, canonical bool) (ValueReader, error) { | ||||
| 	vr := bvrp.pool.Get().(*extJSONValueReader) | ||||
| 	return vr.reset(r, canonical) | ||||
| } | ||||
|  | ||||
| // Put inserts a ValueReader into the pool. If the ValueReader is not a ExtJSON ValueReader nothing | ||||
| // is inserted into the pool and ok will be false. | ||||
| func (bvrp *ExtJSONValueReaderPool) Put(vr ValueReader) (ok bool) { | ||||
| 	bvr, ok := vr.(*extJSONValueReader) | ||||
| 	if !ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	bvr, _ = bvr.reset(nil, false) | ||||
| 	bvrp.pool.Put(bvr) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| type ejvrState struct { | ||||
| 	mode  mode | ||||
| 	vType bsontype.Type | ||||
| 	depth int | ||||
| } | ||||
|  | ||||
| // extJSONValueReader is for reading extended JSON. | ||||
| type extJSONValueReader struct { | ||||
| 	p *extJSONParser | ||||
|  | ||||
| 	stack []ejvrState | ||||
| 	frame int | ||||
| } | ||||
|  | ||||
| // NewExtJSONValueReader creates a new ValueReader from a given io.Reader | ||||
| // It will interpret the JSON of r as canonical or relaxed according to the | ||||
| // given canonical flag | ||||
| func NewExtJSONValueReader(r io.Reader, canonical bool) (ValueReader, error) { | ||||
| 	return newExtJSONValueReader(r, canonical) | ||||
| } | ||||
|  | ||||
| func newExtJSONValueReader(r io.Reader, canonical bool) (*extJSONValueReader, error) { | ||||
| 	ejvr := new(extJSONValueReader) | ||||
| 	return ejvr.reset(r, canonical) | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) reset(r io.Reader, canonical bool) (*extJSONValueReader, error) { | ||||
| 	p := newExtJSONParser(r, canonical) | ||||
| 	typ, err := p.peekType() | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil, ErrInvalidJSON | ||||
| 	} | ||||
|  | ||||
| 	var m mode | ||||
| 	switch typ { | ||||
| 	case bsontype.EmbeddedDocument: | ||||
| 		m = mTopLevel | ||||
| 	case bsontype.Array: | ||||
| 		m = mArray | ||||
| 	default: | ||||
| 		m = mValue | ||||
| 	} | ||||
|  | ||||
| 	stack := make([]ejvrState, 1, 5) | ||||
| 	stack[0] = ejvrState{ | ||||
| 		mode:  m, | ||||
| 		vType: typ, | ||||
| 	} | ||||
| 	return &extJSONValueReader{ | ||||
| 		p:     p, | ||||
| 		stack: stack, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) advanceFrame() { | ||||
| 	if ejvr.frame+1 >= len(ejvr.stack) { // We need to grow the stack | ||||
| 		length := len(ejvr.stack) | ||||
| 		if length+1 >= cap(ejvr.stack) { | ||||
| 			// double it | ||||
| 			buf := make([]ejvrState, 2*cap(ejvr.stack)+1) | ||||
| 			copy(buf, ejvr.stack) | ||||
| 			ejvr.stack = buf | ||||
| 		} | ||||
| 		ejvr.stack = ejvr.stack[:length+1] | ||||
| 	} | ||||
| 	ejvr.frame++ | ||||
|  | ||||
| 	// Clean the stack | ||||
| 	ejvr.stack[ejvr.frame].mode = 0 | ||||
| 	ejvr.stack[ejvr.frame].vType = 0 | ||||
| 	ejvr.stack[ejvr.frame].depth = 0 | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) pushDocument() { | ||||
| 	ejvr.advanceFrame() | ||||
|  | ||||
| 	ejvr.stack[ejvr.frame].mode = mDocument | ||||
| 	ejvr.stack[ejvr.frame].depth = ejvr.p.depth | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) pushCodeWithScope() { | ||||
| 	ejvr.advanceFrame() | ||||
|  | ||||
| 	ejvr.stack[ejvr.frame].mode = mCodeWithScope | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) pushArray() { | ||||
| 	ejvr.advanceFrame() | ||||
|  | ||||
| 	ejvr.stack[ejvr.frame].mode = mArray | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) push(m mode, t bsontype.Type) { | ||||
| 	ejvr.advanceFrame() | ||||
|  | ||||
| 	ejvr.stack[ejvr.frame].mode = m | ||||
| 	ejvr.stack[ejvr.frame].vType = t | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) pop() { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 		ejvr.frame-- | ||||
| 	case mDocument, mArray, mCodeWithScope: | ||||
| 		ejvr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc... | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) skipObject() { | ||||
| 	// read entire object until depth returns to 0 (last ending } or ] seen) | ||||
| 	depth := 1 | ||||
| 	for depth > 0 { | ||||
| 		ejvr.p.advanceState() | ||||
|  | ||||
| 		// If object is empty, raise depth and continue. When emptyObject is true, the | ||||
| 		// parser has already read both the opening and closing brackets of an empty | ||||
| 		// object ("{}"), so the next valid token will be part of the parent document, | ||||
| 		// not part of the nested document. | ||||
| 		// | ||||
| 		// If there is a comma, there are remaining fields, emptyObject must be set back | ||||
| 		// to false, and comma must be skipped with advanceState(). | ||||
| 		if ejvr.p.emptyObject { | ||||
| 			if ejvr.p.s == jpsSawComma { | ||||
| 				ejvr.p.emptyObject = false | ||||
| 				ejvr.p.advanceState() | ||||
| 			} | ||||
| 			depth-- | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		switch ejvr.p.s { | ||||
| 		case jpsSawBeginObject, jpsSawBeginArray: | ||||
| 			depth++ | ||||
| 		case jpsSawEndObject, jpsSawEndArray: | ||||
| 			depth-- | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) invalidTransitionErr(destination mode, name string, modes []mode) error { | ||||
| 	te := TransitionError{ | ||||
| 		name:        name, | ||||
| 		current:     ejvr.stack[ejvr.frame].mode, | ||||
| 		destination: destination, | ||||
| 		modes:       modes, | ||||
| 		action:      "read", | ||||
| 	} | ||||
| 	if ejvr.frame != 0 { | ||||
| 		te.parent = ejvr.stack[ejvr.frame-1].mode | ||||
| 	} | ||||
| 	return te | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) typeError(t bsontype.Type) error { | ||||
| 	return fmt.Errorf("positioned on %s, but attempted to read %s", ejvr.stack[ejvr.frame].vType, t) | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string, addModes ...mode) error { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 		if ejvr.stack[ejvr.frame].vType != t { | ||||
| 			return ejvr.typeError(t) | ||||
| 		} | ||||
| 	default: | ||||
| 		modes := []mode{mElement, mValue} | ||||
| 		if addModes != nil { | ||||
| 			modes = append(modes, addModes...) | ||||
| 		} | ||||
| 		return ejvr.invalidTransitionErr(destination, callerName, modes) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) Type() bsontype.Type { | ||||
| 	return ejvr.stack[ejvr.frame].vType | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) Skip() error { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 	default: | ||||
| 		return ejvr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue}) | ||||
| 	} | ||||
|  | ||||
| 	defer ejvr.pop() | ||||
|  | ||||
| 	t := ejvr.stack[ejvr.frame].vType | ||||
| 	switch t { | ||||
| 	case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope: | ||||
| 		// read entire array, doc or CodeWithScope | ||||
| 		ejvr.skipObject() | ||||
| 	default: | ||||
| 		_, err := ejvr.p.readValue(t) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadArray() (ArrayReader, error) { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mTopLevel: // allow reading array from top level | ||||
| 	case mArray: | ||||
| 		return ejvr, nil | ||||
| 	default: | ||||
| 		if err := ejvr.ensureElementValue(bsontype.Array, mArray, "ReadArray", mTopLevel, mArray); err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ejvr.pushArray() | ||||
|  | ||||
| 	return ejvr, nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadBinary() (b []byte, btype byte, err error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Binary) | ||||
| 	if err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	b, btype, err = v.parseBinary() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return b, btype, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadBoolean() (bool, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Boolean) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	if v.t != bsontype.Boolean { | ||||
| 		return false, fmt.Errorf("expected type bool, but got type %s", v.t) | ||||
| 	} | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return v.v.(bool), nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadDocument() (DocumentReader, error) { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mTopLevel: | ||||
| 		return ejvr, nil | ||||
| 	case mElement, mValue: | ||||
| 		if ejvr.stack[ejvr.frame].vType != bsontype.EmbeddedDocument { | ||||
| 			return nil, ejvr.typeError(bsontype.EmbeddedDocument) | ||||
| 		} | ||||
|  | ||||
| 		ejvr.pushDocument() | ||||
| 		return ejvr, nil | ||||
| 	default: | ||||
| 		return nil, ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) { | ||||
| 	if err = ejvr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.CodeWithScope) | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	code, err = v.parseJavascript() | ||||
|  | ||||
| 	ejvr.pushCodeWithScope() | ||||
| 	return code, ejvr, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) { | ||||
| 	if err = ejvr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil { | ||||
| 		return "", primitive.NilObjectID, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.DBPointer) | ||||
| 	if err != nil { | ||||
| 		return "", primitive.NilObjectID, err | ||||
| 	} | ||||
|  | ||||
| 	ns, oid, err = v.parseDBPointer() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return ns, oid, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadDateTime() (int64, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.DateTime) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	d, err := v.parseDateTime() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return d, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadDecimal128() (primitive.Decimal128, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil { | ||||
| 		return primitive.Decimal128{}, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Decimal128) | ||||
| 	if err != nil { | ||||
| 		return primitive.Decimal128{}, err | ||||
| 	} | ||||
|  | ||||
| 	d, err := v.parseDecimal128() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return d, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadDouble() (float64, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Double) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	d, err := v.parseDouble() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return d, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadInt32() (int32, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Int32) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	i, err := v.parseInt32() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return i, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadInt64() (int64, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Int64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	i, err := v.parseInt64() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return i, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadJavascript() (code string, err error) { | ||||
| 	if err = ejvr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.JavaScript) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	code, err = v.parseJavascript() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return code, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadMaxKey() error { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.MaxKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = v.parseMinMaxKey("max") | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadMinKey() error { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.MinKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = v.parseMinMaxKey("min") | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadNull() error { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Null) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if v.t != bsontype.Null { | ||||
| 		return fmt.Errorf("expected type null but got type %s", v.t) | ||||
| 	} | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadObjectID() (primitive.ObjectID, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil { | ||||
| 		return primitive.ObjectID{}, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.ObjectID) | ||||
| 	if err != nil { | ||||
| 		return primitive.ObjectID{}, err | ||||
| 	} | ||||
|  | ||||
| 	oid, err := v.parseObjectID() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return oid, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadRegex() (pattern string, options string, err error) { | ||||
| 	if err = ejvr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Regex) | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	pattern, options, err = v.parseRegex() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return pattern, options, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadString() (string, error) { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.String) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	if v.t != bsontype.String { | ||||
| 		return "", fmt.Errorf("expected type string but got type %s", v.t) | ||||
| 	} | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return v.v.(string), nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadSymbol() (symbol string, err error) { | ||||
| 	if err = ejvr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Symbol) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	symbol, err = v.parseSymbol() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return symbol, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadTimestamp() (t uint32, i uint32, err error) { | ||||
| 	if err = ejvr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Timestamp) | ||||
| 	if err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	t, i, err = v.parseTimestamp() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return t, i, err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadUndefined() error { | ||||
| 	if err := ejvr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	v, err := ejvr.p.readValue(bsontype.Undefined) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = v.parseUndefined() | ||||
|  | ||||
| 	ejvr.pop() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadElement() (string, ValueReader, error) { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mTopLevel, mDocument, mCodeWithScope: | ||||
| 	default: | ||||
| 		return "", nil, ejvr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope}) | ||||
| 	} | ||||
|  | ||||
| 	name, t, err := ejvr.p.readKey() | ||||
|  | ||||
| 	if err != nil { | ||||
| 		if err == ErrEOD { | ||||
| 			if ejvr.stack[ejvr.frame].mode == mCodeWithScope { | ||||
| 				_, err := ejvr.p.peekType() | ||||
| 				if err != nil { | ||||
| 					return "", nil, err | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			ejvr.pop() | ||||
| 		} | ||||
|  | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	ejvr.push(mElement, t) | ||||
| 	return name, ejvr, nil | ||||
| } | ||||
|  | ||||
| func (ejvr *extJSONValueReader) ReadValue() (ValueReader, error) { | ||||
| 	switch ejvr.stack[ejvr.frame].mode { | ||||
| 	case mArray: | ||||
| 	default: | ||||
| 		return nil, ejvr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray}) | ||||
| 	} | ||||
|  | ||||
| 	t, err := ejvr.p.peekType() | ||||
| 	if err != nil { | ||||
| 		if err == ErrEOA { | ||||
| 			ejvr.pop() | ||||
| 		} | ||||
|  | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	ejvr.push(mValue, t) | ||||
| 	return ejvr, nil | ||||
| } | ||||
							
								
								
									
										168
									
								
								mongo/bson/bsonrw/extjson_reader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								mongo/bson/bsonrw/extjson_reader_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,168 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| func TestExtJSONReader(t *testing.T) { | ||||
| 	t.Run("ReadDocument", func(t *testing.T) { | ||||
| 		t.Run("EmbeddedDocument", func(t *testing.T) { | ||||
| 			ejvr := &extJSONValueReader{ | ||||
| 				stack: []ejvrState{ | ||||
| 					{mode: mTopLevel}, | ||||
| 					{mode: mElement, vType: bsontype.Boolean}, | ||||
| 				}, | ||||
| 				frame: 1, | ||||
| 			} | ||||
|  | ||||
| 			ejvr.stack[1].mode = mArray | ||||
| 			wanterr := ejvr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue}) | ||||
| 			_, err := ejvr.ReadDocument() | ||||
| 			if err == nil || err.Error() != wanterr.Error() { | ||||
| 				t.Errorf("Incorrect returned error. got %v; want %v", err, wanterr) | ||||
| 			} | ||||
|  | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("invalid transition", func(t *testing.T) { | ||||
| 		t.Run("Skip", func(t *testing.T) { | ||||
| 			ejvr := &extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}} | ||||
| 			wanterr := (&extJSONValueReader{stack: []ejvrState{{mode: mTopLevel}}}).invalidTransitionErr(0, "Skip", []mode{mElement, mValue}) | ||||
| 			goterr := ejvr.Skip() | ||||
| 			if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) { | ||||
| 				t.Errorf("Expected correct invalid transition error. got %v; want %v", goterr, wanterr) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestReadMultipleTopLevelDocuments(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name     string | ||||
| 		input    string | ||||
| 		expected [][]byte | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"single top-level document", | ||||
| 			"{\"foo\":1}", | ||||
| 			[][]byte{ | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"single top-level document with leading and trailing whitespace", | ||||
| 			"\n\n   {\"foo\":1}   \n", | ||||
| 			[][]byte{ | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"two top-level documents", | ||||
| 			"{\"foo\":1}{\"foo\":2}", | ||||
| 			[][]byte{ | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"two top-level documents with leading and trailing whitespace and whitespace separation ", | ||||
| 			"\n\n  {\"foo\":1}\n{\"foo\":2}\n  ", | ||||
| 			[][]byte{ | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"top-level array with single document", | ||||
| 			"[{\"foo\":1}]", | ||||
| 			[][]byte{ | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"top-level array with 2 documents", | ||||
| 			"[{\"foo\":1},{\"foo\":2}]", | ||||
| 			[][]byte{ | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}, | ||||
| 				{0x0E, 0x00, 0x00, 0x00, 0x10, 'f', 'o', 'o', 0x00, 0x02, 0x00, 0x00, 0x00, 0x00}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			r := strings.NewReader(tc.input) | ||||
| 			vr, err := NewExtJSONValueReader(r, false) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("expected no error, but got %v", err) | ||||
| 			} | ||||
|  | ||||
| 			actual, err := readAllDocuments(vr) | ||||
| 			if err != nil { | ||||
| 				t.Fatalf("expected no error, but got %v", err) | ||||
| 			} | ||||
|  | ||||
| 			if diff := cmp.Diff(tc.expected, actual); diff != "" { | ||||
| 				t.Fatalf("expected does not match actual: %v", diff) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func readAllDocuments(vr ValueReader) ([][]byte, error) { | ||||
| 	c := NewCopier() | ||||
| 	var actual [][]byte | ||||
|  | ||||
| 	switch vr.Type() { | ||||
| 	case bsontype.EmbeddedDocument: | ||||
| 		for { | ||||
| 			result, err := c.CopyDocumentToBytes(vr) | ||||
| 			if err != nil { | ||||
| 				if err == io.EOF { | ||||
| 					break | ||||
| 				} | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			actual = append(actual, result) | ||||
| 		} | ||||
| 	case bsontype.Array: | ||||
| 		ar, err := vr.ReadArray() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		for { | ||||
| 			evr, err := ar.ReadValue() | ||||
| 			if err != nil { | ||||
| 				if err == ErrEOA { | ||||
| 					break | ||||
| 				} | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			result, err := c.CopyDocumentToBytes(evr) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			actual = append(actual, result) | ||||
| 		} | ||||
| 	default: | ||||
| 		return nil, fmt.Errorf("expected an array or a document, but got %s", vr.Type()) | ||||
| 	} | ||||
|  | ||||
| 	return actual, nil | ||||
| } | ||||
							
								
								
									
										223
									
								
								mongo/bson/bsonrw/extjson_tables.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										223
									
								
								mongo/bson/bsonrw/extjson_tables.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,223 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Based on github.com/golang/go by The Go Authors | ||||
| // See THIRD-PARTY-NOTICES for original license terms. | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import "unicode/utf8" | ||||
|  | ||||
| // safeSet holds the value true if the ASCII character with the given array | ||||
| // position can be represented inside a JSON string without any further | ||||
| // escaping. | ||||
| // | ||||
| // All values are true except for the ASCII control characters (0-31), the | ||||
| // double quote ("), and the backslash character ("\"). | ||||
| var safeSet = [utf8.RuneSelf]bool{ | ||||
| 	' ':      true, | ||||
| 	'!':      true, | ||||
| 	'"':      false, | ||||
| 	'#':      true, | ||||
| 	'$':      true, | ||||
| 	'%':      true, | ||||
| 	'&':      true, | ||||
| 	'\'':     true, | ||||
| 	'(':      true, | ||||
| 	')':      true, | ||||
| 	'*':      true, | ||||
| 	'+':      true, | ||||
| 	',':      true, | ||||
| 	'-':      true, | ||||
| 	'.':      true, | ||||
| 	'/':      true, | ||||
| 	'0':      true, | ||||
| 	'1':      true, | ||||
| 	'2':      true, | ||||
| 	'3':      true, | ||||
| 	'4':      true, | ||||
| 	'5':      true, | ||||
| 	'6':      true, | ||||
| 	'7':      true, | ||||
| 	'8':      true, | ||||
| 	'9':      true, | ||||
| 	':':      true, | ||||
| 	';':      true, | ||||
| 	'<':      true, | ||||
| 	'=':      true, | ||||
| 	'>':      true, | ||||
| 	'?':      true, | ||||
| 	'@':      true, | ||||
| 	'A':      true, | ||||
| 	'B':      true, | ||||
| 	'C':      true, | ||||
| 	'D':      true, | ||||
| 	'E':      true, | ||||
| 	'F':      true, | ||||
| 	'G':      true, | ||||
| 	'H':      true, | ||||
| 	'I':      true, | ||||
| 	'J':      true, | ||||
| 	'K':      true, | ||||
| 	'L':      true, | ||||
| 	'M':      true, | ||||
| 	'N':      true, | ||||
| 	'O':      true, | ||||
| 	'P':      true, | ||||
| 	'Q':      true, | ||||
| 	'R':      true, | ||||
| 	'S':      true, | ||||
| 	'T':      true, | ||||
| 	'U':      true, | ||||
| 	'V':      true, | ||||
| 	'W':      true, | ||||
| 	'X':      true, | ||||
| 	'Y':      true, | ||||
| 	'Z':      true, | ||||
| 	'[':      true, | ||||
| 	'\\':     false, | ||||
| 	']':      true, | ||||
| 	'^':      true, | ||||
| 	'_':      true, | ||||
| 	'`':      true, | ||||
| 	'a':      true, | ||||
| 	'b':      true, | ||||
| 	'c':      true, | ||||
| 	'd':      true, | ||||
| 	'e':      true, | ||||
| 	'f':      true, | ||||
| 	'g':      true, | ||||
| 	'h':      true, | ||||
| 	'i':      true, | ||||
| 	'j':      true, | ||||
| 	'k':      true, | ||||
| 	'l':      true, | ||||
| 	'm':      true, | ||||
| 	'n':      true, | ||||
| 	'o':      true, | ||||
| 	'p':      true, | ||||
| 	'q':      true, | ||||
| 	'r':      true, | ||||
| 	's':      true, | ||||
| 	't':      true, | ||||
| 	'u':      true, | ||||
| 	'v':      true, | ||||
| 	'w':      true, | ||||
| 	'x':      true, | ||||
| 	'y':      true, | ||||
| 	'z':      true, | ||||
| 	'{':      true, | ||||
| 	'|':      true, | ||||
| 	'}':      true, | ||||
| 	'~':      true, | ||||
| 	'\u007f': true, | ||||
| } | ||||
|  | ||||
| // htmlSafeSet holds the value true if the ASCII character with the given | ||||
| // array position can be safely represented inside a JSON string, embedded | ||||
| // inside of HTML <script> tags, without any additional escaping. | ||||
| // | ||||
| // All values are true except for the ASCII control characters (0-31), the | ||||
| // double quote ("), the backslash character ("\"), HTML opening and closing | ||||
| // tags ("<" and ">"), and the ampersand ("&"). | ||||
| var htmlSafeSet = [utf8.RuneSelf]bool{ | ||||
| 	' ':      true, | ||||
| 	'!':      true, | ||||
| 	'"':      false, | ||||
| 	'#':      true, | ||||
| 	'$':      true, | ||||
| 	'%':      true, | ||||
| 	'&':      false, | ||||
| 	'\'':     true, | ||||
| 	'(':      true, | ||||
| 	')':      true, | ||||
| 	'*':      true, | ||||
| 	'+':      true, | ||||
| 	',':      true, | ||||
| 	'-':      true, | ||||
| 	'.':      true, | ||||
| 	'/':      true, | ||||
| 	'0':      true, | ||||
| 	'1':      true, | ||||
| 	'2':      true, | ||||
| 	'3':      true, | ||||
| 	'4':      true, | ||||
| 	'5':      true, | ||||
| 	'6':      true, | ||||
| 	'7':      true, | ||||
| 	'8':      true, | ||||
| 	'9':      true, | ||||
| 	':':      true, | ||||
| 	';':      true, | ||||
| 	'<':      false, | ||||
| 	'=':      true, | ||||
| 	'>':      false, | ||||
| 	'?':      true, | ||||
| 	'@':      true, | ||||
| 	'A':      true, | ||||
| 	'B':      true, | ||||
| 	'C':      true, | ||||
| 	'D':      true, | ||||
| 	'E':      true, | ||||
| 	'F':      true, | ||||
| 	'G':      true, | ||||
| 	'H':      true, | ||||
| 	'I':      true, | ||||
| 	'J':      true, | ||||
| 	'K':      true, | ||||
| 	'L':      true, | ||||
| 	'M':      true, | ||||
| 	'N':      true, | ||||
| 	'O':      true, | ||||
| 	'P':      true, | ||||
| 	'Q':      true, | ||||
| 	'R':      true, | ||||
| 	'S':      true, | ||||
| 	'T':      true, | ||||
| 	'U':      true, | ||||
| 	'V':      true, | ||||
| 	'W':      true, | ||||
| 	'X':      true, | ||||
| 	'Y':      true, | ||||
| 	'Z':      true, | ||||
| 	'[':      true, | ||||
| 	'\\':     false, | ||||
| 	']':      true, | ||||
| 	'^':      true, | ||||
| 	'_':      true, | ||||
| 	'`':      true, | ||||
| 	'a':      true, | ||||
| 	'b':      true, | ||||
| 	'c':      true, | ||||
| 	'd':      true, | ||||
| 	'e':      true, | ||||
| 	'f':      true, | ||||
| 	'g':      true, | ||||
| 	'h':      true, | ||||
| 	'i':      true, | ||||
| 	'j':      true, | ||||
| 	'k':      true, | ||||
| 	'l':      true, | ||||
| 	'm':      true, | ||||
| 	'n':      true, | ||||
| 	'o':      true, | ||||
| 	'p':      true, | ||||
| 	'q':      true, | ||||
| 	'r':      true, | ||||
| 	's':      true, | ||||
| 	't':      true, | ||||
| 	'u':      true, | ||||
| 	'v':      true, | ||||
| 	'w':      true, | ||||
| 	'x':      true, | ||||
| 	'y':      true, | ||||
| 	'z':      true, | ||||
| 	'{':      true, | ||||
| 	'|':      true, | ||||
| 	'}':      true, | ||||
| 	'~':      true, | ||||
| 	'\u007f': true, | ||||
| } | ||||
							
								
								
									
										492
									
								
								mongo/bson/bsonrw/extjson_wrappers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										492
									
								
								mongo/bson/bsonrw/extjson_wrappers.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,492 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"math" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| func wrapperKeyBSONType(key string) bsontype.Type { | ||||
| 	switch key { | ||||
| 	case "$numberInt": | ||||
| 		return bsontype.Int32 | ||||
| 	case "$numberLong": | ||||
| 		return bsontype.Int64 | ||||
| 	case "$oid": | ||||
| 		return bsontype.ObjectID | ||||
| 	case "$symbol": | ||||
| 		return bsontype.Symbol | ||||
| 	case "$numberDouble": | ||||
| 		return bsontype.Double | ||||
| 	case "$numberDecimal": | ||||
| 		return bsontype.Decimal128 | ||||
| 	case "$binary": | ||||
| 		return bsontype.Binary | ||||
| 	case "$code": | ||||
| 		return bsontype.JavaScript | ||||
| 	case "$scope": | ||||
| 		return bsontype.CodeWithScope | ||||
| 	case "$timestamp": | ||||
| 		return bsontype.Timestamp | ||||
| 	case "$regularExpression": | ||||
| 		return bsontype.Regex | ||||
| 	case "$dbPointer": | ||||
| 		return bsontype.DBPointer | ||||
| 	case "$date": | ||||
| 		return bsontype.DateTime | ||||
| 	case "$minKey": | ||||
| 		return bsontype.MinKey | ||||
| 	case "$maxKey": | ||||
| 		return bsontype.MaxKey | ||||
| 	case "$undefined": | ||||
| 		return bsontype.Undefined | ||||
| 	} | ||||
|  | ||||
| 	return bsontype.EmbeddedDocument | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) { | ||||
| 	if ejv.t != bsontype.EmbeddedDocument { | ||||
| 		return nil, 0, fmt.Errorf("$binary value should be object, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	binObj := ejv.v.(*extJSONObject) | ||||
| 	bFound := false | ||||
| 	stFound := false | ||||
|  | ||||
| 	for i, key := range binObj.keys { | ||||
| 		val := binObj.values[i] | ||||
|  | ||||
| 		switch key { | ||||
| 		case "base64": | ||||
| 			if bFound { | ||||
| 				return nil, 0, errors.New("duplicate base64 key in $binary") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return nil, 0, fmt.Errorf("$binary base64 value should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			base64Bytes, err := base64.StdEncoding.DecodeString(val.v.(string)) | ||||
| 			if err != nil { | ||||
| 				return nil, 0, fmt.Errorf("invalid $binary base64 string: %s", val.v.(string)) | ||||
| 			} | ||||
|  | ||||
| 			b = base64Bytes | ||||
| 			bFound = true | ||||
| 		case "subType": | ||||
| 			if stFound { | ||||
| 				return nil, 0, errors.New("duplicate subType key in $binary") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			i, err := strconv.ParseInt(val.v.(string), 16, 64) | ||||
| 			if err != nil { | ||||
| 				return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string)) | ||||
| 			} | ||||
|  | ||||
| 			subType = byte(i) | ||||
| 			stFound = true | ||||
| 		default: | ||||
| 			return nil, 0, fmt.Errorf("invalid key in $binary object: %s", key) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !bFound { | ||||
| 		return nil, 0, errors.New("missing base64 field in $binary object") | ||||
| 	} | ||||
|  | ||||
| 	if !stFound { | ||||
| 		return nil, 0, errors.New("missing subType field in $binary object") | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	return b, subType, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseDBPointer() (ns string, oid primitive.ObjectID, err error) { | ||||
| 	if ejv.t != bsontype.EmbeddedDocument { | ||||
| 		return "", primitive.NilObjectID, fmt.Errorf("$dbPointer value should be object, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	dbpObj := ejv.v.(*extJSONObject) | ||||
| 	oidFound := false | ||||
| 	nsFound := false | ||||
|  | ||||
| 	for i, key := range dbpObj.keys { | ||||
| 		val := dbpObj.values[i] | ||||
|  | ||||
| 		switch key { | ||||
| 		case "$ref": | ||||
| 			if nsFound { | ||||
| 				return "", primitive.NilObjectID, errors.New("duplicate $ref key in $dbPointer") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $ref value should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			ns = val.v.(string) | ||||
| 			nsFound = true | ||||
| 		case "$id": | ||||
| 			if oidFound { | ||||
| 				return "", primitive.NilObjectID, errors.New("duplicate $id key in $dbPointer") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return "", primitive.NilObjectID, fmt.Errorf("$dbPointer $id value should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			oid, err = primitive.ObjectIDFromHex(val.v.(string)) | ||||
| 			if err != nil { | ||||
| 				return "", primitive.NilObjectID, err | ||||
| 			} | ||||
|  | ||||
| 			oidFound = true | ||||
| 		default: | ||||
| 			return "", primitive.NilObjectID, fmt.Errorf("invalid key in $dbPointer object: %s", key) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !nsFound { | ||||
| 		return "", oid, errors.New("missing $ref field in $dbPointer object") | ||||
| 	} | ||||
|  | ||||
| 	if !oidFound { | ||||
| 		return "", oid, errors.New("missing $id field in $dbPointer object") | ||||
| 	} | ||||
|  | ||||
| 	return ns, oid, nil | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	rfc3339Milli = "2006-01-02T15:04:05.999Z07:00" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	timeFormats = []string{rfc3339Milli, "2006-01-02T15:04:05.999Z0700"} | ||||
| ) | ||||
|  | ||||
| func (ejv *extJSONValue) parseDateTime() (int64, error) { | ||||
| 	switch ejv.t { | ||||
| 	case bsontype.Int32: | ||||
| 		return int64(ejv.v.(int32)), nil | ||||
| 	case bsontype.Int64: | ||||
| 		return ejv.v.(int64), nil | ||||
| 	case bsontype.String: | ||||
| 		return parseDatetimeString(ejv.v.(string)) | ||||
| 	case bsontype.EmbeddedDocument: | ||||
| 		return parseDatetimeObject(ejv.v.(*extJSONObject)) | ||||
| 	default: | ||||
| 		return 0, fmt.Errorf("$date value should be string or object, but instead is %s", ejv.t) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func parseDatetimeString(data string) (int64, error) { | ||||
| 	var t time.Time | ||||
| 	var err error | ||||
| 	// try acceptable time formats until one matches | ||||
| 	for _, format := range timeFormats { | ||||
| 		t, err = time.Parse(format, data) | ||||
| 		if err == nil { | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return 0, fmt.Errorf("invalid $date value string: %s", data) | ||||
| 	} | ||||
|  | ||||
| 	return int64(primitive.NewDateTimeFromTime(t)), nil | ||||
| } | ||||
|  | ||||
| func parseDatetimeObject(data *extJSONObject) (d int64, err error) { | ||||
| 	dFound := false | ||||
|  | ||||
| 	for i, key := range data.keys { | ||||
| 		val := data.values[i] | ||||
|  | ||||
| 		switch key { | ||||
| 		case "$numberLong": | ||||
| 			if dFound { | ||||
| 				return 0, errors.New("duplicate $numberLong key in $date") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return 0, fmt.Errorf("$date $numberLong field should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			d, err = val.parseInt64() | ||||
| 			if err != nil { | ||||
| 				return 0, err | ||||
| 			} | ||||
| 			dFound = true | ||||
| 		default: | ||||
| 			return 0, fmt.Errorf("invalid key in $date object: %s", key) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !dFound { | ||||
| 		return 0, errors.New("missing $numberLong field in $date object") | ||||
| 	} | ||||
|  | ||||
| 	return d, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseDecimal128() (primitive.Decimal128, error) { | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return primitive.Decimal128{}, fmt.Errorf("$numberDecimal value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	d, err := primitive.ParseDecimal128(ejv.v.(string)) | ||||
| 	if err != nil { | ||||
| 		return primitive.Decimal128{}, fmt.Errorf("$invalid $numberDecimal string: %s", ejv.v.(string)) | ||||
| 	} | ||||
|  | ||||
| 	return d, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseDouble() (float64, error) { | ||||
| 	if ejv.t == bsontype.Double { | ||||
| 		return ejv.v.(float64), nil | ||||
| 	} | ||||
|  | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return 0, fmt.Errorf("$numberDouble value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	switch ejv.v.(string) { | ||||
| 	case "Infinity": | ||||
| 		return math.Inf(1), nil | ||||
| 	case "-Infinity": | ||||
| 		return math.Inf(-1), nil | ||||
| 	case "NaN": | ||||
| 		return math.NaN(), nil | ||||
| 	} | ||||
|  | ||||
| 	f, err := strconv.ParseFloat(ejv.v.(string), 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	return f, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseInt32() (int32, error) { | ||||
| 	if ejv.t == bsontype.Int32 { | ||||
| 		return ejv.v.(int32), nil | ||||
| 	} | ||||
|  | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return 0, fmt.Errorf("$numberInt value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	i, err := strconv.ParseInt(ejv.v.(string), 10, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	if i < math.MinInt32 || i > math.MaxInt32 { | ||||
| 		return 0, fmt.Errorf("$numberInt value should be int32 but instead is int64: %d", i) | ||||
| 	} | ||||
|  | ||||
| 	return int32(i), nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseInt64() (int64, error) { | ||||
| 	if ejv.t == bsontype.Int64 { | ||||
| 		return ejv.v.(int64), nil | ||||
| 	} | ||||
|  | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return 0, fmt.Errorf("$numberLong value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	i, err := strconv.ParseInt(ejv.v.(string), 10, 64) | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	return i, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseJavascript() (code string, err error) { | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return "", fmt.Errorf("$code value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	return ejv.v.(string), nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseMinMaxKey(minmax string) error { | ||||
| 	if ejv.t != bsontype.Int32 { | ||||
| 		return fmt.Errorf("$%sKey value should be int32, but instead is %s", minmax, ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	if ejv.v.(int32) != 1 { | ||||
| 		return fmt.Errorf("$%sKey value must be 1, but instead is %d", minmax, ejv.v.(int32)) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseObjectID() (primitive.ObjectID, error) { | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return primitive.NilObjectID, fmt.Errorf("$oid value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	return primitive.ObjectIDFromHex(ejv.v.(string)) | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseRegex() (pattern, options string, err error) { | ||||
| 	if ejv.t != bsontype.EmbeddedDocument { | ||||
| 		return "", "", fmt.Errorf("$regularExpression value should be object, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	regexObj := ejv.v.(*extJSONObject) | ||||
| 	patFound := false | ||||
| 	optFound := false | ||||
|  | ||||
| 	for i, key := range regexObj.keys { | ||||
| 		val := regexObj.values[i] | ||||
|  | ||||
| 		switch key { | ||||
| 		case "pattern": | ||||
| 			if patFound { | ||||
| 				return "", "", errors.New("duplicate pattern key in $regularExpression") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return "", "", fmt.Errorf("$regularExpression pattern value should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			pattern = val.v.(string) | ||||
| 			patFound = true | ||||
| 		case "options": | ||||
| 			if optFound { | ||||
| 				return "", "", errors.New("duplicate options key in $regularExpression") | ||||
| 			} | ||||
|  | ||||
| 			if val.t != bsontype.String { | ||||
| 				return "", "", fmt.Errorf("$regularExpression options value should be string, but instead is %s", val.t) | ||||
| 			} | ||||
|  | ||||
| 			options = val.v.(string) | ||||
| 			optFound = true | ||||
| 		default: | ||||
| 			return "", "", fmt.Errorf("invalid key in $regularExpression object: %s", key) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !patFound { | ||||
| 		return "", "", errors.New("missing pattern field in $regularExpression object") | ||||
| 	} | ||||
|  | ||||
| 	if !optFound { | ||||
| 		return "", "", errors.New("missing options field in $regularExpression object") | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	return pattern, options, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseSymbol() (string, error) { | ||||
| 	if ejv.t != bsontype.String { | ||||
| 		return "", fmt.Errorf("$symbol value should be string, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	return ejv.v.(string), nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseTimestamp() (t, i uint32, err error) { | ||||
| 	if ejv.t != bsontype.EmbeddedDocument { | ||||
| 		return 0, 0, fmt.Errorf("$timestamp value should be object, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	handleKey := func(key string, val *extJSONValue, flag bool) (uint32, error) { | ||||
| 		if flag { | ||||
| 			return 0, fmt.Errorf("duplicate %s key in $timestamp", key) | ||||
| 		} | ||||
|  | ||||
| 		switch val.t { | ||||
| 		case bsontype.Int32: | ||||
| 			value := val.v.(int32) | ||||
|  | ||||
| 			if value < 0 { | ||||
| 				return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value) | ||||
| 			} | ||||
|  | ||||
| 			return uint32(value), nil | ||||
| 		case bsontype.Int64: | ||||
| 			value := val.v.(int64) | ||||
| 			if value < 0 || value > int64(math.MaxUint32) { | ||||
| 				return 0, fmt.Errorf("$timestamp %s number should be uint32: %d", key, value) | ||||
| 			} | ||||
|  | ||||
| 			return uint32(value), nil | ||||
| 		default: | ||||
| 			return 0, fmt.Errorf("$timestamp %s value should be uint32, but instead is %s", key, val.t) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tsObj := ejv.v.(*extJSONObject) | ||||
| 	tFound := false | ||||
| 	iFound := false | ||||
|  | ||||
| 	for j, key := range tsObj.keys { | ||||
| 		val := tsObj.values[j] | ||||
|  | ||||
| 		switch key { | ||||
| 		case "t": | ||||
| 			if t, err = handleKey(key, val, tFound); err != nil { | ||||
| 				return 0, 0, err | ||||
| 			} | ||||
|  | ||||
| 			tFound = true | ||||
| 		case "i": | ||||
| 			if i, err = handleKey(key, val, iFound); err != nil { | ||||
| 				return 0, 0, err | ||||
| 			} | ||||
|  | ||||
| 			iFound = true | ||||
| 		default: | ||||
| 			return 0, 0, fmt.Errorf("invalid key in $timestamp object: %s", key) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if !tFound { | ||||
| 		return 0, 0, errors.New("missing t field in $timestamp object") | ||||
| 	} | ||||
|  | ||||
| 	if !iFound { | ||||
| 		return 0, 0, errors.New("missing i field in $timestamp object") | ||||
| 	} | ||||
|  | ||||
| 	return t, i, nil | ||||
| } | ||||
|  | ||||
| func (ejv *extJSONValue) parseUndefined() error { | ||||
| 	if ejv.t != bsontype.Boolean { | ||||
| 		return fmt.Errorf("undefined value should be boolean, but instead is %s", ejv.t) | ||||
| 	} | ||||
|  | ||||
| 	if !ejv.v.(bool) { | ||||
| 		return fmt.Errorf("$undefined balue boolean should be true, but instead is %v", ejv.v.(bool)) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										732
									
								
								mongo/bson/bsonrw/extjson_writer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										732
									
								
								mongo/bson/bsonrw/extjson_writer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,732 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/base64" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 	"unicode/utf8" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| // ExtJSONValueWriterPool is a pool for ExtJSON ValueWriters. | ||||
| type ExtJSONValueWriterPool struct { | ||||
| 	pool sync.Pool | ||||
| } | ||||
|  | ||||
| // NewExtJSONValueWriterPool creates a new pool for ValueWriter instances that write to ExtJSON. | ||||
| func NewExtJSONValueWriterPool() *ExtJSONValueWriterPool { | ||||
| 	return &ExtJSONValueWriterPool{ | ||||
| 		pool: sync.Pool{ | ||||
| 			New: func() interface{} { | ||||
| 				return new(extJSONValueWriter) | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Get retrieves a ExtJSON ValueWriter from the pool and resets it to use w as the destination. | ||||
| func (bvwp *ExtJSONValueWriterPool) Get(w io.Writer, canonical, escapeHTML bool) ValueWriter { | ||||
| 	vw := bvwp.pool.Get().(*extJSONValueWriter) | ||||
| 	if writer, ok := w.(*SliceWriter); ok { | ||||
| 		vw.reset(*writer, canonical, escapeHTML) | ||||
| 		vw.w = writer | ||||
| 		return vw | ||||
| 	} | ||||
| 	vw.buf = vw.buf[:0] | ||||
| 	vw.w = w | ||||
| 	return vw | ||||
| } | ||||
|  | ||||
| // Put inserts a ValueWriter into the pool. If the ValueWriter is not a ExtJSON ValueWriter, nothing | ||||
| // happens and ok will be false. | ||||
| func (bvwp *ExtJSONValueWriterPool) Put(vw ValueWriter) (ok bool) { | ||||
| 	bvw, ok := vw.(*extJSONValueWriter) | ||||
| 	if !ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if _, ok := bvw.w.(*SliceWriter); ok { | ||||
| 		bvw.buf = nil | ||||
| 	} | ||||
| 	bvw.w = nil | ||||
|  | ||||
| 	bvwp.pool.Put(bvw) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| type ejvwState struct { | ||||
| 	mode mode | ||||
| } | ||||
|  | ||||
| type extJSONValueWriter struct { | ||||
| 	w   io.Writer | ||||
| 	buf []byte | ||||
|  | ||||
| 	stack      []ejvwState | ||||
| 	frame      int64 | ||||
| 	canonical  bool | ||||
| 	escapeHTML bool | ||||
| } | ||||
|  | ||||
| // NewExtJSONValueWriter creates a ValueWriter that writes Extended JSON to w. | ||||
| func NewExtJSONValueWriter(w io.Writer, canonical, escapeHTML bool) (ValueWriter, error) { | ||||
| 	if w == nil { | ||||
| 		return nil, errNilWriter | ||||
| 	} | ||||
|  | ||||
| 	return newExtJSONWriter(w, canonical, escapeHTML), nil | ||||
| } | ||||
|  | ||||
| func newExtJSONWriter(w io.Writer, canonical, escapeHTML bool) *extJSONValueWriter { | ||||
| 	stack := make([]ejvwState, 1, 5) | ||||
| 	stack[0] = ejvwState{mode: mTopLevel} | ||||
|  | ||||
| 	return &extJSONValueWriter{ | ||||
| 		w:          w, | ||||
| 		buf:        []byte{}, | ||||
| 		stack:      stack, | ||||
| 		canonical:  canonical, | ||||
| 		escapeHTML: escapeHTML, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func newExtJSONWriterFromSlice(buf []byte, canonical, escapeHTML bool) *extJSONValueWriter { | ||||
| 	stack := make([]ejvwState, 1, 5) | ||||
| 	stack[0] = ejvwState{mode: mTopLevel} | ||||
|  | ||||
| 	return &extJSONValueWriter{ | ||||
| 		buf:        buf, | ||||
| 		stack:      stack, | ||||
| 		canonical:  canonical, | ||||
| 		escapeHTML: escapeHTML, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) reset(buf []byte, canonical, escapeHTML bool) { | ||||
| 	if ejvw.stack == nil { | ||||
| 		ejvw.stack = make([]ejvwState, 1, 5) | ||||
| 	} | ||||
|  | ||||
| 	ejvw.stack = ejvw.stack[:1] | ||||
| 	ejvw.stack[0] = ejvwState{mode: mTopLevel} | ||||
| 	ejvw.canonical = canonical | ||||
| 	ejvw.escapeHTML = escapeHTML | ||||
| 	ejvw.frame = 0 | ||||
| 	ejvw.buf = buf | ||||
| 	ejvw.w = nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) advanceFrame() { | ||||
| 	if ejvw.frame+1 >= int64(len(ejvw.stack)) { // We need to grow the stack | ||||
| 		length := len(ejvw.stack) | ||||
| 		if length+1 >= cap(ejvw.stack) { | ||||
| 			// double it | ||||
| 			buf := make([]ejvwState, 2*cap(ejvw.stack)+1) | ||||
| 			copy(buf, ejvw.stack) | ||||
| 			ejvw.stack = buf | ||||
| 		} | ||||
| 		ejvw.stack = ejvw.stack[:length+1] | ||||
| 	} | ||||
| 	ejvw.frame++ | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) push(m mode) { | ||||
| 	ejvw.advanceFrame() | ||||
|  | ||||
| 	ejvw.stack[ejvw.frame].mode = m | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) pop() { | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 		ejvw.frame-- | ||||
| 	case mDocument, mArray, mCodeWithScope: | ||||
| 		ejvw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc... | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) invalidTransitionErr(destination mode, name string, modes []mode) error { | ||||
| 	te := TransitionError{ | ||||
| 		name:        name, | ||||
| 		current:     ejvw.stack[ejvw.frame].mode, | ||||
| 		destination: destination, | ||||
| 		modes:       modes, | ||||
| 		action:      "write", | ||||
| 	} | ||||
| 	if ejvw.frame != 0 { | ||||
| 		te.parent = ejvw.stack[ejvw.frame-1].mode | ||||
| 	} | ||||
| 	return te | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) ensureElementValue(destination mode, callerName string, addmodes ...mode) error { | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 	default: | ||||
| 		modes := []mode{mElement, mValue} | ||||
| 		if addmodes != nil { | ||||
| 			modes = append(modes, addmodes...) | ||||
| 		} | ||||
| 		return ejvw.invalidTransitionErr(destination, callerName, modes) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) writeExtendedSingleValue(key string, value string, quotes bool) { | ||||
| 	var s string | ||||
| 	if quotes { | ||||
| 		s = fmt.Sprintf(`{"$%s":"%s"}`, key, value) | ||||
| 	} else { | ||||
| 		s = fmt.Sprintf(`{"$%s":%s}`, key, value) | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, []byte(s)...) | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteArray() (ArrayWriter, error) { | ||||
| 	if err := ejvw.ensureElementValue(mArray, "WriteArray"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, '[') | ||||
|  | ||||
| 	ejvw.push(mArray) | ||||
| 	return ejvw, nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteBinary(b []byte) error { | ||||
| 	return ejvw.WriteBinaryWithSubtype(b, 0x00) | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteBinaryWithSubtype"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	buf.WriteString(`{"$binary":{"base64":"`) | ||||
| 	buf.WriteString(base64.StdEncoding.EncodeToString(b)) | ||||
| 	buf.WriteString(fmt.Sprintf(`","subType":"%02x"}},`, btype)) | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, buf.Bytes()...) | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteBoolean(b bool) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteBoolean"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, []byte(strconv.FormatBool(b))...) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) { | ||||
| 	if err := ejvw.ensureElementValue(mCodeWithScope, "WriteCodeWithScope"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	buf.WriteString(`{"$code":`) | ||||
| 	writeStringWithEscapes(code, &buf, ejvw.escapeHTML) | ||||
| 	buf.WriteString(`,"$scope":{`) | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, buf.Bytes()...) | ||||
|  | ||||
| 	ejvw.push(mCodeWithScope) | ||||
| 	return ejvw, nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteDBPointer"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	buf.WriteString(`{"$dbPointer":{"$ref":"`) | ||||
| 	buf.WriteString(ns) | ||||
| 	buf.WriteString(`","$id":{"$oid":"`) | ||||
| 	buf.WriteString(oid.Hex()) | ||||
| 	buf.WriteString(`"}}},`) | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, buf.Bytes()...) | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDateTime(dt int64) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteDateTime"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	t := time.Unix(dt/1e3, dt%1e3*1e6).UTC() | ||||
|  | ||||
| 	if ejvw.canonical || t.Year() < 1970 || t.Year() > 9999 { | ||||
| 		s := fmt.Sprintf(`{"$numberLong":"%d"}`, dt) | ||||
| 		ejvw.writeExtendedSingleValue("date", s, false) | ||||
| 	} else { | ||||
| 		ejvw.writeExtendedSingleValue("date", t.Format(rfc3339Milli), true) | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDecimal128(d primitive.Decimal128) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteDecimal128"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("numberDecimal", d.String(), true) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDocument() (DocumentWriter, error) { | ||||
| 	if ejvw.stack[ejvw.frame].mode == mTopLevel { | ||||
| 		ejvw.buf = append(ejvw.buf, '{') | ||||
| 		return ejvw, nil | ||||
| 	} | ||||
|  | ||||
| 	if err := ejvw.ensureElementValue(mDocument, "WriteDocument", mTopLevel); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, '{') | ||||
| 	ejvw.push(mDocument) | ||||
| 	return ejvw, nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDouble(f float64) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteDouble"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	s := formatDouble(f) | ||||
|  | ||||
| 	if ejvw.canonical { | ||||
| 		ejvw.writeExtendedSingleValue("numberDouble", s, true) | ||||
| 	} else { | ||||
| 		switch s { | ||||
| 		case "Infinity": | ||||
| 			fallthrough | ||||
| 		case "-Infinity": | ||||
| 			fallthrough | ||||
| 		case "NaN": | ||||
| 			s = fmt.Sprintf(`{"$numberDouble":"%s"}`, s) | ||||
| 		} | ||||
| 		ejvw.buf = append(ejvw.buf, []byte(s)...) | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteInt32(i int32) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteInt32"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	s := strconv.FormatInt(int64(i), 10) | ||||
|  | ||||
| 	if ejvw.canonical { | ||||
| 		ejvw.writeExtendedSingleValue("numberInt", s, true) | ||||
| 	} else { | ||||
| 		ejvw.buf = append(ejvw.buf, []byte(s)...) | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteInt64(i int64) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteInt64"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	s := strconv.FormatInt(i, 10) | ||||
|  | ||||
| 	if ejvw.canonical { | ||||
| 		ejvw.writeExtendedSingleValue("numberLong", s, true) | ||||
| 	} else { | ||||
| 		ejvw.buf = append(ejvw.buf, []byte(s)...) | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteJavascript(code string) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteJavascript"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	writeStringWithEscapes(code, &buf, ejvw.escapeHTML) | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("code", buf.String(), false) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteMaxKey() error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteMaxKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("maxKey", "1", false) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteMinKey() error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteMinKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("minKey", "1", false) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteNull() error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteNull"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, []byte("null")...) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteObjectID(oid primitive.ObjectID) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteObjectID"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("oid", oid.Hex(), true) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteRegex(pattern string, options string) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteRegex"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	buf.WriteString(`{"$regularExpression":{"pattern":`) | ||||
| 	writeStringWithEscapes(pattern, &buf, ejvw.escapeHTML) | ||||
| 	buf.WriteString(`,"options":"`) | ||||
| 	buf.WriteString(sortStringAlphebeticAscending(options)) | ||||
| 	buf.WriteString(`"}},`) | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, buf.Bytes()...) | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteString(s string) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteString"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	writeStringWithEscapes(s, &buf, ejvw.escapeHTML) | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, buf.Bytes()...) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteSymbol(symbol string) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteSymbol"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	writeStringWithEscapes(symbol, &buf, ejvw.escapeHTML) | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("symbol", buf.String(), false) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteTimestamp(t uint32, i uint32) error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteTimestamp"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	buf.WriteString(`{"$timestamp":{"t":`) | ||||
| 	buf.WriteString(strconv.FormatUint(uint64(t), 10)) | ||||
| 	buf.WriteString(`,"i":`) | ||||
| 	buf.WriteString(strconv.FormatUint(uint64(i), 10)) | ||||
| 	buf.WriteString(`}},`) | ||||
|  | ||||
| 	ejvw.buf = append(ejvw.buf, buf.Bytes()...) | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteUndefined() error { | ||||
| 	if err := ejvw.ensureElementValue(mode(0), "WriteUndefined"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	ejvw.writeExtendedSingleValue("undefined", "true", false) | ||||
| 	ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDocumentElement(key string) (ValueWriter, error) { | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mDocument, mTopLevel, mCodeWithScope: | ||||
| 		var buf bytes.Buffer | ||||
| 		writeStringWithEscapes(key, &buf, ejvw.escapeHTML) | ||||
|  | ||||
| 		ejvw.buf = append(ejvw.buf, []byte(fmt.Sprintf(`%s:`, buf.String()))...) | ||||
| 		ejvw.push(mElement) | ||||
| 	default: | ||||
| 		return nil, ejvw.invalidTransitionErr(mElement, "WriteDocumentElement", []mode{mDocument, mTopLevel, mCodeWithScope}) | ||||
| 	} | ||||
|  | ||||
| 	return ejvw, nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteDocumentEnd() error { | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mDocument, mTopLevel, mCodeWithScope: | ||||
| 	default: | ||||
| 		return fmt.Errorf("incorrect mode to end document: %s", ejvw.stack[ejvw.frame].mode) | ||||
| 	} | ||||
|  | ||||
| 	// close the document | ||||
| 	if ejvw.buf[len(ejvw.buf)-1] == ',' { | ||||
| 		ejvw.buf[len(ejvw.buf)-1] = '}' | ||||
| 	} else { | ||||
| 		ejvw.buf = append(ejvw.buf, '}') | ||||
| 	} | ||||
|  | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mCodeWithScope: | ||||
| 		ejvw.buf = append(ejvw.buf, '}') | ||||
| 		fallthrough | ||||
| 	case mDocument: | ||||
| 		ejvw.buf = append(ejvw.buf, ',') | ||||
| 	case mTopLevel: | ||||
| 		if ejvw.w != nil { | ||||
| 			if _, err := ejvw.w.Write(ejvw.buf); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			ejvw.buf = ejvw.buf[:0] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	ejvw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteArrayElement() (ValueWriter, error) { | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mArray: | ||||
| 		ejvw.push(mValue) | ||||
| 	default: | ||||
| 		return nil, ejvw.invalidTransitionErr(mValue, "WriteArrayElement", []mode{mArray}) | ||||
| 	} | ||||
|  | ||||
| 	return ejvw, nil | ||||
| } | ||||
|  | ||||
| func (ejvw *extJSONValueWriter) WriteArrayEnd() error { | ||||
| 	switch ejvw.stack[ejvw.frame].mode { | ||||
| 	case mArray: | ||||
| 		// close the array | ||||
| 		if ejvw.buf[len(ejvw.buf)-1] == ',' { | ||||
| 			ejvw.buf[len(ejvw.buf)-1] = ']' | ||||
| 		} else { | ||||
| 			ejvw.buf = append(ejvw.buf, ']') | ||||
| 		} | ||||
|  | ||||
| 		ejvw.buf = append(ejvw.buf, ',') | ||||
|  | ||||
| 		ejvw.pop() | ||||
| 	default: | ||||
| 		return fmt.Errorf("incorrect mode to end array: %s", ejvw.stack[ejvw.frame].mode) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func formatDouble(f float64) string { | ||||
| 	var s string | ||||
| 	if math.IsInf(f, 1) { | ||||
| 		s = "Infinity" | ||||
| 	} else if math.IsInf(f, -1) { | ||||
| 		s = "-Infinity" | ||||
| 	} else if math.IsNaN(f) { | ||||
| 		s = "NaN" | ||||
| 	} else { | ||||
| 		// Print exactly one decimalType place for integers; otherwise, print as many are necessary to | ||||
| 		// perfectly represent it. | ||||
| 		s = strconv.FormatFloat(f, 'G', -1, 64) | ||||
| 		if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') { | ||||
| 			s += ".0" | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| var hexChars = "0123456789abcdef" | ||||
|  | ||||
| func writeStringWithEscapes(s string, buf *bytes.Buffer, escapeHTML bool) { | ||||
| 	buf.WriteByte('"') | ||||
| 	start := 0 | ||||
| 	for i := 0; i < len(s); { | ||||
| 		if b := s[i]; b < utf8.RuneSelf { | ||||
| 			if htmlSafeSet[b] || (!escapeHTML && safeSet[b]) { | ||||
| 				i++ | ||||
| 				continue | ||||
| 			} | ||||
| 			if start < i { | ||||
| 				buf.WriteString(s[start:i]) | ||||
| 			} | ||||
| 			switch b { | ||||
| 			case '\\', '"': | ||||
| 				buf.WriteByte('\\') | ||||
| 				buf.WriteByte(b) | ||||
| 			case '\n': | ||||
| 				buf.WriteByte('\\') | ||||
| 				buf.WriteByte('n') | ||||
| 			case '\r': | ||||
| 				buf.WriteByte('\\') | ||||
| 				buf.WriteByte('r') | ||||
| 			case '\t': | ||||
| 				buf.WriteByte('\\') | ||||
| 				buf.WriteByte('t') | ||||
| 			case '\b': | ||||
| 				buf.WriteByte('\\') | ||||
| 				buf.WriteByte('b') | ||||
| 			case '\f': | ||||
| 				buf.WriteByte('\\') | ||||
| 				buf.WriteByte('f') | ||||
| 			default: | ||||
| 				// This encodes bytes < 0x20 except for \t, \n and \r. | ||||
| 				// If escapeHTML is set, it also escapes <, >, and & | ||||
| 				// because they can lead to security holes when | ||||
| 				// user-controlled strings are rendered into JSON | ||||
| 				// and served to some browsers. | ||||
| 				buf.WriteString(`\u00`) | ||||
| 				buf.WriteByte(hexChars[b>>4]) | ||||
| 				buf.WriteByte(hexChars[b&0xF]) | ||||
| 			} | ||||
| 			i++ | ||||
| 			start = i | ||||
| 			continue | ||||
| 		} | ||||
| 		c, size := utf8.DecodeRuneInString(s[i:]) | ||||
| 		if c == utf8.RuneError && size == 1 { | ||||
| 			if start < i { | ||||
| 				buf.WriteString(s[start:i]) | ||||
| 			} | ||||
| 			buf.WriteString(`\ufffd`) | ||||
| 			i += size | ||||
| 			start = i | ||||
| 			continue | ||||
| 		} | ||||
| 		// U+2028 is LINE SEPARATOR. | ||||
| 		// U+2029 is PARAGRAPH SEPARATOR. | ||||
| 		// They are both technically valid characters in JSON strings, | ||||
| 		// but don't work in JSONP, which has to be evaluated as JavaScript, | ||||
| 		// and can lead to security holes there. It is valid JSON to | ||||
| 		// escape them, so we do so unconditionally. | ||||
| 		// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion. | ||||
| 		if c == '\u2028' || c == '\u2029' { | ||||
| 			if start < i { | ||||
| 				buf.WriteString(s[start:i]) | ||||
| 			} | ||||
| 			buf.WriteString(`\u202`) | ||||
| 			buf.WriteByte(hexChars[c&0xF]) | ||||
| 			i += size | ||||
| 			start = i | ||||
| 			continue | ||||
| 		} | ||||
| 		i += size | ||||
| 	} | ||||
| 	if start < len(s) { | ||||
| 		buf.WriteString(s[start:]) | ||||
| 	} | ||||
| 	buf.WriteByte('"') | ||||
| } | ||||
|  | ||||
| type sortableString []rune | ||||
|  | ||||
| func (ss sortableString) Len() int { | ||||
| 	return len(ss) | ||||
| } | ||||
|  | ||||
| func (ss sortableString) Less(i, j int) bool { | ||||
| 	return ss[i] < ss[j] | ||||
| } | ||||
|  | ||||
| func (ss sortableString) Swap(i, j int) { | ||||
| 	oldI := ss[i] | ||||
| 	ss[i] = ss[j] | ||||
| 	ss[j] = oldI | ||||
| } | ||||
|  | ||||
| func sortStringAlphebeticAscending(s string) string { | ||||
| 	ss := sortableString([]rune(s)) | ||||
| 	sort.Sort(ss) | ||||
| 	return string([]rune(ss)) | ||||
| } | ||||
							
								
								
									
										260
									
								
								mongo/bson/bsonrw/extjson_writer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										260
									
								
								mongo/bson/bsonrw/extjson_writer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,260 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| func TestExtJSONValueWriter(t *testing.T) { | ||||
| 	oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} | ||||
| 	testCases := []struct { | ||||
| 		name   string | ||||
| 		fn     interface{} | ||||
| 		params []interface{} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"WriteBinary", | ||||
| 			(*extJSONValueWriter).WriteBinary, | ||||
| 			[]interface{}{[]byte{0x01, 0x02, 0x03}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteBinaryWithSubtype (not 0x02)", | ||||
| 			(*extJSONValueWriter).WriteBinaryWithSubtype, | ||||
| 			[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteBinaryWithSubtype (0x02)", | ||||
| 			(*extJSONValueWriter).WriteBinaryWithSubtype, | ||||
| 			[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteBoolean", | ||||
| 			(*extJSONValueWriter).WriteBoolean, | ||||
| 			[]interface{}{true}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDBPointer", | ||||
| 			(*extJSONValueWriter).WriteDBPointer, | ||||
| 			[]interface{}{"bar", oid}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDateTime", | ||||
| 			(*extJSONValueWriter).WriteDateTime, | ||||
| 			[]interface{}{int64(12345678)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDecimal128", | ||||
| 			(*extJSONValueWriter).WriteDecimal128, | ||||
| 			[]interface{}{primitive.NewDecimal128(10, 20)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDouble", | ||||
| 			(*extJSONValueWriter).WriteDouble, | ||||
| 			[]interface{}{float64(3.14159)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteInt32", | ||||
| 			(*extJSONValueWriter).WriteInt32, | ||||
| 			[]interface{}{int32(123456)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteInt64", | ||||
| 			(*extJSONValueWriter).WriteInt64, | ||||
| 			[]interface{}{int64(1234567890)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteJavascript", | ||||
| 			(*extJSONValueWriter).WriteJavascript, | ||||
| 			[]interface{}{"var foo = 'bar';"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteMaxKey", | ||||
| 			(*extJSONValueWriter).WriteMaxKey, | ||||
| 			[]interface{}{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteMinKey", | ||||
| 			(*extJSONValueWriter).WriteMinKey, | ||||
| 			[]interface{}{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteNull", | ||||
| 			(*extJSONValueWriter).WriteNull, | ||||
| 			[]interface{}{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteObjectID", | ||||
| 			(*extJSONValueWriter).WriteObjectID, | ||||
| 			[]interface{}{oid}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteRegex", | ||||
| 			(*extJSONValueWriter).WriteRegex, | ||||
| 			[]interface{}{"bar", "baz"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteString", | ||||
| 			(*extJSONValueWriter).WriteString, | ||||
| 			[]interface{}{"hello, world!"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteSymbol", | ||||
| 			(*extJSONValueWriter).WriteSymbol, | ||||
| 			[]interface{}{"symbollolz"}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteTimestamp", | ||||
| 			(*extJSONValueWriter).WriteTimestamp, | ||||
| 			[]interface{}{uint32(10), uint32(20)}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteUndefined", | ||||
| 			(*extJSONValueWriter).WriteUndefined, | ||||
| 			[]interface{}{}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			fn := reflect.ValueOf(tc.fn) | ||||
| 			if fn.Kind() != reflect.Func { | ||||
| 				t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind()) | ||||
| 			} | ||||
| 			if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*extJSONValueWriter)(nil)) { | ||||
| 				t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter") | ||||
| 			} | ||||
| 			if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() { | ||||
| 				t.Fatalf("fn must have one return value and it must be an error.") | ||||
| 			} | ||||
| 			params := make([]reflect.Value, 1, len(tc.params)+1) | ||||
| 			ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 			params[0] = reflect.ValueOf(ejvw) | ||||
| 			for _, param := range tc.params { | ||||
| 				params = append(params, reflect.ValueOf(param)) | ||||
| 			} | ||||
|  | ||||
| 			t.Run("incorrect transition", func(t *testing.T) { | ||||
| 				results := fn.Call(params) | ||||
| 				got := results[0].Interface().(error) | ||||
| 				fnName := tc.name | ||||
| 				if strings.Contains(fnName, "WriteBinary") { | ||||
| 					fnName = "WriteBinaryWithSubtype" | ||||
| 				} | ||||
| 				want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue}, | ||||
| 					action: "write"} | ||||
| 				if !compareErrors(got, want) { | ||||
| 					t.Errorf("Errors do not match. got %v; want %v", got, want) | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	t.Run("WriteArray", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mArray) | ||||
| 		want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel, | ||||
| 			name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"} | ||||
| 		_, got := ejvw.WriteArray() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteCodeWithScope", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mArray) | ||||
| 		want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel, | ||||
| 			name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"} | ||||
| 		_, got := ejvw.WriteCodeWithScope("") | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteDocument", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mArray) | ||||
| 		want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel, | ||||
| 			name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"} | ||||
| 		_, got := ejvw.WriteDocument() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteDocumentElement", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mElement) | ||||
| 		want := TransitionError{current: mElement, | ||||
| 			destination: mElement, | ||||
| 			parent:      mTopLevel, | ||||
| 			name:        "WriteDocumentElement", | ||||
| 			modes:       []mode{mDocument, mTopLevel, mCodeWithScope}, | ||||
| 			action:      "write"} | ||||
| 		_, got := ejvw.WriteDocumentElement("") | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteDocumentEnd", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mElement) | ||||
| 		want := fmt.Errorf("incorrect mode to end document: %s", mElement) | ||||
| 		got := ejvw.WriteDocumentEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteArrayElement", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mElement) | ||||
| 		want := TransitionError{current: mElement, | ||||
| 			destination: mValue, | ||||
| 			parent:      mTopLevel, | ||||
| 			name:        "WriteArrayElement", | ||||
| 			modes:       []mode{mArray}, | ||||
| 			action:      "write"} | ||||
| 		_, got := ejvw.WriteArrayElement() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteArrayEnd", func(t *testing.T) { | ||||
| 		ejvw := newExtJSONWriter(ioutil.Discard, true, true) | ||||
| 		ejvw.push(mElement) | ||||
| 		want := fmt.Errorf("incorrect mode to end array: %s", mElement) | ||||
| 		got := ejvw.WriteArrayEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("WriteBytes", func(t *testing.T) { | ||||
| 		t.Run("writeElementHeader error", func(t *testing.T) { | ||||
| 			ejvw := newExtJSONWriterFromSlice(nil, true, true) | ||||
| 			want := TransitionError{current: mTopLevel, destination: mode(0), | ||||
| 				name: "WriteBinaryWithSubtype", modes: []mode{mElement, mValue}, action: "write"} | ||||
| 			got := ejvw.WriteBinaryWithSubtype(nil, (byte)(bsontype.EmbeddedDocument)) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Did not received expected error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("FormatDoubleWithExponent", func(t *testing.T) { | ||||
| 		want := "3E-12" | ||||
| 		got := formatDouble(float64(0.000000000003)) | ||||
| 		if got != want { | ||||
| 			t.Errorf("Did not receive expected string. got %s: want %s", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										528
									
								
								mongo/bson/bsonrw/json_scanner.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										528
									
								
								mongo/bson/bsonrw/json_scanner.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,528 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"strconv" | ||||
| 	"unicode" | ||||
| 	"unicode/utf16" | ||||
| ) | ||||
|  | ||||
| type jsonTokenType byte | ||||
|  | ||||
| const ( | ||||
| 	jttBeginObject jsonTokenType = iota | ||||
| 	jttEndObject | ||||
| 	jttBeginArray | ||||
| 	jttEndArray | ||||
| 	jttColon | ||||
| 	jttComma | ||||
| 	jttInt32 | ||||
| 	jttInt64 | ||||
| 	jttDouble | ||||
| 	jttString | ||||
| 	jttBool | ||||
| 	jttNull | ||||
| 	jttEOF | ||||
| ) | ||||
|  | ||||
| type jsonToken struct { | ||||
| 	t jsonTokenType | ||||
| 	v interface{} | ||||
| 	p int | ||||
| } | ||||
|  | ||||
| type jsonScanner struct { | ||||
| 	r           io.Reader | ||||
| 	buf         []byte | ||||
| 	pos         int | ||||
| 	lastReadErr error | ||||
| } | ||||
|  | ||||
| // nextToken returns the next JSON token if one exists. A token is a character | ||||
| // of the JSON grammar, a number, a string, or a literal. | ||||
| func (js *jsonScanner) nextToken() (*jsonToken, error) { | ||||
| 	c, err := js.readNextByte() | ||||
|  | ||||
| 	// keep reading until a non-space is encountered (break on read error or EOF) | ||||
| 	for isWhiteSpace(c) && err == nil { | ||||
| 		c, err = js.readNextByte() | ||||
| 	} | ||||
|  | ||||
| 	if err == io.EOF { | ||||
| 		return &jsonToken{t: jttEOF}, nil | ||||
| 	} else if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// switch on the character | ||||
| 	switch c { | ||||
| 	case '{': | ||||
| 		return &jsonToken{t: jttBeginObject, v: byte('{'), p: js.pos - 1}, nil | ||||
| 	case '}': | ||||
| 		return &jsonToken{t: jttEndObject, v: byte('}'), p: js.pos - 1}, nil | ||||
| 	case '[': | ||||
| 		return &jsonToken{t: jttBeginArray, v: byte('['), p: js.pos - 1}, nil | ||||
| 	case ']': | ||||
| 		return &jsonToken{t: jttEndArray, v: byte(']'), p: js.pos - 1}, nil | ||||
| 	case ':': | ||||
| 		return &jsonToken{t: jttColon, v: byte(':'), p: js.pos - 1}, nil | ||||
| 	case ',': | ||||
| 		return &jsonToken{t: jttComma, v: byte(','), p: js.pos - 1}, nil | ||||
| 	case '"': // RFC-8259 only allows for double quotes (") not single (') | ||||
| 		return js.scanString() | ||||
| 	default: | ||||
| 		// check if it's a number | ||||
| 		if c == '-' || isDigit(c) { | ||||
| 			return js.scanNumber(c) | ||||
| 		} else if c == 't' || c == 'f' || c == 'n' { | ||||
| 			// maybe a literal | ||||
| 			return js.scanLiteral(c) | ||||
| 		} else { | ||||
| 			return nil, fmt.Errorf("invalid JSON input. Position: %d. Character: %c", js.pos-1, c) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // readNextByte attempts to read the next byte from the buffer. If the buffer | ||||
| // has been exhausted, this function calls readIntoBuf, thus refilling the | ||||
| // buffer and resetting the read position to 0 | ||||
| func (js *jsonScanner) readNextByte() (byte, error) { | ||||
| 	if js.pos >= len(js.buf) { | ||||
| 		err := js.readIntoBuf() | ||||
|  | ||||
| 		if err != nil { | ||||
| 			return 0, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	b := js.buf[js.pos] | ||||
| 	js.pos++ | ||||
|  | ||||
| 	return b, nil | ||||
| } | ||||
|  | ||||
| // readNNextBytes reads n bytes into dst, starting at offset | ||||
| func (js *jsonScanner) readNNextBytes(dst []byte, n, offset int) error { | ||||
| 	var err error | ||||
|  | ||||
| 	for i := 0; i < n; i++ { | ||||
| 		dst[i+offset], err = js.readNextByte() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // readIntoBuf reads up to 512 bytes from the scanner's io.Reader into the buffer | ||||
| func (js *jsonScanner) readIntoBuf() error { | ||||
| 	if js.lastReadErr != nil { | ||||
| 		js.buf = js.buf[:0] | ||||
| 		js.pos = 0 | ||||
| 		return js.lastReadErr | ||||
| 	} | ||||
|  | ||||
| 	if cap(js.buf) == 0 { | ||||
| 		js.buf = make([]byte, 0, 512) | ||||
| 	} | ||||
|  | ||||
| 	n, err := js.r.Read(js.buf[:cap(js.buf)]) | ||||
| 	if err != nil { | ||||
| 		js.lastReadErr = err | ||||
| 		if n > 0 { | ||||
| 			err = nil | ||||
| 		} | ||||
| 	} | ||||
| 	js.buf = js.buf[:n] | ||||
| 	js.pos = 0 | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func isWhiteSpace(c byte) bool { | ||||
| 	return c == ' ' || c == '\t' || c == '\r' || c == '\n' | ||||
| } | ||||
|  | ||||
| func isDigit(c byte) bool { | ||||
| 	return unicode.IsDigit(rune(c)) | ||||
| } | ||||
|  | ||||
| func isValueTerminator(c byte) bool { | ||||
| 	return c == ',' || c == '}' || c == ']' || isWhiteSpace(c) | ||||
| } | ||||
|  | ||||
| // getu4 decodes the 4-byte hex sequence from the beginning of s, returning the hex value as a rune, | ||||
| // or it returns -1. Note that the "\u" from the unicode escape sequence should not be present. | ||||
| // It is copied and lightly modified from the Go JSON decode function at | ||||
| // https://github.com/golang/go/blob/1b0a0316802b8048d69da49dc23c5a5ab08e8ae8/src/encoding/json/decode.go#L1169-L1188 | ||||
| func getu4(s []byte) rune { | ||||
| 	if len(s) < 4 { | ||||
| 		return -1 | ||||
| 	} | ||||
| 	var r rune | ||||
| 	for _, c := range s[:4] { | ||||
| 		switch { | ||||
| 		case '0' <= c && c <= '9': | ||||
| 			c = c - '0' | ||||
| 		case 'a' <= c && c <= 'f': | ||||
| 			c = c - 'a' + 10 | ||||
| 		case 'A' <= c && c <= 'F': | ||||
| 			c = c - 'A' + 10 | ||||
| 		default: | ||||
| 			return -1 | ||||
| 		} | ||||
| 		r = r*16 + rune(c) | ||||
| 	} | ||||
| 	return r | ||||
| } | ||||
|  | ||||
| // scanString reads from an opening '"' to a closing '"' and handles escaped characters | ||||
| func (js *jsonScanner) scanString() (*jsonToken, error) { | ||||
| 	var b bytes.Buffer | ||||
| 	var c byte | ||||
| 	var err error | ||||
|  | ||||
| 	p := js.pos - 1 | ||||
|  | ||||
| 	for { | ||||
| 		c, err = js.readNextByte() | ||||
| 		if err != nil { | ||||
| 			if err == io.EOF { | ||||
| 				return nil, errors.New("end of input in JSON string") | ||||
| 			} | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 	evalNextChar: | ||||
| 		switch c { | ||||
| 		case '\\': | ||||
| 			c, err = js.readNextByte() | ||||
| 			if err != nil { | ||||
| 				if err == io.EOF { | ||||
| 					return nil, errors.New("end of input in JSON string") | ||||
| 				} | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 		evalNextEscapeChar: | ||||
| 			switch c { | ||||
| 			case '"', '\\', '/': | ||||
| 				b.WriteByte(c) | ||||
| 			case 'b': | ||||
| 				b.WriteByte('\b') | ||||
| 			case 'f': | ||||
| 				b.WriteByte('\f') | ||||
| 			case 'n': | ||||
| 				b.WriteByte('\n') | ||||
| 			case 'r': | ||||
| 				b.WriteByte('\r') | ||||
| 			case 't': | ||||
| 				b.WriteByte('\t') | ||||
| 			case 'u': | ||||
| 				us := make([]byte, 4) | ||||
| 				err = js.readNNextBytes(us, 4, 0) | ||||
| 				if err != nil { | ||||
| 					return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us) | ||||
| 				} | ||||
|  | ||||
| 				rn := getu4(us) | ||||
|  | ||||
| 				// If the rune we just decoded is the high or low value of a possible surrogate pair, | ||||
| 				// try to decode the next sequence as the low value of a surrogate pair. We're | ||||
| 				// expecting the next sequence to be another Unicode escape sequence (e.g. "\uDD1E"), | ||||
| 				// but need to handle cases where the input is not a valid surrogate pair. | ||||
| 				// For more context on unicode surrogate pairs, see: | ||||
| 				// https://www.christianfscott.com/rust-chars-vs-go-runes/ | ||||
| 				// https://www.unicode.org/glossary/#high_surrogate_code_point | ||||
| 				if utf16.IsSurrogate(rn) { | ||||
| 					c, err = js.readNextByte() | ||||
| 					if err != nil { | ||||
| 						if err == io.EOF { | ||||
| 							return nil, errors.New("end of input in JSON string") | ||||
| 						} | ||||
| 						return nil, err | ||||
| 					} | ||||
|  | ||||
| 					// If the next value isn't the beginning of a backslash escape sequence, write | ||||
| 					// the Unicode replacement character for the surrogate value and goto the | ||||
| 					// beginning of the next char eval block. | ||||
| 					if c != '\\' { | ||||
| 						b.WriteRune(unicode.ReplacementChar) | ||||
| 						goto evalNextChar | ||||
| 					} | ||||
|  | ||||
| 					c, err = js.readNextByte() | ||||
| 					if err != nil { | ||||
| 						if err == io.EOF { | ||||
| 							return nil, errors.New("end of input in JSON string") | ||||
| 						} | ||||
| 						return nil, err | ||||
| 					} | ||||
|  | ||||
| 					// If the next value isn't the beginning of a unicode escape sequence, write the | ||||
| 					// Unicode replacement character for the surrogate value and goto the beginning | ||||
| 					// of the next escape char eval block. | ||||
| 					if c != 'u' { | ||||
| 						b.WriteRune(unicode.ReplacementChar) | ||||
| 						goto evalNextEscapeChar | ||||
| 					} | ||||
|  | ||||
| 					err = js.readNNextBytes(us, 4, 0) | ||||
| 					if err != nil { | ||||
| 						return nil, fmt.Errorf("invalid unicode sequence in JSON string: %s", us) | ||||
| 					} | ||||
|  | ||||
| 					rn2 := getu4(us) | ||||
|  | ||||
| 					// Try to decode the pair of runes as a utf16 surrogate pair. If that fails, write | ||||
| 					// the Unicode replacement character for the surrogate value and the 2nd decoded rune. | ||||
| 					if rnPair := utf16.DecodeRune(rn, rn2); rnPair != unicode.ReplacementChar { | ||||
| 						b.WriteRune(rnPair) | ||||
| 					} else { | ||||
| 						b.WriteRune(unicode.ReplacementChar) | ||||
| 						b.WriteRune(rn2) | ||||
| 					} | ||||
|  | ||||
| 					break | ||||
| 				} | ||||
|  | ||||
| 				b.WriteRune(rn) | ||||
| 			default: | ||||
| 				return nil, fmt.Errorf("invalid escape sequence in JSON string '\\%c'", c) | ||||
| 			} | ||||
| 		case '"': | ||||
| 			return &jsonToken{t: jttString, v: b.String(), p: p}, nil | ||||
| 		default: | ||||
| 			b.WriteByte(c) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // scanLiteral reads an unquoted sequence of characters and determines if it is one of | ||||
| // three valid JSON literals (true, false, null); if so, it returns the appropriate | ||||
| // jsonToken; otherwise, it returns an error | ||||
| func (js *jsonScanner) scanLiteral(first byte) (*jsonToken, error) { | ||||
| 	p := js.pos - 1 | ||||
|  | ||||
| 	lit := make([]byte, 4) | ||||
| 	lit[0] = first | ||||
|  | ||||
| 	err := js.readNNextBytes(lit, 3, 1) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	c5, err := js.readNextByte() | ||||
|  | ||||
| 	if bytes.Equal([]byte("true"), lit) && (isValueTerminator(c5) || err == io.EOF) { | ||||
| 		js.pos = int(math.Max(0, float64(js.pos-1))) | ||||
| 		return &jsonToken{t: jttBool, v: true, p: p}, nil | ||||
| 	} else if bytes.Equal([]byte("null"), lit) && (isValueTerminator(c5) || err == io.EOF) { | ||||
| 		js.pos = int(math.Max(0, float64(js.pos-1))) | ||||
| 		return &jsonToken{t: jttNull, v: nil, p: p}, nil | ||||
| 	} else if bytes.Equal([]byte("fals"), lit) { | ||||
| 		if c5 == 'e' { | ||||
| 			c5, err = js.readNextByte() | ||||
|  | ||||
| 			if isValueTerminator(c5) || err == io.EOF { | ||||
| 				js.pos = int(math.Max(0, float64(js.pos-1))) | ||||
| 				return &jsonToken{t: jttBool, v: false, p: p}, nil | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil, fmt.Errorf("invalid JSON literal. Position: %d, literal: %s", p, lit) | ||||
| } | ||||
|  | ||||
| type numberScanState byte | ||||
|  | ||||
| const ( | ||||
| 	nssSawLeadingMinus numberScanState = iota | ||||
| 	nssSawLeadingZero | ||||
| 	nssSawIntegerDigits | ||||
| 	nssSawDecimalPoint | ||||
| 	nssSawFractionDigits | ||||
| 	nssSawExponentLetter | ||||
| 	nssSawExponentSign | ||||
| 	nssSawExponentDigits | ||||
| 	nssDone | ||||
| 	nssInvalid | ||||
| ) | ||||
|  | ||||
| // scanNumber reads a JSON number (according to RFC-8259) | ||||
| func (js *jsonScanner) scanNumber(first byte) (*jsonToken, error) { | ||||
| 	var b bytes.Buffer | ||||
| 	var s numberScanState | ||||
| 	var c byte | ||||
| 	var err error | ||||
|  | ||||
| 	t := jttInt64 // assume it's an int64 until the type can be determined | ||||
| 	start := js.pos - 1 | ||||
|  | ||||
| 	b.WriteByte(first) | ||||
|  | ||||
| 	switch first { | ||||
| 	case '-': | ||||
| 		s = nssSawLeadingMinus | ||||
| 	case '0': | ||||
| 		s = nssSawLeadingZero | ||||
| 	default: | ||||
| 		s = nssSawIntegerDigits | ||||
| 	} | ||||
|  | ||||
| 	for { | ||||
| 		c, err = js.readNextByte() | ||||
|  | ||||
| 		if err != nil && err != io.EOF { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | ||||
| 		switch s { | ||||
| 		case nssSawLeadingMinus: | ||||
| 			switch c { | ||||
| 			case '0': | ||||
| 				s = nssSawLeadingZero | ||||
| 				b.WriteByte(c) | ||||
| 			default: | ||||
| 				if isDigit(c) { | ||||
| 					s = nssSawIntegerDigits | ||||
| 					b.WriteByte(c) | ||||
| 				} else { | ||||
| 					s = nssInvalid | ||||
| 				} | ||||
| 			} | ||||
| 		case nssSawLeadingZero: | ||||
| 			switch c { | ||||
| 			case '.': | ||||
| 				s = nssSawDecimalPoint | ||||
| 				b.WriteByte(c) | ||||
| 			case 'e', 'E': | ||||
| 				s = nssSawExponentLetter | ||||
| 				b.WriteByte(c) | ||||
| 			case '}', ']', ',': | ||||
| 				s = nssDone | ||||
| 			default: | ||||
| 				if isWhiteSpace(c) || err == io.EOF { | ||||
| 					s = nssDone | ||||
| 				} else { | ||||
| 					s = nssInvalid | ||||
| 				} | ||||
| 			} | ||||
| 		case nssSawIntegerDigits: | ||||
| 			switch c { | ||||
| 			case '.': | ||||
| 				s = nssSawDecimalPoint | ||||
| 				b.WriteByte(c) | ||||
| 			case 'e', 'E': | ||||
| 				s = nssSawExponentLetter | ||||
| 				b.WriteByte(c) | ||||
| 			case '}', ']', ',': | ||||
| 				s = nssDone | ||||
| 			default: | ||||
| 				if isWhiteSpace(c) || err == io.EOF { | ||||
| 					s = nssDone | ||||
| 				} else if isDigit(c) { | ||||
| 					s = nssSawIntegerDigits | ||||
| 					b.WriteByte(c) | ||||
| 				} else { | ||||
| 					s = nssInvalid | ||||
| 				} | ||||
| 			} | ||||
| 		case nssSawDecimalPoint: | ||||
| 			t = jttDouble | ||||
| 			if isDigit(c) { | ||||
| 				s = nssSawFractionDigits | ||||
| 				b.WriteByte(c) | ||||
| 			} else { | ||||
| 				s = nssInvalid | ||||
| 			} | ||||
| 		case nssSawFractionDigits: | ||||
| 			switch c { | ||||
| 			case 'e', 'E': | ||||
| 				s = nssSawExponentLetter | ||||
| 				b.WriteByte(c) | ||||
| 			case '}', ']', ',': | ||||
| 				s = nssDone | ||||
| 			default: | ||||
| 				if isWhiteSpace(c) || err == io.EOF { | ||||
| 					s = nssDone | ||||
| 				} else if isDigit(c) { | ||||
| 					s = nssSawFractionDigits | ||||
| 					b.WriteByte(c) | ||||
| 				} else { | ||||
| 					s = nssInvalid | ||||
| 				} | ||||
| 			} | ||||
| 		case nssSawExponentLetter: | ||||
| 			t = jttDouble | ||||
| 			switch c { | ||||
| 			case '+', '-': | ||||
| 				s = nssSawExponentSign | ||||
| 				b.WriteByte(c) | ||||
| 			default: | ||||
| 				if isDigit(c) { | ||||
| 					s = nssSawExponentDigits | ||||
| 					b.WriteByte(c) | ||||
| 				} else { | ||||
| 					s = nssInvalid | ||||
| 				} | ||||
| 			} | ||||
| 		case nssSawExponentSign: | ||||
| 			if isDigit(c) { | ||||
| 				s = nssSawExponentDigits | ||||
| 				b.WriteByte(c) | ||||
| 			} else { | ||||
| 				s = nssInvalid | ||||
| 			} | ||||
| 		case nssSawExponentDigits: | ||||
| 			switch c { | ||||
| 			case '}', ']', ',': | ||||
| 				s = nssDone | ||||
| 			default: | ||||
| 				if isWhiteSpace(c) || err == io.EOF { | ||||
| 					s = nssDone | ||||
| 				} else if isDigit(c) { | ||||
| 					s = nssSawExponentDigits | ||||
| 					b.WriteByte(c) | ||||
| 				} else { | ||||
| 					s = nssInvalid | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		switch s { | ||||
| 		case nssInvalid: | ||||
| 			return nil, fmt.Errorf("invalid JSON number. Position: %d", start) | ||||
| 		case nssDone: | ||||
| 			js.pos = int(math.Max(0, float64(js.pos-1))) | ||||
| 			if t != jttDouble { | ||||
| 				v, err := strconv.ParseInt(b.String(), 10, 64) | ||||
| 				if err == nil { | ||||
| 					if v < math.MinInt32 || v > math.MaxInt32 { | ||||
| 						return &jsonToken{t: jttInt64, v: v, p: start}, nil | ||||
| 					} | ||||
|  | ||||
| 					return &jsonToken{t: jttInt32, v: int32(v), p: start}, nil | ||||
| 				} | ||||
| 			} | ||||
|  | ||||
| 			v, err := strconv.ParseFloat(b.String(), 64) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 			return &jsonToken{t: jttDouble, v: v, p: start}, nil | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										376
									
								
								mongo/bson/bsonrw/json_scanner_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										376
									
								
								mongo/bson/bsonrw/json_scanner_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,376 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"testing/iotest" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| ) | ||||
|  | ||||
| func jttDiff(t *testing.T, expected, actual jsonTokenType, desc string) { | ||||
| 	if diff := cmp.Diff(expected, actual); diff != "" { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Incorrect JSON Token Type (-want, +got): %s\n", desc, diff) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func jtvDiff(t *testing.T, expected, actual interface{}, desc string) { | ||||
| 	if diff := cmp.Diff(expected, actual); diff != "" { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Incorrect JSON Token Value (-want, +got): %s\n", desc, diff) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectNilToken(t *testing.T, v *jsonToken, desc string) { | ||||
| 	if v != nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Expected nil JSON token", desc) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectError(t *testing.T, err error, desc string) { | ||||
| 	if err == nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Expected error", desc) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectNoError(t *testing.T, err error, desc string) { | ||||
| 	if err != nil { | ||||
| 		t.Helper() | ||||
| 		t.Errorf("%s: Unepexted error: %v", desc, err) | ||||
| 		t.FailNow() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type jsonScannerTestCase struct { | ||||
| 	desc   string | ||||
| 	input  string | ||||
| 	tokens []jsonToken | ||||
| } | ||||
|  | ||||
| // length = 512 | ||||
| const longKey = "abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyz" + "abcdefghijklmnopqrstuvwxyz" + | ||||
| 	"abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqr" | ||||
|  | ||||
| func TestJsonScannerValidInputs(t *testing.T) { | ||||
| 	cases := []jsonScannerTestCase{ | ||||
| 		{ | ||||
| 			desc: "empty input", input: "", | ||||
| 			tokens: []jsonToken{}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "empty object", input: "{}", | ||||
| 			tokens: []jsonToken{{t: jttBeginObject, v: byte('{')}, {t: jttEndObject, v: byte('}')}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "empty array", input: "[]", | ||||
| 			tokens: []jsonToken{{t: jttBeginArray, v: byte('[')}, {t: jttEndArray, v: byte(']')}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid empty string", input: `""`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: ""}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--no escaped characters", | ||||
| 			input:  `"string"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "string"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--escaped characters", | ||||
| 			input:  `"\"\\\/\b\f\n\r\t"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "\"\\/\b\f\n\r\t"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--surrogate pair", | ||||
| 			input:  `"abc \uD834\uDd1e 123"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "abc 𝄞 123"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--high surrogate at end of string", | ||||
| 			input:  `"abc \uD834"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "abc <20>"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--low surrogate at end of string", | ||||
| 			input:  `"abc \uDD1E"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "abc <20>"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--high surrogate with non-surrogate Unicode value", | ||||
| 			input:  `"abc \uDD1E\u00BF"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "abc <20>¿"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:   "valid string--high surrogate with non-Unicode escape sequence", | ||||
| 			input:  `"abc \uDD1E\t"`, | ||||
| 			tokens: []jsonToken{{t: jttString, v: "abc <20>\t"}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid literal--true", input: "true", | ||||
| 			tokens: []jsonToken{{t: jttBool, v: true}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid literal--false", input: "false", | ||||
| 			tokens: []jsonToken{{t: jttBool, v: false}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid literal--null", input: "null", | ||||
| 			tokens: []jsonToken{{t: jttNull}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: 0", input: "0", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(0)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: -0", input: "-0", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(0)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: 1", input: "1", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(1)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: -1", input: "-1", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(-1)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: 10", input: "10", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(10)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: 1234", input: "1234", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(1234)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: -10", input: "-10", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(-10)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int32: -1234", input: "-1234", | ||||
| 			tokens: []jsonToken{{t: jttInt32, v: int32(-1234)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int64: 2147483648", input: "2147483648", | ||||
| 			tokens: []jsonToken{{t: jttInt64, v: int64(2147483648)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid int64: -2147483649", input: "-2147483649", | ||||
| 			tokens: []jsonToken{{t: jttInt64, v: int64(-2147483649)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 0.0", input: "0.0", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 0.0}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -0.0", input: "-0.0", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 0.0}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 0.1", input: "0.1", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 0.1}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 0.1234", input: "0.1234", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 0.1234}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1.0", input: "1.0", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1.0}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.0", input: "-1.0", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.0}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1.234", input: "1.234", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1.234}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.234", input: "-1.234", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.234}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1e10", input: "1e10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1E10", input: "1E10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1.2e10", input: "1.2e10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1.2e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1.2E10", input: "1.2E10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1.2e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.2e10", input: "-1.2e10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.2E10", input: "-1.2E10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.2e+10", input: "-1.2e+10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.2E+10", input: "-1.2E+10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.2e+10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1.2e-10", input: "1.2e-10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1.2e-10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 1.2E-10", input: "1.2e-10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: 1.2e-10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.2e-10", input: "-1.2e-10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.2e-10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: -1.2E-10", input: "-1.2E-10", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: -1.2e-10}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid double: 8005332285744496613785600", input: "8005332285744496613785600", | ||||
| 			tokens: []jsonToken{{t: jttDouble, v: float64(8005332285744496613785600)}}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "valid object, only spaces", | ||||
| 			input: `{"key": "string", "key2": 2, "key3": {}, "key4": [], "key5": false }`, | ||||
| 			tokens: []jsonToken{ | ||||
| 				{t: jttBeginObject, v: byte('{')}, {t: jttString, v: "key"}, {t: jttColon, v: byte(':')}, {t: jttString, v: "string"}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key2"}, {t: jttColon, v: byte(':')}, {t: jttInt32, v: int32(2)}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key3"}, {t: jttColon, v: byte(':')}, {t: jttBeginObject, v: byte('{')}, {t: jttEndObject, v: byte('}')}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key4"}, {t: jttColon, v: byte(':')}, {t: jttBeginArray, v: byte('[')}, {t: jttEndArray, v: byte(']')}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key5"}, {t: jttColon, v: byte(':')}, {t: jttBool, v: false}, {t: jttEndObject, v: byte('}')}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "valid object, mixed whitespace", | ||||
| 			input: ` | ||||
| 					{ "key" : "string" | ||||
| 					, "key2": 2 | ||||
| 					, "key3": {} | ||||
| 					, "key4": [] | ||||
| 					, "key5": false | ||||
| 					}`, | ||||
| 			tokens: []jsonToken{ | ||||
| 				{t: jttBeginObject, v: byte('{')}, {t: jttString, v: "key"}, {t: jttColon, v: byte(':')}, {t: jttString, v: "string"}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key2"}, {t: jttColon, v: byte(':')}, {t: jttInt32, v: int32(2)}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key3"}, {t: jttColon, v: byte(':')}, {t: jttBeginObject, v: byte('{')}, {t: jttEndObject, v: byte('}')}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key4"}, {t: jttColon, v: byte(':')}, {t: jttBeginArray, v: byte('[')}, {t: jttEndArray, v: byte(']')}, | ||||
| 				{t: jttComma, v: byte(',')}, {t: jttString, v: "key5"}, {t: jttColon, v: byte(':')}, {t: jttBool, v: false}, {t: jttEndObject, v: byte('}')}, | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc:  "input greater than buffer size", | ||||
| 			input: `{"` + longKey + `": 1}`, | ||||
| 			tokens: []jsonToken{ | ||||
| 				{t: jttBeginObject, v: byte('{')}, {t: jttString, v: longKey}, {t: jttColon, v: byte(':')}, | ||||
| 				{t: jttInt32, v: int32(1)}, {t: jttEndObject, v: byte('}')}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		t.Run(tc.desc, func(t *testing.T) { | ||||
| 			js := &jsonScanner{r: strings.NewReader(tc.input)} | ||||
|  | ||||
| 			for _, token := range tc.tokens { | ||||
| 				c, err := js.nextToken() | ||||
| 				expectNoError(t, err, tc.desc) | ||||
| 				jttDiff(t, token.t, c.t, tc.desc) | ||||
| 				jtvDiff(t, token.v, c.v, tc.desc) | ||||
| 			} | ||||
|  | ||||
| 			c, err := js.nextToken() | ||||
| 			noerr(t, err) | ||||
| 			jttDiff(t, jttEOF, c.t, tc.desc) | ||||
|  | ||||
| 			// testing early EOF reading | ||||
| 			js = &jsonScanner{r: iotest.DataErrReader(strings.NewReader(tc.input))} | ||||
|  | ||||
| 			for _, token := range tc.tokens { | ||||
| 				c, err := js.nextToken() | ||||
| 				expectNoError(t, err, tc.desc) | ||||
| 				jttDiff(t, token.t, c.t, tc.desc) | ||||
| 				jtvDiff(t, token.v, c.v, tc.desc) | ||||
| 			} | ||||
|  | ||||
| 			c, err = js.nextToken() | ||||
| 			noerr(t, err) | ||||
| 			jttDiff(t, jttEOF, c.t, tc.desc) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestJsonScannerInvalidInputs(t *testing.T) { | ||||
| 	cases := []jsonScannerTestCase{ | ||||
| 		{desc: "missing quotation", input: `"missing`}, | ||||
| 		{desc: "invalid escape character--first character", input: `"\invalid"`}, | ||||
| 		{desc: "invalid escape character--middle", input: `"i\nv\alid"`}, | ||||
| 		{desc: "invalid escape character--single quote", input: `"f\'oo"`}, | ||||
| 		{desc: "invalid literal--trueee", input: "trueee"}, | ||||
| 		{desc: "invalid literal--tire", input: "tire"}, | ||||
| 		{desc: "invalid literal--nulll", input: "nulll"}, | ||||
| 		{desc: "invalid literal--fals", input: "fals"}, | ||||
| 		{desc: "invalid literal--falsee", input: "falsee"}, | ||||
| 		{desc: "invalid literal--fake", input: "fake"}, | ||||
| 		{desc: "invalid literal--bad", input: "bad"}, | ||||
| 		{desc: "invalid number: -", input: "-"}, | ||||
| 		{desc: "invalid number: --0", input: "--0"}, | ||||
| 		{desc: "invalid number: -a", input: "-a"}, | ||||
| 		{desc: "invalid number: 00", input: "00"}, | ||||
| 		{desc: "invalid number: 01", input: "01"}, | ||||
| 		{desc: "invalid number: 0-", input: "0-"}, | ||||
| 		{desc: "invalid number: 1-", input: "1-"}, | ||||
| 		{desc: "invalid number: 0..", input: "0.."}, | ||||
| 		{desc: "invalid number: 0.-", input: "0.-"}, | ||||
| 		{desc: "invalid number: 0..0", input: "0..0"}, | ||||
| 		{desc: "invalid number: 0.1.0", input: "0.1.0"}, | ||||
| 		{desc: "invalid number: 0e", input: "0e"}, | ||||
| 		{desc: "invalid number: 0e.", input: "0e."}, | ||||
| 		{desc: "invalid number: 0e1.", input: "0e1."}, | ||||
| 		{desc: "invalid number: 0e1e", input: "0e1e"}, | ||||
| 		{desc: "invalid number: 0e+.1", input: "0e+.1"}, | ||||
| 		{desc: "invalid number: 0e+1.", input: "0e+1."}, | ||||
| 		{desc: "invalid number: 0e+1e", input: "0e+1e"}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		t.Run(tc.desc, func(t *testing.T) { | ||||
| 			js := &jsonScanner{r: strings.NewReader(tc.input)} | ||||
|  | ||||
| 			c, err := js.nextToken() | ||||
| 			expectNilToken(t, c, tc.desc) | ||||
| 			expectError(t, err, tc.desc) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										108
									
								
								mongo/bson/bsonrw/mode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								mongo/bson/bsonrw/mode.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,108 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| ) | ||||
|  | ||||
| type mode int | ||||
|  | ||||
| const ( | ||||
| 	_ mode = iota | ||||
| 	mTopLevel | ||||
| 	mDocument | ||||
| 	mArray | ||||
| 	mValue | ||||
| 	mElement | ||||
| 	mCodeWithScope | ||||
| 	mSpacer | ||||
| ) | ||||
|  | ||||
| func (m mode) String() string { | ||||
| 	var str string | ||||
|  | ||||
| 	switch m { | ||||
| 	case mTopLevel: | ||||
| 		str = "TopLevel" | ||||
| 	case mDocument: | ||||
| 		str = "DocumentMode" | ||||
| 	case mArray: | ||||
| 		str = "ArrayMode" | ||||
| 	case mValue: | ||||
| 		str = "ValueMode" | ||||
| 	case mElement: | ||||
| 		str = "ElementMode" | ||||
| 	case mCodeWithScope: | ||||
| 		str = "CodeWithScopeMode" | ||||
| 	case mSpacer: | ||||
| 		str = "CodeWithScopeSpacerFrame" | ||||
| 	default: | ||||
| 		str = "UnknownMode" | ||||
| 	} | ||||
|  | ||||
| 	return str | ||||
| } | ||||
|  | ||||
| func (m mode) TypeString() string { | ||||
| 	var str string | ||||
|  | ||||
| 	switch m { | ||||
| 	case mTopLevel: | ||||
| 		str = "TopLevel" | ||||
| 	case mDocument: | ||||
| 		str = "Document" | ||||
| 	case mArray: | ||||
| 		str = "Array" | ||||
| 	case mValue: | ||||
| 		str = "Value" | ||||
| 	case mElement: | ||||
| 		str = "Element" | ||||
| 	case mCodeWithScope: | ||||
| 		str = "CodeWithScope" | ||||
| 	case mSpacer: | ||||
| 		str = "CodeWithScopeSpacer" | ||||
| 	default: | ||||
| 		str = "Unknown" | ||||
| 	} | ||||
|  | ||||
| 	return str | ||||
| } | ||||
|  | ||||
| // TransitionError is an error returned when an invalid progressing a | ||||
| // ValueReader or ValueWriter state machine occurs. | ||||
| // If read is false, the error is for writing | ||||
| type TransitionError struct { | ||||
| 	name        string | ||||
| 	parent      mode | ||||
| 	current     mode | ||||
| 	destination mode | ||||
| 	modes       []mode | ||||
| 	action      string | ||||
| } | ||||
|  | ||||
| func (te TransitionError) Error() string { | ||||
| 	errString := fmt.Sprintf("%s can only %s", te.name, te.action) | ||||
| 	if te.destination != mode(0) { | ||||
| 		errString = fmt.Sprintf("%s a %s", errString, te.destination.TypeString()) | ||||
| 	} | ||||
| 	errString = fmt.Sprintf("%s while positioned on a", errString) | ||||
| 	for ind, m := range te.modes { | ||||
| 		if ind != 0 && len(te.modes) > 2 { | ||||
| 			errString = fmt.Sprintf("%s,", errString) | ||||
| 		} | ||||
| 		if ind == len(te.modes)-1 && len(te.modes) > 1 { | ||||
| 			errString = fmt.Sprintf("%s or", errString) | ||||
| 		} | ||||
| 		errString = fmt.Sprintf("%s %s", errString, m.TypeString()) | ||||
| 	} | ||||
| 	errString = fmt.Sprintf("%s but is positioned on a %s", errString, te.current.TypeString()) | ||||
| 	if te.parent != mode(0) { | ||||
| 		errString = fmt.Sprintf("%s with parent %s", errString, te.parent.TypeString()) | ||||
| 	} | ||||
| 	return errString | ||||
| } | ||||
							
								
								
									
										63
									
								
								mongo/bson/bsonrw/reader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								mongo/bson/bsonrw/reader.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| // ArrayReader is implemented by types that allow reading values from a BSON | ||||
| // array. | ||||
| type ArrayReader interface { | ||||
| 	ReadValue() (ValueReader, error) | ||||
| } | ||||
|  | ||||
| // DocumentReader is implemented by types that allow reading elements from a | ||||
| // BSON document. | ||||
| type DocumentReader interface { | ||||
| 	ReadElement() (string, ValueReader, error) | ||||
| } | ||||
|  | ||||
| // ValueReader is a generic interface used to read values from BSON. This type | ||||
| // is implemented by several types with different underlying representations of | ||||
| // BSON, such as a bson.Document, raw BSON bytes, or extended JSON. | ||||
| type ValueReader interface { | ||||
| 	Type() bsontype.Type | ||||
| 	Skip() error | ||||
|  | ||||
| 	ReadArray() (ArrayReader, error) | ||||
| 	ReadBinary() (b []byte, btype byte, err error) | ||||
| 	ReadBoolean() (bool, error) | ||||
| 	ReadDocument() (DocumentReader, error) | ||||
| 	ReadCodeWithScope() (code string, dr DocumentReader, err error) | ||||
| 	ReadDBPointer() (ns string, oid primitive.ObjectID, err error) | ||||
| 	ReadDateTime() (int64, error) | ||||
| 	ReadDecimal128() (primitive.Decimal128, error) | ||||
| 	ReadDouble() (float64, error) | ||||
| 	ReadInt32() (int32, error) | ||||
| 	ReadInt64() (int64, error) | ||||
| 	ReadJavascript() (code string, err error) | ||||
| 	ReadMaxKey() error | ||||
| 	ReadMinKey() error | ||||
| 	ReadNull() error | ||||
| 	ReadObjectID() (primitive.ObjectID, error) | ||||
| 	ReadRegex() (pattern, options string, err error) | ||||
| 	ReadString() (string, error) | ||||
| 	ReadSymbol() (symbol string, err error) | ||||
| 	ReadTimestamp() (t, i uint32, err error) | ||||
| 	ReadUndefined() error | ||||
| } | ||||
|  | ||||
| // BytesReader is a generic interface used to read BSON bytes from a | ||||
| // ValueReader. This imterface is meant to be a superset of ValueReader, so that | ||||
| // types that implement ValueReader may also implement this interface. | ||||
| // | ||||
| // The bytes of the value will be appended to dst. | ||||
| type BytesReader interface { | ||||
| 	ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) | ||||
| } | ||||
							
								
								
									
										874
									
								
								mongo/bson/bsonrw/value_reader.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										874
									
								
								mongo/bson/bsonrw/value_reader.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,874 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/binary" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| var _ ValueReader = (*valueReader)(nil) | ||||
|  | ||||
| var vrPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		return new(valueReader) | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| // BSONValueReaderPool is a pool for ValueReaders that read BSON. | ||||
| type BSONValueReaderPool struct { | ||||
| 	pool sync.Pool | ||||
| } | ||||
|  | ||||
| // NewBSONValueReaderPool instantiates a new BSONValueReaderPool. | ||||
| func NewBSONValueReaderPool() *BSONValueReaderPool { | ||||
| 	return &BSONValueReaderPool{ | ||||
| 		pool: sync.Pool{ | ||||
| 			New: func() interface{} { | ||||
| 				return new(valueReader) | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Get retrieves a ValueReader from the pool and uses src as the underlying BSON. | ||||
| func (bvrp *BSONValueReaderPool) Get(src []byte) ValueReader { | ||||
| 	vr := bvrp.pool.Get().(*valueReader) | ||||
| 	vr.reset(src) | ||||
| 	return vr | ||||
| } | ||||
|  | ||||
| // Put inserts a ValueReader into the pool. If the ValueReader is not a BSON ValueReader nothing | ||||
| // is inserted into the pool and ok will be false. | ||||
| func (bvrp *BSONValueReaderPool) Put(vr ValueReader) (ok bool) { | ||||
| 	bvr, ok := vr.(*valueReader) | ||||
| 	if !ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	bvr.reset(nil) | ||||
| 	bvrp.pool.Put(bvr) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // ErrEOA is the error returned when the end of a BSON array has been reached. | ||||
| var ErrEOA = errors.New("end of array") | ||||
|  | ||||
| // ErrEOD is the error returned when the end of a BSON document has been reached. | ||||
| var ErrEOD = errors.New("end of document") | ||||
|  | ||||
| type vrState struct { | ||||
| 	mode  mode | ||||
| 	vType bsontype.Type | ||||
| 	end   int64 | ||||
| } | ||||
|  | ||||
| // valueReader is for reading BSON values. | ||||
| type valueReader struct { | ||||
| 	offset int64 | ||||
| 	d      []byte | ||||
|  | ||||
| 	stack []vrState | ||||
| 	frame int64 | ||||
| } | ||||
|  | ||||
| // NewBSONDocumentReader returns a ValueReader using b for the underlying BSON | ||||
| // representation. Parameter b must be a BSON Document. | ||||
| func NewBSONDocumentReader(b []byte) ValueReader { | ||||
| 	// TODO(skriptble): There's a lack of symmetry between the reader and writer, since the reader takes a []byte while the | ||||
| 	// TODO writer takes an io.Writer. We should have two versions of each, one that takes a []byte and one that takes an | ||||
| 	// TODO io.Reader or io.Writer. The []byte version will need to return a thing that can return the finished []byte since | ||||
| 	// TODO it might be reallocated when appended to. | ||||
| 	return newValueReader(b) | ||||
| } | ||||
|  | ||||
| // NewBSONValueReader returns a ValueReader that starts in the Value mode instead of in top | ||||
| // level document mode. This enables the creation of a ValueReader for a single BSON value. | ||||
| func NewBSONValueReader(t bsontype.Type, val []byte) ValueReader { | ||||
| 	stack := make([]vrState, 1, 5) | ||||
| 	stack[0] = vrState{ | ||||
| 		mode:  mValue, | ||||
| 		vType: t, | ||||
| 	} | ||||
| 	return &valueReader{ | ||||
| 		d:     val, | ||||
| 		stack: stack, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func newValueReader(b []byte) *valueReader { | ||||
| 	stack := make([]vrState, 1, 5) | ||||
| 	stack[0] = vrState{ | ||||
| 		mode: mTopLevel, | ||||
| 	} | ||||
| 	return &valueReader{ | ||||
| 		d:     b, | ||||
| 		stack: stack, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) reset(b []byte) { | ||||
| 	if vr.stack == nil { | ||||
| 		vr.stack = make([]vrState, 1, 5) | ||||
| 	} | ||||
| 	vr.stack = vr.stack[:1] | ||||
| 	vr.stack[0] = vrState{mode: mTopLevel} | ||||
| 	vr.d = b | ||||
| 	vr.offset = 0 | ||||
| 	vr.frame = 0 | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) advanceFrame() { | ||||
| 	if vr.frame+1 >= int64(len(vr.stack)) { // We need to grow the stack | ||||
| 		length := len(vr.stack) | ||||
| 		if length+1 >= cap(vr.stack) { | ||||
| 			// double it | ||||
| 			buf := make([]vrState, 2*cap(vr.stack)+1) | ||||
| 			copy(buf, vr.stack) | ||||
| 			vr.stack = buf | ||||
| 		} | ||||
| 		vr.stack = vr.stack[:length+1] | ||||
| 	} | ||||
| 	vr.frame++ | ||||
|  | ||||
| 	// Clean the stack | ||||
| 	vr.stack[vr.frame].mode = 0 | ||||
| 	vr.stack[vr.frame].vType = 0 | ||||
| 	vr.stack[vr.frame].end = 0 | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) pushDocument() error { | ||||
| 	vr.advanceFrame() | ||||
|  | ||||
| 	vr.stack[vr.frame].mode = mDocument | ||||
|  | ||||
| 	size, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	vr.stack[vr.frame].end = int64(size) + vr.offset - 4 | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) pushArray() error { | ||||
| 	vr.advanceFrame() | ||||
|  | ||||
| 	vr.stack[vr.frame].mode = mArray | ||||
|  | ||||
| 	size, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	vr.stack[vr.frame].end = int64(size) + vr.offset - 4 | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) pushElement(t bsontype.Type) { | ||||
| 	vr.advanceFrame() | ||||
|  | ||||
| 	vr.stack[vr.frame].mode = mElement | ||||
| 	vr.stack[vr.frame].vType = t | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) pushValue(t bsontype.Type) { | ||||
| 	vr.advanceFrame() | ||||
|  | ||||
| 	vr.stack[vr.frame].mode = mValue | ||||
| 	vr.stack[vr.frame].vType = t | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) pushCodeWithScope() (int64, error) { | ||||
| 	vr.advanceFrame() | ||||
|  | ||||
| 	vr.stack[vr.frame].mode = mCodeWithScope | ||||
|  | ||||
| 	size, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
| 	vr.stack[vr.frame].end = int64(size) + vr.offset - 4 | ||||
|  | ||||
| 	return int64(size), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) pop() { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 		vr.frame-- | ||||
| 	case mDocument, mArray, mCodeWithScope: | ||||
| 		vr.frame -= 2 // we pop twice to jump over the vrElement: vrDocument -> vrElement -> vrDocument/TopLevel/etc... | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) invalidTransitionErr(destination mode, name string, modes []mode) error { | ||||
| 	te := TransitionError{ | ||||
| 		name:        name, | ||||
| 		current:     vr.stack[vr.frame].mode, | ||||
| 		destination: destination, | ||||
| 		modes:       modes, | ||||
| 		action:      "read", | ||||
| 	} | ||||
| 	if vr.frame != 0 { | ||||
| 		te.parent = vr.stack[vr.frame-1].mode | ||||
| 	} | ||||
| 	return te | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) typeError(t bsontype.Type) error { | ||||
| 	return fmt.Errorf("positioned on %s, but attempted to read %s", vr.stack[vr.frame].vType, t) | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) invalidDocumentLengthError() error { | ||||
| 	return fmt.Errorf("document is invalid, end byte is at %d, but null byte found at %d", vr.stack[vr.frame].end, vr.offset) | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ensureElementValue(t bsontype.Type, destination mode, callerName string) error { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 		if vr.stack[vr.frame].vType != t { | ||||
| 			return vr.typeError(t) | ||||
| 		} | ||||
| 	default: | ||||
| 		return vr.invalidTransitionErr(destination, callerName, []mode{mElement, mValue}) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) Type() bsontype.Type { | ||||
| 	return vr.stack[vr.frame].vType | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) nextElementLength() (int32, error) { | ||||
| 	var length int32 | ||||
| 	var err error | ||||
| 	switch vr.stack[vr.frame].vType { | ||||
| 	case bsontype.Array, bsontype.EmbeddedDocument, bsontype.CodeWithScope: | ||||
| 		length, err = vr.peekLength() | ||||
| 	case bsontype.Binary: | ||||
| 		length, err = vr.peekLength() | ||||
| 		length += 4 + 1 // binary length + subtype byte | ||||
| 	case bsontype.Boolean: | ||||
| 		length = 1 | ||||
| 	case bsontype.DBPointer: | ||||
| 		length, err = vr.peekLength() | ||||
| 		length += 4 + 12 // string length + ObjectID length | ||||
| 	case bsontype.DateTime, bsontype.Double, bsontype.Int64, bsontype.Timestamp: | ||||
| 		length = 8 | ||||
| 	case bsontype.Decimal128: | ||||
| 		length = 16 | ||||
| 	case bsontype.Int32: | ||||
| 		length = 4 | ||||
| 	case bsontype.JavaScript, bsontype.String, bsontype.Symbol: | ||||
| 		length, err = vr.peekLength() | ||||
| 		length += 4 | ||||
| 	case bsontype.MaxKey, bsontype.MinKey, bsontype.Null, bsontype.Undefined: | ||||
| 		length = 0 | ||||
| 	case bsontype.ObjectID: | ||||
| 		length = 12 | ||||
| 	case bsontype.Regex: | ||||
| 		regex := bytes.IndexByte(vr.d[vr.offset:], 0x00) | ||||
| 		if regex < 0 { | ||||
| 			err = io.EOF | ||||
| 			break | ||||
| 		} | ||||
| 		pattern := bytes.IndexByte(vr.d[vr.offset+int64(regex)+1:], 0x00) | ||||
| 		if pattern < 0 { | ||||
| 			err = io.EOF | ||||
| 			break | ||||
| 		} | ||||
| 		length = int32(int64(regex) + 1 + int64(pattern) + 1) | ||||
| 	default: | ||||
| 		return 0, fmt.Errorf("attempted to read bytes of unknown BSON type %v", vr.stack[vr.frame].vType) | ||||
| 	} | ||||
|  | ||||
| 	return length, err | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadValueBytes(dst []byte) (bsontype.Type, []byte, error) { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mTopLevel: | ||||
| 		length, err := vr.peekLength() | ||||
| 		if err != nil { | ||||
| 			return bsontype.Type(0), nil, err | ||||
| 		} | ||||
| 		dst, err = vr.appendBytes(dst, length) | ||||
| 		if err != nil { | ||||
| 			return bsontype.Type(0), nil, err | ||||
| 		} | ||||
| 		return bsontype.Type(0), dst, nil | ||||
| 	case mElement, mValue: | ||||
| 		length, err := vr.nextElementLength() | ||||
| 		if err != nil { | ||||
| 			return bsontype.Type(0), dst, err | ||||
| 		} | ||||
|  | ||||
| 		dst, err = vr.appendBytes(dst, length) | ||||
| 		t := vr.stack[vr.frame].vType | ||||
| 		vr.pop() | ||||
| 		return t, dst, err | ||||
| 	default: | ||||
| 		return bsontype.Type(0), nil, vr.invalidTransitionErr(0, "ReadValueBytes", []mode{mElement, mValue}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) Skip() error { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 	default: | ||||
| 		return vr.invalidTransitionErr(0, "Skip", []mode{mElement, mValue}) | ||||
| 	} | ||||
|  | ||||
| 	length, err := vr.nextElementLength() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = vr.skipBytes(length) | ||||
| 	vr.pop() | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadArray() (ArrayReader, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Array, mArray, "ReadArray"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err := vr.pushArray() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return vr, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadBinary() (b []byte, btype byte, err error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Binary, 0, "ReadBinary"); err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	length, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	btype, err = vr.readByte() | ||||
| 	if err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
|  | ||||
| 	// Check length in case it is an old binary without a length. | ||||
| 	if btype == 0x02 && length > 4 { | ||||
| 		length, err = vr.readLength() | ||||
| 		if err != nil { | ||||
| 			return nil, 0, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	b, err = vr.readBytes(length) | ||||
| 	if err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
| 	// Make a copy of the returned byte slice because it's just a subslice from the valueReader's | ||||
| 	// buffer and is not safe to return in the unmarshaled value. | ||||
| 	cp := make([]byte, len(b)) | ||||
| 	copy(cp, b) | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return cp, btype, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadBoolean() (bool, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Boolean, 0, "ReadBoolean"); err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	b, err := vr.readByte() | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	if b > 1 { | ||||
| 		return false, fmt.Errorf("invalid byte for boolean, %b", b) | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return b == 1, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadDocument() (DocumentReader, error) { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mTopLevel: | ||||
| 		// read size | ||||
| 		size, err := vr.readLength() | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		if int(size) != len(vr.d) { | ||||
| 			return nil, fmt.Errorf("invalid document length") | ||||
| 		} | ||||
| 		vr.stack[vr.frame].end = int64(size) + vr.offset - 4 | ||||
| 		return vr, nil | ||||
| 	case mElement, mValue: | ||||
| 		if vr.stack[vr.frame].vType != bsontype.EmbeddedDocument { | ||||
| 			return nil, vr.typeError(bsontype.EmbeddedDocument) | ||||
| 		} | ||||
| 	default: | ||||
| 		return nil, vr.invalidTransitionErr(mDocument, "ReadDocument", []mode{mTopLevel, mElement, mValue}) | ||||
| 	} | ||||
|  | ||||
| 	err := vr.pushDocument() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return vr, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadCodeWithScope() (code string, dr DocumentReader, err error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.CodeWithScope, 0, "ReadCodeWithScope"); err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	totalLength, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
| 	strLength, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
| 	if strLength <= 0 { | ||||
| 		return "", nil, fmt.Errorf("invalid string length: %d", strLength) | ||||
| 	} | ||||
| 	strBytes, err := vr.readBytes(strLength) | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
| 	code = string(strBytes[:len(strBytes)-1]) | ||||
|  | ||||
| 	size, err := vr.pushCodeWithScope() | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	// The total length should equal: | ||||
| 	// 4 (total length) + strLength + 4 (the length of str itself) + (document length) | ||||
| 	componentsLength := int64(4+strLength+4) + size | ||||
| 	if int64(totalLength) != componentsLength { | ||||
| 		return "", nil, fmt.Errorf( | ||||
| 			"length of CodeWithScope does not match lengths of components; total: %d; components: %d", | ||||
| 			totalLength, componentsLength, | ||||
| 		) | ||||
| 	} | ||||
| 	return code, vr, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.DBPointer, 0, "ReadDBPointer"); err != nil { | ||||
| 		return "", oid, err | ||||
| 	} | ||||
|  | ||||
| 	ns, err = vr.readString() | ||||
| 	if err != nil { | ||||
| 		return "", oid, err | ||||
| 	} | ||||
|  | ||||
| 	oidbytes, err := vr.readBytes(12) | ||||
| 	if err != nil { | ||||
| 		return "", oid, err | ||||
| 	} | ||||
|  | ||||
| 	copy(oid[:], oidbytes) | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return ns, oid, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadDateTime() (int64, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.DateTime, 0, "ReadDateTime"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	i, err := vr.readi64() | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return i, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadDecimal128() (primitive.Decimal128, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Decimal128, 0, "ReadDecimal128"); err != nil { | ||||
| 		return primitive.Decimal128{}, err | ||||
| 	} | ||||
|  | ||||
| 	b, err := vr.readBytes(16) | ||||
| 	if err != nil { | ||||
| 		return primitive.Decimal128{}, err | ||||
| 	} | ||||
|  | ||||
| 	l := binary.LittleEndian.Uint64(b[0:8]) | ||||
| 	h := binary.LittleEndian.Uint64(b[8:16]) | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return primitive.NewDecimal128(h, l), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadDouble() (float64, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Double, 0, "ReadDouble"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	u, err := vr.readu64() | ||||
| 	if err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return math.Float64frombits(u), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadInt32() (int32, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Int32, 0, "ReadInt32"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return vr.readi32() | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadInt64() (int64, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Int64, 0, "ReadInt64"); err != nil { | ||||
| 		return 0, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return vr.readi64() | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadJavascript() (code string, err error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.JavaScript, 0, "ReadJavascript"); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return vr.readString() | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadMaxKey() error { | ||||
| 	if err := vr.ensureElementValue(bsontype.MaxKey, 0, "ReadMaxKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadMinKey() error { | ||||
| 	if err := vr.ensureElementValue(bsontype.MinKey, 0, "ReadMinKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadNull() error { | ||||
| 	if err := vr.ensureElementValue(bsontype.Null, 0, "ReadNull"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadObjectID() (primitive.ObjectID, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.ObjectID, 0, "ReadObjectID"); err != nil { | ||||
| 		return primitive.ObjectID{}, err | ||||
| 	} | ||||
|  | ||||
| 	oidbytes, err := vr.readBytes(12) | ||||
| 	if err != nil { | ||||
| 		return primitive.ObjectID{}, err | ||||
| 	} | ||||
|  | ||||
| 	var oid primitive.ObjectID | ||||
| 	copy(oid[:], oidbytes) | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return oid, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadRegex() (string, string, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Regex, 0, "ReadRegex"); err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	pattern, err := vr.readCString() | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	options, err := vr.readCString() | ||||
| 	if err != nil { | ||||
| 		return "", "", err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return pattern, options, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadString() (string, error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.String, 0, "ReadString"); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return vr.readString() | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadSymbol() (symbol string, err error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Symbol, 0, "ReadSymbol"); err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return vr.readString() | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadTimestamp() (t uint32, i uint32, err error) { | ||||
| 	if err := vr.ensureElementValue(bsontype.Timestamp, 0, "ReadTimestamp"); err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	i, err = vr.readu32() | ||||
| 	if err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	t, err = vr.readu32() | ||||
| 	if err != nil { | ||||
| 		return 0, 0, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return t, i, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadUndefined() error { | ||||
| 	if err := vr.ensureElementValue(bsontype.Undefined, 0, "ReadUndefined"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vr.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadElement() (string, ValueReader, error) { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mTopLevel, mDocument, mCodeWithScope: | ||||
| 	default: | ||||
| 		return "", nil, vr.invalidTransitionErr(mElement, "ReadElement", []mode{mTopLevel, mDocument, mCodeWithScope}) | ||||
| 	} | ||||
|  | ||||
| 	t, err := vr.readByte() | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	if t == 0 { | ||||
| 		if vr.offset != vr.stack[vr.frame].end { | ||||
| 			return "", nil, vr.invalidDocumentLengthError() | ||||
| 		} | ||||
|  | ||||
| 		vr.pop() | ||||
| 		return "", nil, ErrEOD | ||||
| 	} | ||||
|  | ||||
| 	name, err := vr.readCString() | ||||
| 	if err != nil { | ||||
| 		return "", nil, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pushElement(bsontype.Type(t)) | ||||
| 	return name, vr, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) ReadValue() (ValueReader, error) { | ||||
| 	switch vr.stack[vr.frame].mode { | ||||
| 	case mArray: | ||||
| 	default: | ||||
| 		return nil, vr.invalidTransitionErr(mValue, "ReadValue", []mode{mArray}) | ||||
| 	} | ||||
|  | ||||
| 	t, err := vr.readByte() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if t == 0 { | ||||
| 		if vr.offset != vr.stack[vr.frame].end { | ||||
| 			return nil, vr.invalidDocumentLengthError() | ||||
| 		} | ||||
|  | ||||
| 		vr.pop() | ||||
| 		return nil, ErrEOA | ||||
| 	} | ||||
|  | ||||
| 	_, err = vr.readCString() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	vr.pushValue(bsontype.Type(t)) | ||||
| 	return vr, nil | ||||
| } | ||||
|  | ||||
| // readBytes reads length bytes from the valueReader starting at the current offset. Note that the | ||||
| // returned byte slice is a subslice from the valueReader buffer and must be converted or copied | ||||
| // before returning in an unmarshaled value. | ||||
| func (vr *valueReader) readBytes(length int32) ([]byte, error) { | ||||
| 	if length < 0 { | ||||
| 		return nil, fmt.Errorf("invalid length: %d", length) | ||||
| 	} | ||||
|  | ||||
| 	if vr.offset+int64(length) > int64(len(vr.d)) { | ||||
| 		return nil, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	start := vr.offset | ||||
| 	vr.offset += int64(length) | ||||
|  | ||||
| 	return vr.d[start : start+int64(length)], nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) appendBytes(dst []byte, length int32) ([]byte, error) { | ||||
| 	if vr.offset+int64(length) > int64(len(vr.d)) { | ||||
| 		return nil, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	start := vr.offset | ||||
| 	vr.offset += int64(length) | ||||
| 	return append(dst, vr.d[start:start+int64(length)]...), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) skipBytes(length int32) error { | ||||
| 	if vr.offset+int64(length) > int64(len(vr.d)) { | ||||
| 		return io.EOF | ||||
| 	} | ||||
|  | ||||
| 	vr.offset += int64(length) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readByte() (byte, error) { | ||||
| 	if vr.offset+1 > int64(len(vr.d)) { | ||||
| 		return 0x0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	vr.offset++ | ||||
| 	return vr.d[vr.offset-1], nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readCString() (string, error) { | ||||
| 	idx := bytes.IndexByte(vr.d[vr.offset:], 0x00) | ||||
| 	if idx < 0 { | ||||
| 		return "", io.EOF | ||||
| 	} | ||||
| 	start := vr.offset | ||||
| 	// idx does not include the null byte | ||||
| 	vr.offset += int64(idx) + 1 | ||||
| 	return string(vr.d[start : start+int64(idx)]), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readString() (string, error) { | ||||
| 	length, err := vr.readLength() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	if int64(length)+vr.offset > int64(len(vr.d)) { | ||||
| 		return "", io.EOF | ||||
| 	} | ||||
|  | ||||
| 	if length <= 0 { | ||||
| 		return "", fmt.Errorf("invalid string length: %d", length) | ||||
| 	} | ||||
|  | ||||
| 	if vr.d[vr.offset+int64(length)-1] != 0x00 { | ||||
| 		return "", fmt.Errorf("string does not end with null byte, but with %v", vr.d[vr.offset+int64(length)-1]) | ||||
| 	} | ||||
|  | ||||
| 	start := vr.offset | ||||
| 	vr.offset += int64(length) | ||||
| 	return string(vr.d[start : start+int64(length)-1]), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) peekLength() (int32, error) { | ||||
| 	if vr.offset+4 > int64(len(vr.d)) { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	idx := vr.offset | ||||
| 	return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readLength() (int32, error) { return vr.readi32() } | ||||
|  | ||||
| func (vr *valueReader) readi32() (int32, error) { | ||||
| 	if vr.offset+4 > int64(len(vr.d)) { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	idx := vr.offset | ||||
| 	vr.offset += 4 | ||||
| 	return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readu32() (uint32, error) { | ||||
| 	if vr.offset+4 > int64(len(vr.d)) { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	idx := vr.offset | ||||
| 	vr.offset += 4 | ||||
| 	return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readi64() (int64, error) { | ||||
| 	if vr.offset+8 > int64(len(vr.d)) { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	idx := vr.offset | ||||
| 	vr.offset += 8 | ||||
| 	return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 | | ||||
| 		int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil | ||||
| } | ||||
|  | ||||
| func (vr *valueReader) readu64() (uint64, error) { | ||||
| 	if vr.offset+8 > int64(len(vr.d)) { | ||||
| 		return 0, io.EOF | ||||
| 	} | ||||
|  | ||||
| 	idx := vr.offset | ||||
| 	vr.offset += 8 | ||||
| 	return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 | | ||||
| 		uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil | ||||
| } | ||||
							
								
								
									
										1538
									
								
								mongo/bson/bsonrw/value_reader_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1538
									
								
								mongo/bson/bsonrw/value_reader_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										608
									
								
								mongo/bson/bsonrw/value_reader_writer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										608
									
								
								mongo/bson/bsonrw/value_reader_writer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,608 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| type VRWInvoked byte | ||||
|  | ||||
| const ( | ||||
| 	_                         = iota | ||||
| 	llvrwReadArray VRWInvoked = 1 | ||||
| 	llvrwReadBinary | ||||
| 	llvrwReadBoolean | ||||
| 	llvrwReadDocument | ||||
| 	llvrwReadCodeWithScope | ||||
| 	llvrwReadDBPointer | ||||
| 	llvrwReadDateTime | ||||
| 	llvrwReadDecimal128 | ||||
| 	llvrwReadDouble | ||||
| 	llvrwReadInt32 | ||||
| 	llvrwReadInt64 | ||||
| 	llvrwReadJavascript | ||||
| 	llvrwReadMaxKey | ||||
| 	llvrwReadMinKey | ||||
| 	llvrwReadNull | ||||
| 	llvrwReadObjectID | ||||
| 	llvrwReadRegex | ||||
| 	llvrwReadString | ||||
| 	llvrwReadSymbol | ||||
| 	llvrwReadTimestamp | ||||
| 	llvrwReadUndefined | ||||
| 	llvrwReadElement | ||||
| 	llvrwReadValue | ||||
| 	llvrwWriteArray | ||||
| 	llvrwWriteBinary | ||||
| 	llvrwWriteBinaryWithSubtype | ||||
| 	llvrwWriteBoolean | ||||
| 	llvrwWriteCodeWithScope | ||||
| 	llvrwWriteDBPointer | ||||
| 	llvrwWriteDateTime | ||||
| 	llvrwWriteDecimal128 | ||||
| 	llvrwWriteDouble | ||||
| 	llvrwWriteInt32 | ||||
| 	llvrwWriteInt64 | ||||
| 	llvrwWriteJavascript | ||||
| 	llvrwWriteMaxKey | ||||
| 	llvrwWriteMinKey | ||||
| 	llvrwWriteNull | ||||
| 	llvrwWriteObjectID | ||||
| 	llvrwWriteRegex | ||||
| 	llvrwWriteString | ||||
| 	llvrwWriteDocument | ||||
| 	llvrwWriteSymbol | ||||
| 	llvrwWriteTimestamp | ||||
| 	llvrwWriteUndefined | ||||
| 	llvrwWriteDocumentElement | ||||
| 	llvrwWriteDocumentEnd | ||||
| 	llvrwWriteArrayElement | ||||
| 	llvrwWriteArrayEnd | ||||
| ) | ||||
|  | ||||
| type TestValueReaderWriter struct { | ||||
| 	t        *testing.T | ||||
| 	invoked  VRWInvoked | ||||
| 	readval  interface{} | ||||
| 	bsontype bsontype.Type | ||||
| 	err      error | ||||
| 	errAfter VRWInvoked // error after this method is called | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) Type() bsontype.Type { | ||||
| 	return llvrw.bsontype | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) Skip() error { | ||||
| 	panic("not implemented") | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadArray() (ArrayReader, error) { | ||||
| 	llvrw.invoked = llvrwReadArray | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadBinary() (b []byte, btype byte, err error) { | ||||
| 	llvrw.invoked = llvrwReadBinary | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, 0x00, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	switch tt := llvrw.readval.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		subtype, data, _, ok := bsoncore.ReadBinary(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.t.Error("Invalid Value provided for return value of ReadBinary.") | ||||
| 			return nil, 0x00, nil | ||||
| 		} | ||||
| 		return data, subtype, nil | ||||
| 	default: | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadBinary: %T", llvrw.readval) | ||||
| 		return nil, 0x00, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadBoolean() (bool, error) { | ||||
| 	llvrw.invoked = llvrwReadBoolean | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return false, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	b, ok := llvrw.readval.(bool) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadBoolean: %T", llvrw.readval) | ||||
| 		return false, nil | ||||
| 	} | ||||
|  | ||||
| 	return b, llvrw.err | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadDocument() (DocumentReader, error) { | ||||
| 	llvrw.invoked = llvrwReadDocument | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadCodeWithScope() (code string, dr DocumentReader, err error) { | ||||
| 	llvrw.invoked = llvrwReadCodeWithScope | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return "", llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadDBPointer() (ns string, oid primitive.ObjectID, err error) { | ||||
| 	llvrw.invoked = llvrwReadDBPointer | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", primitive.ObjectID{}, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	switch tt := llvrw.readval.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		ns, oid, _, ok := bsoncore.ReadDBPointer(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.t.Error("Invalid Value instance provided for return value of ReadDBPointer") | ||||
| 			return "", primitive.ObjectID{}, nil | ||||
| 		} | ||||
| 		return ns, oid, nil | ||||
| 	default: | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadDBPointer: %T", llvrw.readval) | ||||
| 		return "", primitive.ObjectID{}, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadDateTime() (int64, error) { | ||||
| 	llvrw.invoked = llvrwReadDateTime | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return 0, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	dt, ok := llvrw.readval.(int64) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadDateTime: %T", llvrw.readval) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return dt, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadDecimal128() (primitive.Decimal128, error) { | ||||
| 	llvrw.invoked = llvrwReadDecimal128 | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return primitive.Decimal128{}, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	d128, ok := llvrw.readval.(primitive.Decimal128) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadDecimal128: %T", llvrw.readval) | ||||
| 		return primitive.Decimal128{}, nil | ||||
| 	} | ||||
|  | ||||
| 	return d128, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadDouble() (float64, error) { | ||||
| 	llvrw.invoked = llvrwReadDouble | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return 0, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	f64, ok := llvrw.readval.(float64) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadDouble: %T", llvrw.readval) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return f64, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadInt32() (int32, error) { | ||||
| 	llvrw.invoked = llvrwReadInt32 | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return 0, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	i32, ok := llvrw.readval.(int32) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadInt32: %T", llvrw.readval) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return i32, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadInt64() (int64, error) { | ||||
| 	llvrw.invoked = llvrwReadInt64 | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return 0, llvrw.err | ||||
| 	} | ||||
| 	i64, ok := llvrw.readval.(int64) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadInt64: %T", llvrw.readval) | ||||
| 		return 0, nil | ||||
| 	} | ||||
|  | ||||
| 	return i64, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadJavascript() (code string, err error) { | ||||
| 	llvrw.invoked = llvrwReadJavascript | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", llvrw.err | ||||
| 	} | ||||
| 	js, ok := llvrw.readval.(string) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadJavascript: %T", llvrw.readval) | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	return js, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadMaxKey() error { | ||||
| 	llvrw.invoked = llvrwReadMaxKey | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadMinKey() error { | ||||
| 	llvrw.invoked = llvrwReadMinKey | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadNull() error { | ||||
| 	llvrw.invoked = llvrwReadNull | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadObjectID() (primitive.ObjectID, error) { | ||||
| 	llvrw.invoked = llvrwReadObjectID | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return primitive.ObjectID{}, llvrw.err | ||||
| 	} | ||||
| 	oid, ok := llvrw.readval.(primitive.ObjectID) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadObjectID: %T", llvrw.readval) | ||||
| 		return primitive.ObjectID{}, nil | ||||
| 	} | ||||
|  | ||||
| 	return oid, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadRegex() (pattern string, options string, err error) { | ||||
| 	llvrw.invoked = llvrwReadRegex | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", "", llvrw.err | ||||
| 	} | ||||
| 	switch tt := llvrw.readval.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		pattern, options, _, ok := bsoncore.ReadRegex(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.t.Error("Invalid Value instance provided for ReadRegex") | ||||
| 			return "", "", nil | ||||
| 		} | ||||
| 		return pattern, options, nil | ||||
| 	default: | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadRegex: %T", llvrw.readval) | ||||
| 		return "", "", nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadString() (string, error) { | ||||
| 	llvrw.invoked = llvrwReadString | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", llvrw.err | ||||
| 	} | ||||
| 	str, ok := llvrw.readval.(string) | ||||
| 	if !ok { | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadString: %T", llvrw.readval) | ||||
| 		return "", nil | ||||
| 	} | ||||
|  | ||||
| 	return str, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadSymbol() (symbol string, err error) { | ||||
| 	llvrw.invoked = llvrwReadSymbol | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", llvrw.err | ||||
| 	} | ||||
| 	switch tt := llvrw.readval.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		symbol, _, ok := bsoncore.ReadSymbol(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.t.Error("Invalid Value instance provided for ReadSymbol") | ||||
| 			return "", nil | ||||
| 		} | ||||
| 		return symbol, nil | ||||
| 	default: | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadSymbol: %T", llvrw.readval) | ||||
| 		return "", nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadTimestamp() (t uint32, i uint32, err error) { | ||||
| 	llvrw.invoked = llvrwReadTimestamp | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return 0, 0, llvrw.err | ||||
| 	} | ||||
| 	switch tt := llvrw.readval.(type) { | ||||
| 	case bsoncore.Value: | ||||
| 		t, i, _, ok := bsoncore.ReadTimestamp(tt.Data) | ||||
| 		if !ok { | ||||
| 			llvrw.t.Errorf("Invalid Value instance provided for return value of ReadTimestamp") | ||||
| 			return 0, 0, nil | ||||
| 		} | ||||
| 		return t, i, nil | ||||
| 	default: | ||||
| 		llvrw.t.Errorf("Incorrect type provided for return value of ReadTimestamp: %T", llvrw.readval) | ||||
| 		return 0, 0, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadUndefined() error { | ||||
| 	llvrw.invoked = llvrwReadUndefined | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteArray() (ArrayWriter, error) { | ||||
| 	llvrw.invoked = llvrwWriteArray | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteBinary(b []byte) error { | ||||
| 	llvrw.invoked = llvrwWriteBinary | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { | ||||
| 	llvrw.invoked = llvrwWriteBinaryWithSubtype | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteBoolean(bool) error { | ||||
| 	llvrw.invoked = llvrwWriteBoolean | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteCodeWithScope(code string) (DocumentWriter, error) { | ||||
| 	llvrw.invoked = llvrwWriteCodeWithScope | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { | ||||
| 	llvrw.invoked = llvrwWriteDBPointer | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDateTime(dt int64) error { | ||||
| 	llvrw.invoked = llvrwWriteDateTime | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDecimal128(primitive.Decimal128) error { | ||||
| 	llvrw.invoked = llvrwWriteDecimal128 | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDouble(float64) error { | ||||
| 	llvrw.invoked = llvrwWriteDouble | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteInt32(int32) error { | ||||
| 	llvrw.invoked = llvrwWriteInt32 | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteInt64(int64) error { | ||||
| 	llvrw.invoked = llvrwWriteInt64 | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteJavascript(code string) error { | ||||
| 	llvrw.invoked = llvrwWriteJavascript | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteMaxKey() error { | ||||
| 	llvrw.invoked = llvrwWriteMaxKey | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteMinKey() error { | ||||
| 	llvrw.invoked = llvrwWriteMinKey | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteNull() error { | ||||
| 	llvrw.invoked = llvrwWriteNull | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteObjectID(primitive.ObjectID) error { | ||||
| 	llvrw.invoked = llvrwWriteObjectID | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteRegex(pattern string, options string) error { | ||||
| 	llvrw.invoked = llvrwWriteRegex | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteString(string) error { | ||||
| 	llvrw.invoked = llvrwWriteString | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDocument() (DocumentWriter, error) { | ||||
| 	llvrw.invoked = llvrwWriteDocument | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteSymbol(symbol string) error { | ||||
| 	llvrw.invoked = llvrwWriteSymbol | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteTimestamp(t uint32, i uint32) error { | ||||
| 	llvrw.invoked = llvrwWriteTimestamp | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteUndefined() error { | ||||
| 	llvrw.invoked = llvrwWriteUndefined | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadElement() (string, ValueReader, error) { | ||||
| 	llvrw.invoked = llvrwReadElement | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return "", nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return "", llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDocumentElement(string) (ValueWriter, error) { | ||||
| 	llvrw.invoked = llvrwWriteDocumentElement | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteDocumentEnd() error { | ||||
| 	llvrw.invoked = llvrwWriteDocumentEnd | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) ReadValue() (ValueReader, error) { | ||||
| 	llvrw.invoked = llvrwReadValue | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteArrayElement() (ValueWriter, error) { | ||||
| 	llvrw.invoked = llvrwWriteArrayElement | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return nil, llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return llvrw, nil | ||||
| } | ||||
|  | ||||
| func (llvrw *TestValueReaderWriter) WriteArrayEnd() error { | ||||
| 	llvrw.invoked = llvrwWriteArrayEnd | ||||
| 	if llvrw.errAfter == llvrw.invoked { | ||||
| 		return llvrw.err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										606
									
								
								mongo/bson/bsonrw/value_writer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										606
									
								
								mongo/bson/bsonrw/value_writer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,606 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"math" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| var _ ValueWriter = (*valueWriter)(nil) | ||||
|  | ||||
| var vwPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		return new(valueWriter) | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| // BSONValueWriterPool is a pool for BSON ValueWriters. | ||||
| type BSONValueWriterPool struct { | ||||
| 	pool sync.Pool | ||||
| } | ||||
|  | ||||
| // NewBSONValueWriterPool creates a new pool for ValueWriter instances that write to BSON. | ||||
| func NewBSONValueWriterPool() *BSONValueWriterPool { | ||||
| 	return &BSONValueWriterPool{ | ||||
| 		pool: sync.Pool{ | ||||
| 			New: func() interface{} { | ||||
| 				return new(valueWriter) | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Get retrieves a BSON ValueWriter from the pool and resets it to use w as the destination. | ||||
| func (bvwp *BSONValueWriterPool) Get(w io.Writer) ValueWriter { | ||||
| 	vw := bvwp.pool.Get().(*valueWriter) | ||||
|  | ||||
| 	// TODO: Having to call reset here with the same buffer doesn't really make sense. | ||||
| 	vw.reset(vw.buf) | ||||
| 	vw.buf = vw.buf[:0] | ||||
| 	vw.w = w | ||||
| 	return vw | ||||
| } | ||||
|  | ||||
| // GetAtModeElement retrieves a ValueWriterFlusher from the pool and resets it to use w as the destination. | ||||
| func (bvwp *BSONValueWriterPool) GetAtModeElement(w io.Writer) ValueWriterFlusher { | ||||
| 	vw := bvwp.Get(w).(*valueWriter) | ||||
| 	vw.push(mElement) | ||||
| 	return vw | ||||
| } | ||||
|  | ||||
| // Put inserts a ValueWriter into the pool. If the ValueWriter is not a BSON ValueWriter, nothing | ||||
| // happens and ok will be false. | ||||
| func (bvwp *BSONValueWriterPool) Put(vw ValueWriter) (ok bool) { | ||||
| 	bvw, ok := vw.(*valueWriter) | ||||
| 	if !ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	bvwp.pool.Put(bvw) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // This is here so that during testing we can change it and not require | ||||
| // allocating a 4GB slice. | ||||
| var maxSize = math.MaxInt32 | ||||
|  | ||||
| var errNilWriter = errors.New("cannot create a ValueWriter from a nil io.Writer") | ||||
|  | ||||
| type errMaxDocumentSizeExceeded struct { | ||||
| 	size int64 | ||||
| } | ||||
|  | ||||
| func (mdse errMaxDocumentSizeExceeded) Error() string { | ||||
| 	return fmt.Sprintf("document size (%d) is larger than the max int32", mdse.size) | ||||
| } | ||||
|  | ||||
| type vwMode int | ||||
|  | ||||
| const ( | ||||
| 	_ vwMode = iota | ||||
| 	vwTopLevel | ||||
| 	vwDocument | ||||
| 	vwArray | ||||
| 	vwValue | ||||
| 	vwElement | ||||
| 	vwCodeWithScope | ||||
| ) | ||||
|  | ||||
| func (vm vwMode) String() string { | ||||
| 	var str string | ||||
|  | ||||
| 	switch vm { | ||||
| 	case vwTopLevel: | ||||
| 		str = "TopLevel" | ||||
| 	case vwDocument: | ||||
| 		str = "DocumentMode" | ||||
| 	case vwArray: | ||||
| 		str = "ArrayMode" | ||||
| 	case vwValue: | ||||
| 		str = "ValueMode" | ||||
| 	case vwElement: | ||||
| 		str = "ElementMode" | ||||
| 	case vwCodeWithScope: | ||||
| 		str = "CodeWithScopeMode" | ||||
| 	default: | ||||
| 		str = "UnknownMode" | ||||
| 	} | ||||
|  | ||||
| 	return str | ||||
| } | ||||
|  | ||||
| type vwState struct { | ||||
| 	mode   mode | ||||
| 	key    string | ||||
| 	arrkey int | ||||
| 	start  int32 | ||||
| } | ||||
|  | ||||
| type valueWriter struct { | ||||
| 	w   io.Writer | ||||
| 	buf []byte | ||||
|  | ||||
| 	stack []vwState | ||||
| 	frame int64 | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) advanceFrame() { | ||||
| 	if vw.frame+1 >= int64(len(vw.stack)) { // We need to grow the stack | ||||
| 		length := len(vw.stack) | ||||
| 		if length+1 >= cap(vw.stack) { | ||||
| 			// double it | ||||
| 			buf := make([]vwState, 2*cap(vw.stack)+1) | ||||
| 			copy(buf, vw.stack) | ||||
| 			vw.stack = buf | ||||
| 		} | ||||
| 		vw.stack = vw.stack[:length+1] | ||||
| 	} | ||||
| 	vw.frame++ | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) push(m mode) { | ||||
| 	vw.advanceFrame() | ||||
|  | ||||
| 	// Clean the stack | ||||
| 	vw.stack[vw.frame].mode = m | ||||
| 	vw.stack[vw.frame].key = "" | ||||
| 	vw.stack[vw.frame].arrkey = 0 | ||||
| 	vw.stack[vw.frame].start = 0 | ||||
|  | ||||
| 	vw.stack[vw.frame].mode = m | ||||
| 	switch m { | ||||
| 	case mDocument, mArray, mCodeWithScope: | ||||
| 		vw.reserveLength() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) reserveLength() { | ||||
| 	vw.stack[vw.frame].start = int32(len(vw.buf)) | ||||
| 	vw.buf = append(vw.buf, 0x00, 0x00, 0x00, 0x00) | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) pop() { | ||||
| 	switch vw.stack[vw.frame].mode { | ||||
| 	case mElement, mValue: | ||||
| 		vw.frame-- | ||||
| 	case mDocument, mArray, mCodeWithScope: | ||||
| 		vw.frame -= 2 // we pop twice to jump over the mElement: mDocument -> mElement -> mDocument/mTopLevel/etc... | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // NewBSONValueWriter creates a ValueWriter that writes BSON to w. | ||||
| // | ||||
| // This ValueWriter will only write entire documents to the io.Writer and it | ||||
| // will buffer the document as it is built. | ||||
| func NewBSONValueWriter(w io.Writer) (ValueWriter, error) { | ||||
| 	if w == nil { | ||||
| 		return nil, errNilWriter | ||||
| 	} | ||||
| 	return newValueWriter(w), nil | ||||
| } | ||||
|  | ||||
| func newValueWriter(w io.Writer) *valueWriter { | ||||
| 	vw := new(valueWriter) | ||||
| 	stack := make([]vwState, 1, 5) | ||||
| 	stack[0] = vwState{mode: mTopLevel} | ||||
| 	vw.w = w | ||||
| 	vw.stack = stack | ||||
|  | ||||
| 	return vw | ||||
| } | ||||
|  | ||||
| func newValueWriterFromSlice(buf []byte) *valueWriter { | ||||
| 	vw := new(valueWriter) | ||||
| 	stack := make([]vwState, 1, 5) | ||||
| 	stack[0] = vwState{mode: mTopLevel} | ||||
| 	vw.stack = stack | ||||
| 	vw.buf = buf | ||||
|  | ||||
| 	return vw | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) reset(buf []byte) { | ||||
| 	if vw.stack == nil { | ||||
| 		vw.stack = make([]vwState, 1, 5) | ||||
| 	} | ||||
| 	vw.stack = vw.stack[:1] | ||||
| 	vw.stack[0] = vwState{mode: mTopLevel} | ||||
| 	vw.buf = buf | ||||
| 	vw.frame = 0 | ||||
| 	vw.w = nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) invalidTransitionError(destination mode, name string, modes []mode) error { | ||||
| 	te := TransitionError{ | ||||
| 		name:        name, | ||||
| 		current:     vw.stack[vw.frame].mode, | ||||
| 		destination: destination, | ||||
| 		modes:       modes, | ||||
| 		action:      "write", | ||||
| 	} | ||||
| 	if vw.frame != 0 { | ||||
| 		te.parent = vw.stack[vw.frame-1].mode | ||||
| 	} | ||||
| 	return te | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) writeElementHeader(t bsontype.Type, destination mode, callerName string, addmodes ...mode) error { | ||||
| 	switch vw.stack[vw.frame].mode { | ||||
| 	case mElement: | ||||
| 		key := vw.stack[vw.frame].key | ||||
| 		if !isValidCString(key) { | ||||
| 			return errors.New("BSON element key cannot contain null bytes") | ||||
| 		} | ||||
|  | ||||
| 		vw.buf = bsoncore.AppendHeader(vw.buf, t, key) | ||||
| 	case mValue: | ||||
| 		// TODO: Do this with a cache of the first 1000 or so array keys. | ||||
| 		vw.buf = bsoncore.AppendHeader(vw.buf, t, strconv.Itoa(vw.stack[vw.frame].arrkey)) | ||||
| 	default: | ||||
| 		modes := []mode{mElement, mValue} | ||||
| 		if addmodes != nil { | ||||
| 			modes = append(modes, addmodes...) | ||||
| 		} | ||||
| 		return vw.invalidTransitionError(destination, callerName, modes) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteValueBytes(t bsontype.Type, b []byte) error { | ||||
| 	if err := vw.writeElementHeader(t, mode(0), "WriteValueBytes"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	vw.buf = append(vw.buf, b...) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteArray() (ArrayWriter, error) { | ||||
| 	if err := vw.writeElementHeader(bsontype.Array, mArray, "WriteArray"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	vw.push(mArray) | ||||
|  | ||||
| 	return vw, nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteBinary(b []byte) error { | ||||
| 	return vw.WriteBinaryWithSubtype(b, 0x00) | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteBinaryWithSubtype(b []byte, btype byte) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Binary, mode(0), "WriteBinaryWithSubtype"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendBinary(vw.buf, btype, b) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteBoolean(b bool) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Boolean, mode(0), "WriteBoolean"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendBoolean(vw.buf, b) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteCodeWithScope(code string) (DocumentWriter, error) { | ||||
| 	if err := vw.writeElementHeader(bsontype.CodeWithScope, mCodeWithScope, "WriteCodeWithScope"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// CodeWithScope is a different than other types because we need an extra | ||||
| 	// frame on the stack. In the EndDocument code, we write the document | ||||
| 	// length, pop, write the code with scope length, and pop. To simplify the | ||||
| 	// pop code, we push a spacer frame that we'll always jump over. | ||||
| 	vw.push(mCodeWithScope) | ||||
| 	vw.buf = bsoncore.AppendString(vw.buf, code) | ||||
| 	vw.push(mSpacer) | ||||
| 	vw.push(mDocument) | ||||
|  | ||||
| 	return vw, nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDBPointer(ns string, oid primitive.ObjectID) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.DBPointer, mode(0), "WriteDBPointer"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendDBPointer(vw.buf, ns, oid) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDateTime(dt int64) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.DateTime, mode(0), "WriteDateTime"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendDateTime(vw.buf, dt) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDecimal128(d128 primitive.Decimal128) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Decimal128, mode(0), "WriteDecimal128"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendDecimal128(vw.buf, d128) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDouble(f float64) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Double, mode(0), "WriteDouble"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendDouble(vw.buf, f) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteInt32(i32 int32) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Int32, mode(0), "WriteInt32"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendInt32(vw.buf, i32) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteInt64(i64 int64) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Int64, mode(0), "WriteInt64"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendInt64(vw.buf, i64) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteJavascript(code string) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.JavaScript, mode(0), "WriteJavascript"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendJavaScript(vw.buf, code) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteMaxKey() error { | ||||
| 	if err := vw.writeElementHeader(bsontype.MaxKey, mode(0), "WriteMaxKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteMinKey() error { | ||||
| 	if err := vw.writeElementHeader(bsontype.MinKey, mode(0), "WriteMinKey"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteNull() error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Null, mode(0), "WriteNull"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteObjectID(oid primitive.ObjectID) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.ObjectID, mode(0), "WriteObjectID"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendObjectID(vw.buf, oid) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteRegex(pattern string, options string) error { | ||||
| 	if !isValidCString(pattern) || !isValidCString(options) { | ||||
| 		return errors.New("BSON regex values cannot contain null bytes") | ||||
| 	} | ||||
| 	if err := vw.writeElementHeader(bsontype.Regex, mode(0), "WriteRegex"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendRegex(vw.buf, pattern, sortStringAlphebeticAscending(options)) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteString(s string) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.String, mode(0), "WriteString"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendString(vw.buf, s) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDocument() (DocumentWriter, error) { | ||||
| 	if vw.stack[vw.frame].mode == mTopLevel { | ||||
| 		vw.reserveLength() | ||||
| 		return vw, nil | ||||
| 	} | ||||
| 	if err := vw.writeElementHeader(bsontype.EmbeddedDocument, mDocument, "WriteDocument", mTopLevel); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	vw.push(mDocument) | ||||
| 	return vw, nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteSymbol(symbol string) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Symbol, mode(0), "WriteSymbol"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendSymbol(vw.buf, symbol) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteTimestamp(t uint32, i uint32) error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Timestamp, mode(0), "WriteTimestamp"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = bsoncore.AppendTimestamp(vw.buf, t, i) | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteUndefined() error { | ||||
| 	if err := vw.writeElementHeader(bsontype.Undefined, mode(0), "WriteUndefined"); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDocumentElement(key string) (ValueWriter, error) { | ||||
| 	switch vw.stack[vw.frame].mode { | ||||
| 	case mTopLevel, mDocument: | ||||
| 	default: | ||||
| 		return nil, vw.invalidTransitionError(mElement, "WriteDocumentElement", []mode{mTopLevel, mDocument}) | ||||
| 	} | ||||
|  | ||||
| 	vw.push(mElement) | ||||
| 	vw.stack[vw.frame].key = key | ||||
|  | ||||
| 	return vw, nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteDocumentEnd() error { | ||||
| 	switch vw.stack[vw.frame].mode { | ||||
| 	case mTopLevel, mDocument: | ||||
| 	default: | ||||
| 		return fmt.Errorf("incorrect mode to end document: %s", vw.stack[vw.frame].mode) | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = append(vw.buf, 0x00) | ||||
|  | ||||
| 	err := vw.writeLength() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	if vw.stack[vw.frame].mode == mTopLevel { | ||||
| 		if err = vw.Flush(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	vw.pop() | ||||
|  | ||||
| 	if vw.stack[vw.frame].mode == mCodeWithScope { | ||||
| 		// We ignore the error here because of the guarantee of writeLength. | ||||
| 		// See the docs for writeLength for more info. | ||||
| 		_ = vw.writeLength() | ||||
| 		vw.pop() | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) Flush() error { | ||||
| 	if vw.w == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	if _, err := vw.w.Write(vw.buf); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// reset buffer | ||||
| 	vw.buf = vw.buf[:0] | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteArrayElement() (ValueWriter, error) { | ||||
| 	if vw.stack[vw.frame].mode != mArray { | ||||
| 		return nil, vw.invalidTransitionError(mValue, "WriteArrayElement", []mode{mArray}) | ||||
| 	} | ||||
|  | ||||
| 	arrkey := vw.stack[vw.frame].arrkey | ||||
| 	vw.stack[vw.frame].arrkey++ | ||||
|  | ||||
| 	vw.push(mValue) | ||||
| 	vw.stack[vw.frame].arrkey = arrkey | ||||
|  | ||||
| 	return vw, nil | ||||
| } | ||||
|  | ||||
| func (vw *valueWriter) WriteArrayEnd() error { | ||||
| 	if vw.stack[vw.frame].mode != mArray { | ||||
| 		return fmt.Errorf("incorrect mode to end array: %s", vw.stack[vw.frame].mode) | ||||
| 	} | ||||
|  | ||||
| 	vw.buf = append(vw.buf, 0x00) | ||||
|  | ||||
| 	err := vw.writeLength() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	vw.pop() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // NOTE: We assume that if we call writeLength more than once the same function | ||||
| // within the same function without altering the vw.buf that this method will | ||||
| // not return an error. If this changes ensure that the following methods are | ||||
| // updated: | ||||
| // | ||||
| // - WriteDocumentEnd | ||||
| func (vw *valueWriter) writeLength() error { | ||||
| 	length := len(vw.buf) | ||||
| 	if length > maxSize { | ||||
| 		return errMaxDocumentSizeExceeded{size: int64(len(vw.buf))} | ||||
| 	} | ||||
| 	length = length - int(vw.stack[vw.frame].start) | ||||
| 	start := vw.stack[vw.frame].start | ||||
|  | ||||
| 	vw.buf[start+0] = byte(length) | ||||
| 	vw.buf[start+1] = byte(length >> 8) | ||||
| 	vw.buf[start+2] = byte(length >> 16) | ||||
| 	vw.buf[start+3] = byte(length >> 24) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func isValidCString(cs string) bool { | ||||
| 	return !strings.ContainsRune(cs, '\x00') | ||||
| } | ||||
							
								
								
									
										368
									
								
								mongo/bson/bsonrw/value_writer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										368
									
								
								mongo/bson/bsonrw/value_writer_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,368 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"math" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| func TestNewBSONValueWriter(t *testing.T) { | ||||
| 	_, got := NewBSONValueWriter(nil) | ||||
| 	want := errNilWriter | ||||
| 	if !compareErrors(got, want) { | ||||
| 		t.Errorf("Returned error did not match what was expected. got %v; want %v", got, want) | ||||
| 	} | ||||
|  | ||||
| 	vw, got := NewBSONValueWriter(errWriter{}) | ||||
| 	want = nil | ||||
| 	if !compareErrors(got, want) { | ||||
| 		t.Errorf("Returned error did not match what was expected. got %v; want %v", got, want) | ||||
| 	} | ||||
| 	if vw == nil { | ||||
| 		t.Errorf("Expected non-nil ValueWriter to be returned from NewBSONValueWriter") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestValueWriter(t *testing.T) { | ||||
| 	header := []byte{0x00, 0x00, 0x00, 0x00} | ||||
| 	oid := primitive.ObjectID{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C} | ||||
| 	testCases := []struct { | ||||
| 		name   string | ||||
| 		fn     interface{} | ||||
| 		params []interface{} | ||||
| 		want   []byte | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"WriteBinary", | ||||
| 			(*valueWriter).WriteBinary, | ||||
| 			[]interface{}{[]byte{0x01, 0x02, 0x03}}, | ||||
| 			bsoncore.AppendBinaryElement(header, "foo", 0x00, []byte{0x01, 0x02, 0x03}), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteBinaryWithSubtype (not 0x02)", | ||||
| 			(*valueWriter).WriteBinaryWithSubtype, | ||||
| 			[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0xFF)}, | ||||
| 			bsoncore.AppendBinaryElement(header, "foo", 0xFF, []byte{0x01, 0x02, 0x03}), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteBinaryWithSubtype (0x02)", | ||||
| 			(*valueWriter).WriteBinaryWithSubtype, | ||||
| 			[]interface{}{[]byte{0x01, 0x02, 0x03}, byte(0x02)}, | ||||
| 			bsoncore.AppendBinaryElement(header, "foo", 0x02, []byte{0x01, 0x02, 0x03}), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteBoolean", | ||||
| 			(*valueWriter).WriteBoolean, | ||||
| 			[]interface{}{true}, | ||||
| 			bsoncore.AppendBooleanElement(header, "foo", true), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDBPointer", | ||||
| 			(*valueWriter).WriteDBPointer, | ||||
| 			[]interface{}{"bar", oid}, | ||||
| 			bsoncore.AppendDBPointerElement(header, "foo", "bar", oid), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDateTime", | ||||
| 			(*valueWriter).WriteDateTime, | ||||
| 			[]interface{}{int64(12345678)}, | ||||
| 			bsoncore.AppendDateTimeElement(header, "foo", 12345678), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDecimal128", | ||||
| 			(*valueWriter).WriteDecimal128, | ||||
| 			[]interface{}{primitive.NewDecimal128(10, 20)}, | ||||
| 			bsoncore.AppendDecimal128Element(header, "foo", primitive.NewDecimal128(10, 20)), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteDouble", | ||||
| 			(*valueWriter).WriteDouble, | ||||
| 			[]interface{}{float64(3.14159)}, | ||||
| 			bsoncore.AppendDoubleElement(header, "foo", 3.14159), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteInt32", | ||||
| 			(*valueWriter).WriteInt32, | ||||
| 			[]interface{}{int32(123456)}, | ||||
| 			bsoncore.AppendInt32Element(header, "foo", 123456), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteInt64", | ||||
| 			(*valueWriter).WriteInt64, | ||||
| 			[]interface{}{int64(1234567890)}, | ||||
| 			bsoncore.AppendInt64Element(header, "foo", 1234567890), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteJavascript", | ||||
| 			(*valueWriter).WriteJavascript, | ||||
| 			[]interface{}{"var foo = 'bar';"}, | ||||
| 			bsoncore.AppendJavaScriptElement(header, "foo", "var foo = 'bar';"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteMaxKey", | ||||
| 			(*valueWriter).WriteMaxKey, | ||||
| 			[]interface{}{}, | ||||
| 			bsoncore.AppendMaxKeyElement(header, "foo"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteMinKey", | ||||
| 			(*valueWriter).WriteMinKey, | ||||
| 			[]interface{}{}, | ||||
| 			bsoncore.AppendMinKeyElement(header, "foo"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteNull", | ||||
| 			(*valueWriter).WriteNull, | ||||
| 			[]interface{}{}, | ||||
| 			bsoncore.AppendNullElement(header, "foo"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteObjectID", | ||||
| 			(*valueWriter).WriteObjectID, | ||||
| 			[]interface{}{oid}, | ||||
| 			bsoncore.AppendObjectIDElement(header, "foo", oid), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteRegex", | ||||
| 			(*valueWriter).WriteRegex, | ||||
| 			[]interface{}{"bar", "baz"}, | ||||
| 			bsoncore.AppendRegexElement(header, "foo", "bar", "abz"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteString", | ||||
| 			(*valueWriter).WriteString, | ||||
| 			[]interface{}{"hello, world!"}, | ||||
| 			bsoncore.AppendStringElement(header, "foo", "hello, world!"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteSymbol", | ||||
| 			(*valueWriter).WriteSymbol, | ||||
| 			[]interface{}{"symbollolz"}, | ||||
| 			bsoncore.AppendSymbolElement(header, "foo", "symbollolz"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteTimestamp", | ||||
| 			(*valueWriter).WriteTimestamp, | ||||
| 			[]interface{}{uint32(10), uint32(20)}, | ||||
| 			bsoncore.AppendTimestampElement(header, "foo", 10, 20), | ||||
| 		}, | ||||
| 		{ | ||||
| 			"WriteUndefined", | ||||
| 			(*valueWriter).WriteUndefined, | ||||
| 			[]interface{}{}, | ||||
| 			bsoncore.AppendUndefinedElement(header, "foo"), | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			fn := reflect.ValueOf(tc.fn) | ||||
| 			if fn.Kind() != reflect.Func { | ||||
| 				t.Fatalf("fn must be of kind Func but it is a %v", fn.Kind()) | ||||
| 			} | ||||
| 			if fn.Type().NumIn() != len(tc.params)+1 || fn.Type().In(0) != reflect.TypeOf((*valueWriter)(nil)) { | ||||
| 				t.Fatalf("fn must have at least one parameter and the first parameter must be a *valueWriter") | ||||
| 			} | ||||
| 			if fn.Type().NumOut() != 1 || fn.Type().Out(0) != reflect.TypeOf((*error)(nil)).Elem() { | ||||
| 				t.Fatalf("fn must have one return value and it must be an error.") | ||||
| 			} | ||||
| 			params := make([]reflect.Value, 1, len(tc.params)+1) | ||||
| 			vw := newValueWriter(ioutil.Discard) | ||||
| 			params[0] = reflect.ValueOf(vw) | ||||
| 			for _, param := range tc.params { | ||||
| 				params = append(params, reflect.ValueOf(param)) | ||||
| 			} | ||||
| 			_, err := vw.WriteDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, err = vw.WriteDocumentElement("foo") | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			results := fn.Call(params) | ||||
| 			if !results[0].IsValid() { | ||||
| 				err = results[0].Interface().(error) | ||||
| 			} else { | ||||
| 				err = nil | ||||
| 			} | ||||
| 			noerr(t, err) | ||||
| 			got := vw.buf | ||||
| 			want := tc.want | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes are not equal.\n\tgot %v\n\twant %v", got, want) | ||||
| 			} | ||||
|  | ||||
| 			t.Run("incorrect transition", func(t *testing.T) { | ||||
| 				vw = newValueWriter(ioutil.Discard) | ||||
| 				results := fn.Call(params) | ||||
| 				got := results[0].Interface().(error) | ||||
| 				fnName := tc.name | ||||
| 				if strings.Contains(fnName, "WriteBinary") { | ||||
| 					fnName = "WriteBinaryWithSubtype" | ||||
| 				} | ||||
| 				want := TransitionError{current: mTopLevel, name: fnName, modes: []mode{mElement, mValue}, | ||||
| 					action: "write"} | ||||
| 				if !compareErrors(got, want) { | ||||
| 					t.Errorf("Errors do not match. got %v; want %v", got, want) | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	t.Run("WriteArray", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mArray) | ||||
| 		want := TransitionError{current: mArray, destination: mArray, parent: mTopLevel, | ||||
| 			name: "WriteArray", modes: []mode{mElement, mValue}, action: "write"} | ||||
| 		_, got := vw.WriteArray() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteCodeWithScope", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mArray) | ||||
| 		want := TransitionError{current: mArray, destination: mCodeWithScope, parent: mTopLevel, | ||||
| 			name: "WriteCodeWithScope", modes: []mode{mElement, mValue}, action: "write"} | ||||
| 		_, got := vw.WriteCodeWithScope("") | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteDocument", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mArray) | ||||
| 		want := TransitionError{current: mArray, destination: mDocument, parent: mTopLevel, | ||||
| 			name: "WriteDocument", modes: []mode{mElement, mValue, mTopLevel}, action: "write"} | ||||
| 		_, got := vw.WriteDocument() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteDocumentElement", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mElement) | ||||
| 		want := TransitionError{current: mElement, | ||||
| 			destination: mElement, | ||||
| 			parent:      mTopLevel, | ||||
| 			name:        "WriteDocumentElement", | ||||
| 			modes:       []mode{mTopLevel, mDocument}, | ||||
| 			action:      "write"} | ||||
| 		_, got := vw.WriteDocumentElement("") | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteDocumentEnd", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mElement) | ||||
| 		want := fmt.Errorf("incorrect mode to end document: %s", mElement) | ||||
| 		got := vw.WriteDocumentEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 		vw.pop() | ||||
| 		vw.buf = append(vw.buf, make([]byte, 1023)...) | ||||
| 		maxSize = 512 | ||||
| 		want = errMaxDocumentSizeExceeded{size: 1024} | ||||
| 		got = vw.WriteDocumentEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 		maxSize = math.MaxInt32 | ||||
| 		want = errors.New("what a nice fake error we have here") | ||||
| 		vw.w = errWriter{err: want} | ||||
| 		got = vw.WriteDocumentEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteArrayElement", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mElement) | ||||
| 		want := TransitionError{current: mElement, | ||||
| 			destination: mValue, | ||||
| 			parent:      mTopLevel, | ||||
| 			name:        "WriteArrayElement", | ||||
| 			modes:       []mode{mArray}, | ||||
| 			action:      "write"} | ||||
| 		_, got := vw.WriteArrayElement() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("WriteArrayEnd", func(t *testing.T) { | ||||
| 		vw := newValueWriter(ioutil.Discard) | ||||
| 		vw.push(mElement) | ||||
| 		want := fmt.Errorf("incorrect mode to end array: %s", mElement) | ||||
| 		got := vw.WriteArrayEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 		vw.push(mArray) | ||||
| 		vw.buf = append(vw.buf, make([]byte, 1019)...) | ||||
| 		maxSize = 512 | ||||
| 		want = errMaxDocumentSizeExceeded{size: 1024} | ||||
| 		got = vw.WriteArrayEnd() | ||||
| 		if !compareErrors(got, want) { | ||||
| 			t.Errorf("Did not get expected error. got %v; want %v", got, want) | ||||
| 		} | ||||
| 		maxSize = math.MaxInt32 | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("WriteBytes", func(t *testing.T) { | ||||
| 		t.Run("writeElementHeader error", func(t *testing.T) { | ||||
| 			vw := newValueWriterFromSlice(nil) | ||||
| 			want := TransitionError{current: mTopLevel, destination: mode(0), | ||||
| 				name: "WriteValueBytes", modes: []mode{mElement, mValue}, action: "write"} | ||||
| 			got := vw.WriteValueBytes(bsontype.EmbeddedDocument, nil) | ||||
| 			if !compareErrors(got, want) { | ||||
| 				t.Errorf("Did not received expected error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("success", func(t *testing.T) { | ||||
| 			index, doc := bsoncore.ReserveLength(nil) | ||||
| 			doc = bsoncore.AppendStringElement(doc, "hello", "world") | ||||
| 			doc = append(doc, 0x00) | ||||
| 			doc = bsoncore.UpdateLength(doc, index, int32(len(doc))) | ||||
|  | ||||
| 			index, want := bsoncore.ReserveLength(nil) | ||||
| 			want = bsoncore.AppendDocumentElement(want, "foo", doc) | ||||
| 			want = append(want, 0x00) | ||||
| 			want = bsoncore.UpdateLength(want, index, int32(len(want))) | ||||
|  | ||||
| 			vw := newValueWriterFromSlice(make([]byte, 0, 512)) | ||||
| 			_, err := vw.WriteDocument() | ||||
| 			noerr(t, err) | ||||
| 			_, err = vw.WriteDocumentElement("foo") | ||||
| 			noerr(t, err) | ||||
| 			err = vw.WriteValueBytes(bsontype.EmbeddedDocument, doc) | ||||
| 			noerr(t, err) | ||||
| 			err = vw.WriteDocumentEnd() | ||||
| 			noerr(t, err) | ||||
| 			got := vw.buf | ||||
| 			if !bytes.Equal(got, want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type errWriter struct { | ||||
| 	err error | ||||
| } | ||||
|  | ||||
| func (ew errWriter) Write([]byte) (int, error) { return 0, ew.err } | ||||
							
								
								
									
										78
									
								
								mongo/bson/bsonrw/writer.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										78
									
								
								mongo/bson/bsonrw/writer.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,78 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsonrw | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| ) | ||||
|  | ||||
| // ArrayWriter is the interface used to create a BSON or BSON adjacent array. | ||||
| // Callers must ensure they call WriteArrayEnd when they have finished creating | ||||
| // the array. | ||||
| type ArrayWriter interface { | ||||
| 	WriteArrayElement() (ValueWriter, error) | ||||
| 	WriteArrayEnd() error | ||||
| } | ||||
|  | ||||
| // DocumentWriter is the interface used to create a BSON or BSON adjacent | ||||
| // document. Callers must ensure they call WriteDocumentEnd when they have | ||||
| // finished creating the document. | ||||
| type DocumentWriter interface { | ||||
| 	WriteDocumentElement(string) (ValueWriter, error) | ||||
| 	WriteDocumentEnd() error | ||||
| } | ||||
|  | ||||
| // ValueWriter is the interface used to write BSON values. Implementations of | ||||
| // this interface handle creating BSON or BSON adjacent representations of the | ||||
| // values. | ||||
| type ValueWriter interface { | ||||
| 	WriteArray() (ArrayWriter, error) | ||||
| 	WriteBinary(b []byte) error | ||||
| 	WriteBinaryWithSubtype(b []byte, btype byte) error | ||||
| 	WriteBoolean(bool) error | ||||
| 	WriteCodeWithScope(code string) (DocumentWriter, error) | ||||
| 	WriteDBPointer(ns string, oid primitive.ObjectID) error | ||||
| 	WriteDateTime(dt int64) error | ||||
| 	WriteDecimal128(primitive.Decimal128) error | ||||
| 	WriteDouble(float64) error | ||||
| 	WriteInt32(int32) error | ||||
| 	WriteInt64(int64) error | ||||
| 	WriteJavascript(code string) error | ||||
| 	WriteMaxKey() error | ||||
| 	WriteMinKey() error | ||||
| 	WriteNull() error | ||||
| 	WriteObjectID(primitive.ObjectID) error | ||||
| 	WriteRegex(pattern, options string) error | ||||
| 	WriteString(string) error | ||||
| 	WriteDocument() (DocumentWriter, error) | ||||
| 	WriteSymbol(symbol string) error | ||||
| 	WriteTimestamp(t, i uint32) error | ||||
| 	WriteUndefined() error | ||||
| } | ||||
|  | ||||
| // ValueWriterFlusher is a superset of ValueWriter that exposes functionality to flush to the underlying buffer. | ||||
| type ValueWriterFlusher interface { | ||||
| 	ValueWriter | ||||
| 	Flush() error | ||||
| } | ||||
|  | ||||
| // BytesWriter is the interface used to write BSON bytes to a ValueWriter. | ||||
| // This interface is meant to be a superset of ValueWriter, so that types that | ||||
| // implement ValueWriter may also implement this interface. | ||||
| type BytesWriter interface { | ||||
| 	WriteValueBytes(t bsontype.Type, b []byte) error | ||||
| } | ||||
|  | ||||
| // SliceWriter allows a pointer to a slice of bytes to be used as an io.Writer. | ||||
| type SliceWriter []byte | ||||
|  | ||||
| func (sw *SliceWriter) Write(p []byte) (int, error) { | ||||
| 	written := len(p) | ||||
| 	*sw = append(*sw, p...) | ||||
| 	return written, nil | ||||
| } | ||||
							
								
								
									
										97
									
								
								mongo/bson/bsontype/bsontype.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								mongo/bson/bsontype/bsontype.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,97 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| // Package bsontype is a utility package that contains types for each BSON type and the | ||||
| // a stringifier for the Type to enable easier debugging when working with BSON. | ||||
| package bsontype // import "go.mongodb.org/mongo-driver/bson/bsontype" | ||||
|  | ||||
| // These constants uniquely refer to each BSON type. | ||||
| const ( | ||||
| 	Double           Type = 0x01 | ||||
| 	String           Type = 0x02 | ||||
| 	EmbeddedDocument Type = 0x03 | ||||
| 	Array            Type = 0x04 | ||||
| 	Binary           Type = 0x05 | ||||
| 	Undefined        Type = 0x06 | ||||
| 	ObjectID         Type = 0x07 | ||||
| 	Boolean          Type = 0x08 | ||||
| 	DateTime         Type = 0x09 | ||||
| 	Null             Type = 0x0A | ||||
| 	Regex            Type = 0x0B | ||||
| 	DBPointer        Type = 0x0C | ||||
| 	JavaScript       Type = 0x0D | ||||
| 	Symbol           Type = 0x0E | ||||
| 	CodeWithScope    Type = 0x0F | ||||
| 	Int32            Type = 0x10 | ||||
| 	Timestamp        Type = 0x11 | ||||
| 	Int64            Type = 0x12 | ||||
| 	Decimal128       Type = 0x13 | ||||
| 	MinKey           Type = 0xFF | ||||
| 	MaxKey           Type = 0x7F | ||||
|  | ||||
| 	BinaryGeneric     byte = 0x00 | ||||
| 	BinaryFunction    byte = 0x01 | ||||
| 	BinaryBinaryOld   byte = 0x02 | ||||
| 	BinaryUUIDOld     byte = 0x03 | ||||
| 	BinaryUUID        byte = 0x04 | ||||
| 	BinaryMD5         byte = 0x05 | ||||
| 	BinaryEncrypted   byte = 0x06 | ||||
| 	BinaryColumn      byte = 0x07 | ||||
| 	BinaryUserDefined byte = 0x80 | ||||
| ) | ||||
|  | ||||
| // Type represents a BSON type. | ||||
| type Type byte | ||||
|  | ||||
| // String returns the string representation of the BSON type's name. | ||||
| func (bt Type) String() string { | ||||
| 	switch bt { | ||||
| 	case '\x01': | ||||
| 		return "double" | ||||
| 	case '\x02': | ||||
| 		return "string" | ||||
| 	case '\x03': | ||||
| 		return "embedded document" | ||||
| 	case '\x04': | ||||
| 		return "array" | ||||
| 	case '\x05': | ||||
| 		return "binary" | ||||
| 	case '\x06': | ||||
| 		return "undefined" | ||||
| 	case '\x07': | ||||
| 		return "objectID" | ||||
| 	case '\x08': | ||||
| 		return "boolean" | ||||
| 	case '\x09': | ||||
| 		return "UTC datetime" | ||||
| 	case '\x0A': | ||||
| 		return "null" | ||||
| 	case '\x0B': | ||||
| 		return "regex" | ||||
| 	case '\x0C': | ||||
| 		return "dbPointer" | ||||
| 	case '\x0D': | ||||
| 		return "javascript" | ||||
| 	case '\x0E': | ||||
| 		return "symbol" | ||||
| 	case '\x0F': | ||||
| 		return "code with scope" | ||||
| 	case '\x10': | ||||
| 		return "32-bit integer" | ||||
| 	case '\x11': | ||||
| 		return "timestamp" | ||||
| 	case '\x12': | ||||
| 		return "64-bit integer" | ||||
| 	case '\x13': | ||||
| 		return "128-bit decimal" | ||||
| 	case '\xFF': | ||||
| 		return "min key" | ||||
| 	case '\x7F': | ||||
| 		return "max key" | ||||
| 	default: | ||||
| 		return "invalid" | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										49
									
								
								mongo/bson/bsontype/bsontype_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								mongo/bson/bsontype/bsontype_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bsontype | ||||
|  | ||||
| import "testing" | ||||
|  | ||||
| func TestType(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		name string | ||||
| 		t    Type | ||||
| 		want string | ||||
| 	}{ | ||||
| 		{"double", Double, "double"}, | ||||
| 		{"string", String, "string"}, | ||||
| 		{"embedded document", EmbeddedDocument, "embedded document"}, | ||||
| 		{"array", Array, "array"}, | ||||
| 		{"binary", Binary, "binary"}, | ||||
| 		{"undefined", Undefined, "undefined"}, | ||||
| 		{"objectID", ObjectID, "objectID"}, | ||||
| 		{"boolean", Boolean, "boolean"}, | ||||
| 		{"UTC datetime", DateTime, "UTC datetime"}, | ||||
| 		{"null", Null, "null"}, | ||||
| 		{"regex", Regex, "regex"}, | ||||
| 		{"dbPointer", DBPointer, "dbPointer"}, | ||||
| 		{"javascript", JavaScript, "javascript"}, | ||||
| 		{"symbol", Symbol, "symbol"}, | ||||
| 		{"code with scope", CodeWithScope, "code with scope"}, | ||||
| 		{"32-bit integer", Int32, "32-bit integer"}, | ||||
| 		{"timestamp", Timestamp, "timestamp"}, | ||||
| 		{"64-bit integer", Int64, "64-bit integer"}, | ||||
| 		{"128-bit decimal", Decimal128, "128-bit decimal"}, | ||||
| 		{"min key", MinKey, "min key"}, | ||||
| 		{"max key", MaxKey, "max key"}, | ||||
| 		{"invalid", (0), "invalid"}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			got := tc.t.String() | ||||
| 			if got != tc.want { | ||||
| 				t.Errorf("String outputs do not match. got %s; want %s", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										141
									
								
								mongo/bson/decoder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								mongo/bson/decoder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| ) | ||||
|  | ||||
| // ErrDecodeToNil is the error returned when trying to decode to a nil value | ||||
| var ErrDecodeToNil = errors.New("cannot Decode to nil value") | ||||
|  | ||||
| // This pool is used to keep the allocations of Decoders down. This is only used for the Marshal* | ||||
| // methods and is not consumable from outside of this package. The Decoders retrieved from this pool | ||||
| // must have both Reset and SetRegistry called on them. | ||||
| var decPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		return new(Decoder) | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| // A Decoder reads and decodes BSON documents from a stream. It reads from a bsonrw.ValueReader as | ||||
| // the source of BSON data. | ||||
| type Decoder struct { | ||||
| 	dc bsoncodec.DecodeContext | ||||
| 	vr bsonrw.ValueReader | ||||
|  | ||||
| 	// We persist defaultDocumentM and defaultDocumentD on the Decoder to prevent overwriting from | ||||
| 	// (*Decoder).SetContext. | ||||
| 	defaultDocumentM bool | ||||
| 	defaultDocumentD bool | ||||
| } | ||||
|  | ||||
| // NewDecoder returns a new decoder that uses the DefaultRegistry to read from vr. | ||||
| func NewDecoder(vr bsonrw.ValueReader) (*Decoder, error) { | ||||
| 	if vr == nil { | ||||
| 		return nil, errors.New("cannot create a new Decoder with a nil ValueReader") | ||||
| 	} | ||||
|  | ||||
| 	return &Decoder{ | ||||
| 		dc: bsoncodec.DecodeContext{Registry: DefaultRegistry}, | ||||
| 		vr: vr, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // NewDecoderWithContext returns a new decoder that uses DecodeContext dc to read from vr. | ||||
| func NewDecoderWithContext(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader) (*Decoder, error) { | ||||
| 	if dc.Registry == nil { | ||||
| 		dc.Registry = DefaultRegistry | ||||
| 	} | ||||
| 	if vr == nil { | ||||
| 		return nil, errors.New("cannot create a new Decoder with a nil ValueReader") | ||||
| 	} | ||||
|  | ||||
| 	return &Decoder{ | ||||
| 		dc: dc, | ||||
| 		vr: vr, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // Decode reads the next BSON document from the stream and decodes it into the | ||||
| // value pointed to by val. | ||||
| // | ||||
| // The documentation for Unmarshal contains details about of BSON into a Go | ||||
| // value. | ||||
| func (d *Decoder) Decode(val interface{}) error { | ||||
| 	if unmarshaler, ok := val.(Unmarshaler); ok { | ||||
| 		// TODO(skriptble): Reuse a []byte here and use the AppendDocumentBytes method. | ||||
| 		buf, err := bsonrw.Copier{}.CopyDocumentToBytes(d.vr) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return unmarshaler.UnmarshalBSON(buf) | ||||
| 	} | ||||
|  | ||||
| 	rval := reflect.ValueOf(val) | ||||
| 	switch rval.Kind() { | ||||
| 	case reflect.Ptr: | ||||
| 		if rval.IsNil() { | ||||
| 			return ErrDecodeToNil | ||||
| 		} | ||||
| 		rval = rval.Elem() | ||||
| 	case reflect.Map: | ||||
| 		if rval.IsNil() { | ||||
| 			return ErrDecodeToNil | ||||
| 		} | ||||
| 	default: | ||||
| 		return fmt.Errorf("argument to Decode must be a pointer or a map, but got %v", rval) | ||||
| 	} | ||||
| 	decoder, err := d.dc.LookupDecoder(rval.Type()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if d.defaultDocumentM { | ||||
| 		d.dc.DefaultDocumentM() | ||||
| 	} | ||||
| 	if d.defaultDocumentD { | ||||
| 		d.dc.DefaultDocumentD() | ||||
| 	} | ||||
| 	return decoder.DecodeValue(d.dc, d.vr, rval) | ||||
| } | ||||
|  | ||||
| // Reset will reset the state of the decoder, using the same *DecodeContext used in | ||||
| // the original construction but using vr for reading. | ||||
| func (d *Decoder) Reset(vr bsonrw.ValueReader) error { | ||||
| 	d.vr = vr | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SetRegistry replaces the current registry of the decoder with r. | ||||
| func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error { | ||||
| 	d.dc.Registry = r | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SetContext replaces the current registry of the decoder with dc. | ||||
| func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error { | ||||
| 	d.dc = dc | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as | ||||
| // "interface{}" or "map[string]interface{}". | ||||
| func (d *Decoder) DefaultDocumentM() { | ||||
| 	d.defaultDocumentM = true | ||||
| } | ||||
|  | ||||
| // DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as | ||||
| // "interface{}" or "map[string]interface{}". | ||||
| func (d *Decoder) DefaultDocumentD() { | ||||
| 	d.defaultDocumentD = true | ||||
| } | ||||
							
								
								
									
										435
									
								
								mongo/bson/decoder_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										435
									
								
								mongo/bson/decoder_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,435 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| func TestBasicDecode(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	for _, tc := range unmarshalingTestCases() { | ||||
| 		tc := tc | ||||
|  | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			got := reflect.New(tc.sType).Elem() | ||||
| 			vr := bsonrw.NewBSONDocumentReader(tc.data) | ||||
| 			reg := DefaultRegistry | ||||
| 			decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) | ||||
| 			noerr(t, err) | ||||
| 			err = decoder.DecodeValue(bsoncodec.DecodeContext{Registry: reg}, vr, got) | ||||
| 			noerr(t, err) | ||||
| 			assert.Equal(t, tc.want, got.Addr().Interface(), "Results do not match.") | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestDecoderv2(t *testing.T) { | ||||
| 	t.Parallel() | ||||
|  | ||||
| 	t.Run("Decode", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		for _, tc := range unmarshalingTestCases() { | ||||
| 			tc := tc | ||||
|  | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				t.Parallel() | ||||
|  | ||||
| 				got := reflect.New(tc.sType).Interface() | ||||
| 				vr := bsonrw.NewBSONDocumentReader(tc.data) | ||||
| 				dec, err := NewDecoderWithContext(bsoncodec.DecodeContext{Registry: DefaultRegistry}, vr) | ||||
| 				noerr(t, err) | ||||
| 				err = dec.Decode(got) | ||||
| 				noerr(t, err) | ||||
| 				assert.Equal(t, tc.want, got, "Results do not match.") | ||||
| 			}) | ||||
| 		} | ||||
| 		t.Run("lookup error", func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			type certainlydoesntexistelsewhereihope func(string, string) string | ||||
| 			// Avoid unused code lint error. | ||||
| 			_ = certainlydoesntexistelsewhereihope(func(string, string) string { return "" }) | ||||
|  | ||||
| 			cdeih := func(string, string) string { return "certainlydoesntexistelsewhereihope" } | ||||
| 			dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{})) | ||||
| 			noerr(t, err) | ||||
| 			want := bsoncodec.ErrNoDecoder{Type: reflect.TypeOf(cdeih)} | ||||
| 			got := dec.Decode(&cdeih) | ||||
| 			assert.Equal(t, want, got, "Received unexpected error.") | ||||
| 		}) | ||||
| 		t.Run("Unmarshaler", func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			testCases := []struct { | ||||
| 				name    string | ||||
| 				err     error | ||||
| 				vr      bsonrw.ValueReader | ||||
| 				invoked bool | ||||
| 			}{ | ||||
| 				{ | ||||
| 					"error", | ||||
| 					errors.New("Unmarshaler error"), | ||||
| 					&bsonrwtest.ValueReaderWriter{BSONType: bsontype.EmbeddedDocument, Err: bsonrw.ErrEOD, ErrAfter: bsonrwtest.ReadElement}, | ||||
| 					true, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"copy error", | ||||
| 					errors.New("copy error"), | ||||
| 					&bsonrwtest.ValueReaderWriter{Err: errors.New("copy error"), ErrAfter: bsonrwtest.ReadDocument}, | ||||
| 					false, | ||||
| 				}, | ||||
| 				{ | ||||
| 					"success", | ||||
| 					nil, | ||||
| 					&bsonrwtest.ValueReaderWriter{BSONType: bsontype.EmbeddedDocument, Err: bsonrw.ErrEOD, ErrAfter: bsonrwtest.ReadElement}, | ||||
| 					true, | ||||
| 				}, | ||||
| 			} | ||||
|  | ||||
| 			for _, tc := range testCases { | ||||
| 				tc := tc | ||||
|  | ||||
| 				t.Run(tc.name, func(t *testing.T) { | ||||
| 					t.Parallel() | ||||
|  | ||||
| 					unmarshaler := &testUnmarshaler{err: tc.err} | ||||
| 					dec, err := NewDecoder(tc.vr) | ||||
| 					noerr(t, err) | ||||
| 					got := dec.Decode(unmarshaler) | ||||
| 					want := tc.err | ||||
| 					if !compareErrors(got, want) { | ||||
| 						t.Errorf("Did not receive expected error. got %v; want %v", got, want) | ||||
| 					} | ||||
| 					if unmarshaler.invoked != tc.invoked { | ||||
| 						if tc.invoked { | ||||
| 							t.Error("Expected to have UnmarshalBSON invoked, but it wasn't.") | ||||
| 						} else { | ||||
| 							t.Error("Expected UnmarshalBSON to not be invoked, but it was.") | ||||
| 						} | ||||
| 					} | ||||
| 				}) | ||||
| 			} | ||||
|  | ||||
| 			t.Run("Unmarshaler/success bsonrw.ValueReader", func(t *testing.T) { | ||||
| 				t.Parallel() | ||||
|  | ||||
| 				want := bsoncore.BuildDocument(nil, bsoncore.AppendDoubleElement(nil, "pi", 3.14159)) | ||||
| 				unmarshaler := &testUnmarshaler{} | ||||
| 				vr := bsonrw.NewBSONDocumentReader(want) | ||||
| 				dec, err := NewDecoder(vr) | ||||
| 				noerr(t, err) | ||||
| 				err = dec.Decode(unmarshaler) | ||||
| 				noerr(t, err) | ||||
| 				got := unmarshaler.data | ||||
| 				if !bytes.Equal(got, want) { | ||||
| 					t.Errorf("Did not unmarshal properly. got %v; want %v", got, want) | ||||
| 				} | ||||
| 			}) | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("NewDecoder", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		t.Run("error", func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			_, got := NewDecoder(nil) | ||||
| 			want := errors.New("cannot create a new Decoder with a nil ValueReader") | ||||
| 			if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { | ||||
| 				t.Errorf("Was expecting error but got different error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("success", func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			got, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{})) | ||||
| 			noerr(t, err) | ||||
| 			if got == nil { | ||||
| 				t.Errorf("Was expecting a non-nil Decoder, but got <nil>") | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("NewDecoderWithContext", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		t.Run("errors", func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} | ||||
| 			_, got := NewDecoderWithContext(dc, nil) | ||||
| 			want := errors.New("cannot create a new Decoder with a nil ValueReader") | ||||
| 			if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) { | ||||
| 				t.Errorf("Was expecting error but got different error. got %v; want %v", got, want) | ||||
| 			} | ||||
| 		}) | ||||
| 		t.Run("success", func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
|  | ||||
| 			got, err := NewDecoderWithContext(bsoncodec.DecodeContext{}, bsonrw.NewBSONDocumentReader([]byte{})) | ||||
| 			noerr(t, err) | ||||
| 			if got == nil { | ||||
| 				t.Errorf("Was expecting a non-nil Decoder, but got <nil>") | ||||
| 			} | ||||
| 			dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} | ||||
| 			got, err = NewDecoderWithContext(dc, bsonrw.NewBSONDocumentReader([]byte{})) | ||||
| 			noerr(t, err) | ||||
| 			if got == nil { | ||||
| 				t.Errorf("Was expecting a non-nil Decoder, but got <nil>") | ||||
| 			} | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("Decode doesn't zero struct", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		type foo struct { | ||||
| 			Item  string | ||||
| 			Qty   int | ||||
| 			Bonus int | ||||
| 		} | ||||
| 		var got foo | ||||
| 		got.Item = "apple" | ||||
| 		got.Bonus = 2 | ||||
| 		data := docToBytes(D{{"item", "canvas"}, {"qty", 4}}) | ||||
| 		vr := bsonrw.NewBSONDocumentReader(data) | ||||
| 		dec, err := NewDecoder(vr) | ||||
| 		noerr(t, err) | ||||
| 		err = dec.Decode(&got) | ||||
| 		noerr(t, err) | ||||
| 		want := foo{Item: "canvas", Qty: 4, Bonus: 2} | ||||
| 		assert.Equal(t, want, got, "Results do not match.") | ||||
| 	}) | ||||
| 	t.Run("Reset", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		vr1, vr2 := bsonrw.NewBSONDocumentReader([]byte{}), bsonrw.NewBSONDocumentReader([]byte{}) | ||||
| 		dc := bsoncodec.DecodeContext{Registry: DefaultRegistry} | ||||
| 		dec, err := NewDecoderWithContext(dc, vr1) | ||||
| 		noerr(t, err) | ||||
| 		if dec.vr != vr1 { | ||||
| 			t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr1) | ||||
| 		} | ||||
| 		err = dec.Reset(vr2) | ||||
| 		noerr(t, err) | ||||
| 		if dec.vr != vr2 { | ||||
| 			t.Errorf("Decoder should use the value reader provided. got %v; want %v", dec.vr, vr2) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("SetContext", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		dc1 := bsoncodec.DecodeContext{Registry: DefaultRegistry} | ||||
| 		dc2 := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} | ||||
| 		dec, err := NewDecoderWithContext(dc1, bsonrw.NewBSONDocumentReader([]byte{})) | ||||
| 		noerr(t, err) | ||||
| 		if !reflect.DeepEqual(dec.dc, dc1) { | ||||
| 			t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) | ||||
| 		} | ||||
| 		err = dec.SetContext(dc2) | ||||
| 		noerr(t, err) | ||||
| 		if !reflect.DeepEqual(dec.dc, dc2) { | ||||
| 			t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("SetRegistry", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() | ||||
| 		dc1 := bsoncodec.DecodeContext{Registry: r1} | ||||
| 		dc2 := bsoncodec.DecodeContext{Registry: r2} | ||||
| 		dec, err := NewDecoder(bsonrw.NewBSONDocumentReader([]byte{})) | ||||
| 		noerr(t, err) | ||||
| 		if !reflect.DeepEqual(dec.dc, dc1) { | ||||
| 			t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc1) | ||||
| 		} | ||||
| 		err = dec.SetRegistry(r2) | ||||
| 		noerr(t, err) | ||||
| 		if !reflect.DeepEqual(dec.dc, dc2) { | ||||
| 			t.Errorf("Decoder should use the Registry provided. got %v; want %v", dec.dc, dc2) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("DecodeToNil", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		data := docToBytes(D{{"item", "canvas"}, {"qty", 4}}) | ||||
| 		vr := bsonrw.NewBSONDocumentReader(data) | ||||
| 		dec, err := NewDecoder(vr) | ||||
| 		noerr(t, err) | ||||
|  | ||||
| 		var got *D | ||||
| 		err = dec.Decode(got) | ||||
| 		if err != ErrDecodeToNil { | ||||
| 			t.Fatalf("Decode error mismatch; expected %v, got %v", ErrDecodeToNil, err) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("DefaultDocuemntD embedded map as empty interface", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		type someMap map[string]interface{} | ||||
|  | ||||
| 		in := make(someMap) | ||||
| 		in["foo"] = map[string]interface{}{"bar": "baz"} | ||||
|  | ||||
| 		bytes, err := Marshal(in) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		var bsonOut someMap | ||||
| 		dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		dec.DefaultDocumentM() | ||||
| 		if err := dec.Decode(&bsonOut); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// Ensure that interface{}-typed top-level data is converted to the document type. | ||||
| 		bsonOutType := reflect.TypeOf(bsonOut) | ||||
| 		inType := reflect.TypeOf(in) | ||||
| 		assert.Equal(t, inType, bsonOutType, "expected %v to equal %v", inType.String(), bsonOutType.String()) | ||||
|  | ||||
| 		// Ensure that the embedded type is a primitive map. | ||||
| 		mType := reflect.TypeOf(primitive.M{}) | ||||
| 		bsonFooOutType := reflect.TypeOf(bsonOut["foo"]) | ||||
| 		assert.Equal(t, mType, bsonFooOutType, "expected %v to equal %v", mType.String(), bsonFooOutType.String()) | ||||
| 	}) | ||||
| 	t.Run("DefaultDocuemntD for decoding into interface{} alias", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		var in interface{} = map[string]interface{}{"bar": "baz"} | ||||
|  | ||||
| 		bytes, err := Marshal(in) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		var bsonOut interface{} | ||||
| 		dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		dec.DefaultDocumentD() | ||||
| 		if err := dec.Decode(&bsonOut); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// Ensure that interface{}-typed top-level data is converted to the document type. | ||||
| 		dType := reflect.TypeOf(primitive.D{}) | ||||
| 		bsonOutType := reflect.TypeOf(bsonOut) | ||||
| 		assert.Equal(t, dType, bsonOutType, | ||||
| 			"expected %v to equal %v", dType.String(), bsonOutType.String()) | ||||
| 	}) | ||||
| 	t.Run("DefaultDocuemntD for decoding into non-interface{} alias", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		var in interface{} = map[string]interface{}{"bar": "baz"} | ||||
|  | ||||
| 		bytes, err := Marshal(in) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		var bsonOut struct{} | ||||
| 		dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		dec.DefaultDocumentD() | ||||
| 		if err := dec.Decode(&bsonOut); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		// Ensure that typed top-level data is not converted to the document type. | ||||
| 		dType := reflect.TypeOf(primitive.D{}) | ||||
| 		bsonOutType := reflect.TypeOf(bsonOut) | ||||
| 		assert.NotEqual(t, dType, bsonOutType, | ||||
| 			"expected %v to not equal %v", dType.String(), bsonOutType.String()) | ||||
| 	}) | ||||
| 	t.Run("DefaultDocumentD for deep struct values", func(t *testing.T) { | ||||
| 		t.Parallel() | ||||
|  | ||||
| 		type emb struct { | ||||
| 			Foo map[int]interface{} `bson:"foo"` | ||||
| 		} | ||||
|  | ||||
| 		objID := primitive.NewObjectID() | ||||
|  | ||||
| 		in := emb{ | ||||
| 			Foo: map[int]interface{}{ | ||||
| 				1: map[string]interface{}{"bar": "baz"}, | ||||
| 				2: map[int]interface{}{ | ||||
| 					3: map[string]interface{}{"bar": "baz"}, | ||||
| 				}, | ||||
| 				4: map[primitive.ObjectID]interface{}{ | ||||
| 					objID: map[string]interface{}{"bar": "baz"}, | ||||
| 				}, | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 		bytes, err := Marshal(in) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		dec, err := NewDecoder(bsonrw.NewBSONDocumentReader(bytes)) | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		dec.DefaultDocumentD() | ||||
|  | ||||
| 		var out emb | ||||
| 		if err := dec.Decode(&out); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
|  | ||||
| 		mType := reflect.TypeOf(primitive.M{}) | ||||
| 		bsonOutType := reflect.TypeOf(out) | ||||
| 		assert.NotEqual(t, mType, bsonOutType, | ||||
| 			"expected %v to not equal %v", mType.String(), bsonOutType.String()) | ||||
|  | ||||
| 		want := emb{ | ||||
| 			Foo: map[int]interface{}{ | ||||
| 				1: primitive.D{{Key: "bar", Value: "baz"}}, | ||||
| 				2: primitive.D{{Key: "3", Value: primitive.D{{Key: "bar", Value: "baz"}}}}, | ||||
| 				4: primitive.D{{Key: objID.Hex(), Value: primitive.D{{Key: "bar", Value: "baz"}}}}, | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 		assert.Equal(t, want, out, "expected %v, got %v", want, out) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type testUnmarshaler struct { | ||||
| 	invoked bool | ||||
| 	err     error | ||||
| 	data    []byte | ||||
| } | ||||
|  | ||||
| func (tu *testUnmarshaler) UnmarshalBSON(d []byte) error { | ||||
| 	tu.invoked = true | ||||
| 	tu.data = d | ||||
| 	return tu.err | ||||
| } | ||||
							
								
								
									
										141
									
								
								mongo/bson/doc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								mongo/bson/doc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,141 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| // Package bson is a library for reading, writing, and manipulating BSON. BSON is a binary serialization format used to | ||||
| // store documents and make remote procedure calls in MongoDB. The BSON specification is located at https://bsonspec.org. | ||||
| // The BSON library handles marshalling and unmarshalling of values through a configurable codec system. For a description | ||||
| // of the codec system and examples of registering custom codecs, see the bsoncodec package. | ||||
| // | ||||
| // # Raw BSON | ||||
| // | ||||
| // The Raw family of types is used to validate and retrieve elements from a slice of bytes. This | ||||
| // type is most useful when you want do lookups on BSON bytes without unmarshaling it into another | ||||
| // type. | ||||
| // | ||||
| // Example: | ||||
| // | ||||
| //	var raw bson.Raw = ... // bytes from somewhere | ||||
| //	err := raw.Validate() | ||||
| //	if err != nil { return err } | ||||
| //	val := raw.Lookup("foo") | ||||
| //	i32, ok := val.Int32OK() | ||||
| //	// do something with i32... | ||||
| // | ||||
| // # Native Go Types | ||||
| // | ||||
| // The D and M types defined in this package can be used to build representations of BSON using native Go types. D is a | ||||
| // slice and M is a map. For more information about the use cases for these types, see the documentation on the type | ||||
| // definitions. | ||||
| // | ||||
| // Note that a D should not be constructed with duplicate key names, as that can cause undefined server behavior. | ||||
| // | ||||
| // Example: | ||||
| // | ||||
| //	bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}} | ||||
| //	bson.M{"foo": "bar", "hello": "world", "pi": 3.14159} | ||||
| // | ||||
| // When decoding BSON to a D or M, the following type mappings apply when unmarshalling: | ||||
| // | ||||
| //  1. BSON int32 unmarshals to an int32. | ||||
| //  2. BSON int64 unmarshals to an int64. | ||||
| //  3. BSON double unmarshals to a float64. | ||||
| //  4. BSON string unmarshals to a string. | ||||
| //  5. BSON boolean unmarshals to a bool. | ||||
| //  6. BSON embedded document unmarshals to the parent type (i.e. D for a D, M for an M). | ||||
| //  7. BSON array unmarshals to a bson.A. | ||||
| //  8. BSON ObjectId unmarshals to a primitive.ObjectID. | ||||
| //  9. BSON datetime unmarshals to a primitive.DateTime. | ||||
| //  10. BSON binary unmarshals to a primitive.Binary. | ||||
| //  11. BSON regular expression unmarshals to a primitive.Regex. | ||||
| //  12. BSON JavaScript unmarshals to a primitive.JavaScript. | ||||
| //  13. BSON code with scope unmarshals to a primitive.CodeWithScope. | ||||
| //  14. BSON timestamp unmarshals to an primitive.Timestamp. | ||||
| //  15. BSON 128-bit decimal unmarshals to an primitive.Decimal128. | ||||
| //  16. BSON min key unmarshals to an primitive.MinKey. | ||||
| //  17. BSON max key unmarshals to an primitive.MaxKey. | ||||
| //  18. BSON undefined unmarshals to a primitive.Undefined. | ||||
| //  19. BSON null unmarshals to nil. | ||||
| //  20. BSON DBPointer unmarshals to a primitive.DBPointer. | ||||
| //  21. BSON symbol unmarshals to a primitive.Symbol. | ||||
| // | ||||
| // The above mappings also apply when marshalling a D or M to BSON. Some other useful marshalling mappings are: | ||||
| // | ||||
| //  1. time.Time marshals to a BSON datetime. | ||||
| //  2. int8, int16, and int32 marshal to a BSON int32. | ||||
| //  3. int marshals to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, inclusive, and a BSON int64 | ||||
| //     otherwise. | ||||
| //  4. int64 marshals to BSON int64. | ||||
| //  5. uint8 and uint16 marshal to a BSON int32. | ||||
| //  6. uint, uint32, and uint64 marshal to a BSON int32 if the value is between math.MinInt32 and math.MaxInt32, | ||||
| //     inclusive, and BSON int64 otherwise. | ||||
| //  7. BSON null and undefined values will unmarshal into the zero value of a field (e.g. unmarshalling a BSON null or | ||||
| //     undefined value into a string will yield the empty string.). | ||||
| // | ||||
| // # Structs | ||||
| // | ||||
| // Structs can be marshalled/unmarshalled to/from BSON or Extended JSON. When transforming structs to/from BSON or Extended | ||||
| // JSON, the following rules apply: | ||||
| // | ||||
| //  1. Only exported fields in structs will be marshalled or unmarshalled. | ||||
| // | ||||
| //  2. When marshalling a struct, each field will be lowercased to generate the key for the corresponding BSON element. | ||||
| //     For example, a struct field named "Foo" will generate key "foo". This can be overridden via a struct tag (e.g. | ||||
| //     `bson:"fooField"` to generate key "fooField" instead). | ||||
| // | ||||
| //  3. An embedded struct field is marshalled as a subdocument. The key will be the lowercased name of the field's type. | ||||
| // | ||||
| //  4. A pointer field is marshalled as the underlying type if the pointer is non-nil. If the pointer is nil, it is | ||||
| //     marshalled as a BSON null value. | ||||
| // | ||||
| //  5. When unmarshalling, a field of type interface{} will follow the D/M type mappings listed above. BSON documents | ||||
| //     unmarshalled into an interface{} field will be unmarshalled as a D. | ||||
| // | ||||
| // The encoding of each struct field can be customized by the "bson" struct tag. | ||||
| // | ||||
| // This tag behavior is configurable, and different struct tag behavior can be configured by initializing a new | ||||
| // bsoncodec.StructCodec with the desired tag parser and registering that StructCodec onto the Registry. By default, JSON tags | ||||
| // are not honored, but that can be enabled by creating a StructCodec with JSONFallbackStructTagParser, like below: | ||||
| // | ||||
| // Example: | ||||
| // | ||||
| //	structcodec, _ := bsoncodec.NewStructCodec(bsoncodec.JSONFallbackStructTagParser) | ||||
| // | ||||
| // The bson tag gives the name of the field, possibly followed by a comma-separated list of options. | ||||
| // The name may be empty in order to specify options without overriding the default field name. The following options can be used | ||||
| // to configure behavior: | ||||
| // | ||||
| //  1. omitempty: If the omitempty struct tag is specified on a field, the field will not be marshalled if it is set to | ||||
| //     the zero value. Fields with language primitive types such as integers, booleans, and strings are considered empty if | ||||
| //     their value is equal to the zero value for the type (i.e. 0 for integers, false for booleans, and "" for strings). | ||||
| //     Slices, maps, and arrays are considered empty if they are of length zero. Interfaces and pointers are considered | ||||
| //     empty if their value is nil. By default, structs are only considered empty if the struct type implements the | ||||
| //     bsoncodec.Zeroer interface and the IsZero method returns true. Struct fields whose types do not implement Zeroer are | ||||
| //     never considered empty and will be marshalled as embedded documents. | ||||
| //     NOTE: It is recommended that this tag be used for all slice and map fields. | ||||
| // | ||||
| //  2. minsize: If the minsize struct tag is specified on a field of type int64, uint, uint32, or uint64 and the value of | ||||
| //     the field can fit in a signed int32, the field will be serialized as a BSON int32 rather than a BSON int64. For other | ||||
| //     types, this tag is ignored. | ||||
| // | ||||
| //  3. truncate: If the truncate struct tag is specified on a field with a non-float numeric type, BSON doubles unmarshalled | ||||
| //     into that field will be truncated at the decimal point. For example, if 3.14 is unmarshalled into a field of type int, | ||||
| //     it will be unmarshalled as 3. If this tag is not specified, the decoder will throw an error if the value cannot be | ||||
| //     decoded without losing precision. For float64 or non-numeric types, this tag is ignored. | ||||
| // | ||||
| //  4. inline: If the inline struct tag is specified for a struct or map field, the field will be "flattened" when | ||||
| //     marshalling and "un-flattened" when unmarshalling. This means that all of the fields in that struct/map will be | ||||
| //     pulled up one level and will become top-level fields rather than being fields in a nested document. For example, if a | ||||
| //     map field named "Map" with value map[string]interface{}{"foo": "bar"} is inlined, the resulting document will be | ||||
| //     {"foo": "bar"} instead of {"map": {"foo": "bar"}}. There can only be one inlined map field in a struct. If there are | ||||
| //     duplicated fields in the resulting document when an inlined struct is marshalled, the inlined field will be overwritten. | ||||
| //     If there are duplicated fields in the resulting document when an inlined map is marshalled, an error will be returned. | ||||
| //     This tag can be used with fields that are pointers to structs. If an inlined pointer field is nil, it will not be | ||||
| //     marshalled. For fields that are not maps or structs, this tag is ignored. | ||||
| // | ||||
| // # Marshalling and Unmarshalling | ||||
| // | ||||
| // Manually marshalling and unmarshalling can be done with the Marshal and Unmarshal family of functions. | ||||
| package bson | ||||
							
								
								
									
										99
									
								
								mongo/bson/encoder.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								mongo/bson/encoder.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,99 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"sync" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| ) | ||||
|  | ||||
| // This pool is used to keep the allocations of Encoders down. This is only used for the Marshal* | ||||
| // methods and is not consumable from outside of this package. The Encoders retrieved from this pool | ||||
| // must have both Reset and SetRegistry called on them. | ||||
| var encPool = sync.Pool{ | ||||
| 	New: func() interface{} { | ||||
| 		return new(Encoder) | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| // An Encoder writes a serialization format to an output stream. It writes to a bsonrw.ValueWriter | ||||
| // as the destination of BSON data. | ||||
| type Encoder struct { | ||||
| 	ec bsoncodec.EncodeContext | ||||
| 	vw bsonrw.ValueWriter | ||||
| } | ||||
|  | ||||
| // NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. | ||||
| func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) { | ||||
| 	if vw == nil { | ||||
| 		return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") | ||||
| 	} | ||||
|  | ||||
| 	return &Encoder{ | ||||
| 		ec: bsoncodec.EncodeContext{Registry: DefaultRegistry}, | ||||
| 		vw: vw, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // NewEncoderWithContext returns a new encoder that uses EncodeContext ec to write to vw. | ||||
| func NewEncoderWithContext(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter) (*Encoder, error) { | ||||
| 	if ec.Registry == nil { | ||||
| 		ec = bsoncodec.EncodeContext{Registry: DefaultRegistry} | ||||
| 	} | ||||
| 	if vw == nil { | ||||
| 		return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") | ||||
| 	} | ||||
|  | ||||
| 	return &Encoder{ | ||||
| 		ec: ec, | ||||
| 		vw: vw, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| // Encode writes the BSON encoding of val to the stream. | ||||
| // | ||||
| // The documentation for Marshal contains details about the conversion of Go | ||||
| // values to BSON. | ||||
| func (e *Encoder) Encode(val interface{}) error { | ||||
| 	if marshaler, ok := val.(Marshaler); ok { | ||||
| 		// TODO(skriptble): Should we have a MarshalAppender interface so that we can have []byte reuse? | ||||
| 		buf, err := marshaler.MarshalBSON() | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return bsonrw.Copier{}.CopyDocumentFromBytes(e.vw, buf) | ||||
| 	} | ||||
|  | ||||
| 	encoder, err := e.ec.LookupEncoder(reflect.TypeOf(val)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return encoder.EncodeValue(e.ec, e.vw, reflect.ValueOf(val)) | ||||
| } | ||||
|  | ||||
| // Reset will reset the state of the encoder, using the same *EncodeContext used in | ||||
| // the original construction but using vw. | ||||
| func (e *Encoder) Reset(vw bsonrw.ValueWriter) error { | ||||
| 	e.vw = vw | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SetRegistry replaces the current registry of the encoder with r. | ||||
| func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error { | ||||
| 	e.ec.Registry = r | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // SetContext replaces the current EncodeContext of the encoder with er. | ||||
| func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error { | ||||
| 	e.ec = ec | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										135
									
								
								mongo/bson/encoder_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								mongo/bson/encoder_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,135 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest" | ||||
| ) | ||||
|  | ||||
| func TestBasicEncode(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			got := make(bsonrw.SliceWriter, 0, 1024) | ||||
| 			vw, err := bsonrw.NewBSONValueWriter(&got) | ||||
| 			noerr(t, err) | ||||
| 			reg := DefaultRegistry | ||||
| 			encoder, err := reg.LookupEncoder(reflect.TypeOf(tc.val)) | ||||
| 			noerr(t, err) | ||||
| 			err = encoder.EncodeValue(bsoncodec.EncodeContext{Registry: reg}, vw, reflect.ValueOf(tc.val)) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncoderEncode(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			got := make(bsonrw.SliceWriter, 0, 1024) | ||||
| 			vw, err := bsonrw.NewBSONValueWriter(&got) | ||||
| 			noerr(t, err) | ||||
| 			enc, err := NewEncoder(vw) | ||||
| 			noerr(t, err) | ||||
| 			err = enc.Encode(tc.val) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	t.Run("Marshaler", func(t *testing.T) { | ||||
| 		testCases := []struct { | ||||
| 			name    string | ||||
| 			buf     []byte | ||||
| 			err     error | ||||
| 			wanterr error | ||||
| 			vw      bsonrw.ValueWriter | ||||
| 		}{ | ||||
| 			{ | ||||
| 				"error", | ||||
| 				nil, | ||||
| 				errors.New("Marshaler error"), | ||||
| 				errors.New("Marshaler error"), | ||||
| 				&bsonrwtest.ValueReaderWriter{}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				"copy error", | ||||
| 				[]byte{0x05, 0x00, 0x00, 0x00, 0x00}, | ||||
| 				nil, | ||||
| 				errors.New("copy error"), | ||||
| 				&bsonrwtest.ValueReaderWriter{Err: errors.New("copy error"), ErrAfter: bsonrwtest.WriteDocument}, | ||||
| 			}, | ||||
| 			{ | ||||
| 				"success", | ||||
| 				[]byte{0x07, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00}, | ||||
| 				nil, | ||||
| 				nil, | ||||
| 				nil, | ||||
| 			}, | ||||
| 		} | ||||
|  | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				marshaler := testMarshaler{buf: tc.buf, err: tc.err} | ||||
|  | ||||
| 				var vw bsonrw.ValueWriter | ||||
| 				var err error | ||||
| 				b := make(bsonrw.SliceWriter, 0, 100) | ||||
| 				compareVW := false | ||||
| 				if tc.vw != nil { | ||||
| 					vw = tc.vw | ||||
| 				} else { | ||||
| 					compareVW = true | ||||
| 					vw, err = bsonrw.NewBSONValueWriter(&b) | ||||
| 					noerr(t, err) | ||||
| 				} | ||||
| 				enc, err := NewEncoder(vw) | ||||
| 				noerr(t, err) | ||||
| 				got := enc.Encode(marshaler) | ||||
| 				want := tc.wanterr | ||||
| 				if !compareErrors(got, want) { | ||||
| 					t.Errorf("Did not receive expected error. got %v; want %v", got, want) | ||||
| 				} | ||||
| 				if compareVW { | ||||
| 					buf := b | ||||
| 					if !bytes.Equal(buf, tc.buf) { | ||||
| 						t.Errorf("Copied bytes do not match. got %v; want %v", buf, tc.buf) | ||||
| 					} | ||||
| 				} | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| type testMarshaler struct { | ||||
| 	buf []byte | ||||
| 	err error | ||||
| } | ||||
|  | ||||
| func (tm testMarshaler) MarshalBSON() ([]byte, error) { return tm.buf, tm.err } | ||||
|  | ||||
| func docToBytes(d interface{}) []byte { | ||||
| 	b, err := Marshal(d) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return b | ||||
| } | ||||
							
								
								
									
										47
									
								
								mongo/bson/extjson_prose_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								mongo/bson/extjson_prose_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| ) | ||||
|  | ||||
| func TestExtJSON(t *testing.T) { | ||||
| 	timestampNegativeInt32Err := fmt.Errorf("$timestamp i number should be uint32: -1") | ||||
| 	timestampNegativeInt64Err := fmt.Errorf("$timestamp i number should be uint32: -2147483649") | ||||
| 	timestampLargeValueErr := fmt.Errorf("$timestamp i number should be uint32: 4294967296") | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		name      string | ||||
| 		input     string | ||||
| 		canonical bool | ||||
| 		err       error | ||||
| 	}{ | ||||
| 		{"timestamp - negative int32 value", `{"":{"$timestamp":{"t":0,"i":-1}}}`, false, timestampNegativeInt32Err}, | ||||
| 		{"timestamp - negative int64 value", `{"":{"$timestamp":{"t":0,"i":-2147483649}}}`, false, timestampNegativeInt64Err}, | ||||
| 		{"timestamp - value overflows uint32", `{"":{"$timestamp":{"t":0,"i":4294967296}}}`, false, timestampLargeValueErr}, | ||||
| 		{"top level key is not treated as special", `{"$code": "foo"}`, false, nil}, | ||||
| 		{"escaped single quote errors", `{"f\'oo": "bar"}`, false, bsonrw.ErrInvalidJSON}, | ||||
| 	} | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			var res Raw | ||||
| 			err := UnmarshalExtJSON([]byte(tc.input), tc.canonical, &res) | ||||
| 			if tc.err == nil { | ||||
| 				assert.Nil(t, err, "UnmarshalExtJSON error: %v", err) | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			assert.NotNil(t, err, "expected error %v, got nil", tc.err) | ||||
| 			assert.Equal(t, tc.err.Error(), err.Error(), "expected error %v, got %v", tc.err, err) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										40
									
								
								mongo/bson/fuzz_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								mongo/bson/fuzz_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2023-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func FuzzDecode(f *testing.F) { | ||||
| 	seedBSONCorpus(f) | ||||
|  | ||||
| 	f.Fuzz(func(t *testing.T, data []byte) { | ||||
| 		for _, typ := range []func() interface{}{ | ||||
| 			func() interface{} { return new(D) }, | ||||
| 			func() interface{} { return new([]E) }, | ||||
| 			func() interface{} { return new(M) }, | ||||
| 			func() interface{} { return new(interface{}) }, | ||||
| 			func() interface{} { return make(map[string]interface{}) }, | ||||
| 			func() interface{} { return new([]interface{}) }, | ||||
| 		} { | ||||
| 			i := typ() | ||||
| 			if err := Unmarshal(data, i); err != nil { | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			encoded, err := Marshal(i) | ||||
| 			if err != nil { | ||||
| 				t.Fatal("failed to marshal", err) | ||||
| 			} | ||||
|  | ||||
| 			if err := Unmarshal(encoded, i); err != nil { | ||||
| 				t.Fatal("failed to unmarshal", err) | ||||
| 			} | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										248
									
								
								mongo/bson/marshal.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								mongo/bson/marshal.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,248 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| ) | ||||
|  | ||||
| const defaultDstCap = 256 | ||||
|  | ||||
| var bvwPool = bsonrw.NewBSONValueWriterPool() | ||||
| var extjPool = bsonrw.NewExtJSONValueWriterPool() | ||||
|  | ||||
| // Marshaler is an interface implemented by types that can marshal themselves | ||||
| // into a BSON document represented as bytes. The bytes returned must be a valid | ||||
| // BSON document if the error is nil. | ||||
| type Marshaler interface { | ||||
| 	MarshalBSON() ([]byte, error) | ||||
| } | ||||
|  | ||||
| // ValueMarshaler is an interface implemented by types that can marshal | ||||
| // themselves into a BSON value as bytes. The type must be the valid type for | ||||
| // the bytes returned. The bytes and byte type together must be valid if the | ||||
| // error is nil. | ||||
| type ValueMarshaler interface { | ||||
| 	MarshalBSONValue() (bsontype.Type, []byte, error) | ||||
| } | ||||
|  | ||||
| // Marshal returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed into a | ||||
| // document, MarshalValue should be used instead. | ||||
| // | ||||
| // Marshal will use the default registry created by NewRegistry to recursively | ||||
| // marshal val into a []byte. Marshal will inspect struct tags and alter the | ||||
| // marshaling process accordingly. | ||||
| func Marshal(val interface{}) ([]byte, error) { | ||||
| 	return MarshalWithRegistry(DefaultRegistry, val) | ||||
| } | ||||
|  | ||||
| // MarshalAppend will encode val as a BSON document and append the bytes to dst. If dst is not large enough to hold the | ||||
| // bytes, it will be grown. If val is not a type that can be transformed into a document, MarshalValueAppend should be | ||||
| // used instead. | ||||
| func MarshalAppend(dst []byte, val interface{}) ([]byte, error) { | ||||
| 	return MarshalAppendWithRegistry(DefaultRegistry, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalWithRegistry returns the BSON encoding of val as a BSON document. If val is not a type that can be transformed | ||||
| // into a document, MarshalValueWithRegistry should be used instead. | ||||
| func MarshalWithRegistry(r *bsoncodec.Registry, val interface{}) ([]byte, error) { | ||||
| 	dst := make([]byte, 0) | ||||
| 	return MarshalAppendWithRegistry(r, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalWithContext returns the BSON encoding of val as a BSON document using EncodeContext ec. If val is not a type | ||||
| // that can be transformed into a document, MarshalValueWithContext should be used instead. | ||||
| func MarshalWithContext(ec bsoncodec.EncodeContext, val interface{}) ([]byte, error) { | ||||
| 	dst := make([]byte, 0) | ||||
| 	return MarshalAppendWithContext(ec, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalAppendWithRegistry will encode val as a BSON document using Registry r and append the bytes to dst. If dst is | ||||
| // not large enough to hold the bytes, it will be grown. If val is not a type that can be transformed into a document, | ||||
| // MarshalValueAppendWithRegistry should be used instead. | ||||
| func MarshalAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) ([]byte, error) { | ||||
| 	return MarshalAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalAppendWithContext will encode val as a BSON document using Registry r and EncodeContext ec and append the | ||||
| // bytes to dst. If dst is not large enough to hold the bytes, it will be grown. If val is not a type that can be | ||||
| // transformed into a document, MarshalValueAppendWithContext should be used instead. | ||||
| func MarshalAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) ([]byte, error) { | ||||
| 	sw := new(bsonrw.SliceWriter) | ||||
| 	*sw = dst | ||||
| 	vw := bvwPool.Get(sw) | ||||
| 	defer bvwPool.Put(vw) | ||||
|  | ||||
| 	enc := encPool.Get().(*Encoder) | ||||
| 	defer encPool.Put(enc) | ||||
|  | ||||
| 	err := enc.Reset(vw) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = enc.SetContext(ec) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = enc.Encode(val) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return *sw, nil | ||||
| } | ||||
|  | ||||
| // MarshalValue returns the BSON encoding of val. | ||||
| // | ||||
| // MarshalValue will use bson.DefaultRegistry to transform val into a BSON value. If val is a struct, this function will | ||||
| // inspect struct tags and alter the marshalling process accordingly. | ||||
| func MarshalValue(val interface{}) (bsontype.Type, []byte, error) { | ||||
| 	return MarshalValueWithRegistry(DefaultRegistry, val) | ||||
| } | ||||
|  | ||||
| // MarshalValueAppend will append the BSON encoding of val to dst. If dst is not large enough to hold the BSON encoding | ||||
| // of val, dst will be grown. | ||||
| func MarshalValueAppend(dst []byte, val interface{}) (bsontype.Type, []byte, error) { | ||||
| 	return MarshalValueAppendWithRegistry(DefaultRegistry, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalValueWithRegistry returns the BSON encoding of val using Registry r. | ||||
| func MarshalValueWithRegistry(r *bsoncodec.Registry, val interface{}) (bsontype.Type, []byte, error) { | ||||
| 	dst := make([]byte, 0) | ||||
| 	return MarshalValueAppendWithRegistry(r, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalValueWithContext returns the BSON encoding of val using EncodeContext ec. | ||||
| func MarshalValueWithContext(ec bsoncodec.EncodeContext, val interface{}) (bsontype.Type, []byte, error) { | ||||
| 	dst := make([]byte, 0) | ||||
| 	return MarshalValueAppendWithContext(ec, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalValueAppendWithRegistry will append the BSON encoding of val to dst using Registry r. If dst is not large | ||||
| // enough to hold the BSON encoding of val, dst will be grown. | ||||
| func MarshalValueAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}) (bsontype.Type, []byte, error) { | ||||
| 	return MarshalValueAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val) | ||||
| } | ||||
|  | ||||
| // MarshalValueAppendWithContext will append the BSON encoding of val to dst using EncodeContext ec. If dst is not large | ||||
| // enough to hold the BSON encoding of val, dst will be grown. | ||||
| func MarshalValueAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}) (bsontype.Type, []byte, error) { | ||||
| 	// get a ValueWriter configured to write to dst | ||||
| 	sw := new(bsonrw.SliceWriter) | ||||
| 	*sw = dst | ||||
| 	vwFlusher := bvwPool.GetAtModeElement(sw) | ||||
|  | ||||
| 	// get an Encoder and encode the value | ||||
| 	enc := encPool.Get().(*Encoder) | ||||
| 	defer encPool.Put(enc) | ||||
| 	if err := enc.Reset(vwFlusher); err != nil { | ||||
| 		return 0, nil, err | ||||
| 	} | ||||
| 	if err := enc.SetContext(ec); err != nil { | ||||
| 		return 0, nil, err | ||||
| 	} | ||||
| 	if err := enc.Encode(val); err != nil { | ||||
| 		return 0, nil, err | ||||
| 	} | ||||
|  | ||||
| 	// flush the bytes written because we cannot guarantee that a full document has been written | ||||
| 	// after the flush, *sw will be in the format | ||||
| 	// [value type, 0 (null byte to indicate end of empty element name), value bytes..] | ||||
| 	if err := vwFlusher.Flush(); err != nil { | ||||
| 		return 0, nil, err | ||||
| 	} | ||||
| 	buffer := *sw | ||||
| 	return bsontype.Type(buffer[0]), buffer[2:], nil | ||||
| } | ||||
|  | ||||
| // MarshalExtJSON returns the extended JSON encoding of val. | ||||
| func MarshalExtJSON(val interface{}, canonical, escapeHTML bool) ([]byte, error) { | ||||
| 	return MarshalExtJSONWithRegistry(DefaultRegistry, val, canonical, escapeHTML) | ||||
| } | ||||
|  | ||||
| // MarshalExtJSONAppend will append the extended JSON encoding of val to dst. | ||||
| // If dst is not large enough to hold the extended JSON encoding of val, dst | ||||
| // will be grown. | ||||
| func MarshalExtJSONAppend(dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { | ||||
| 	return MarshalExtJSONAppendWithRegistry(DefaultRegistry, dst, val, canonical, escapeHTML) | ||||
| } | ||||
|  | ||||
| // MarshalExtJSONWithRegistry returns the extended JSON encoding of val using Registry r. | ||||
| func MarshalExtJSONWithRegistry(r *bsoncodec.Registry, val interface{}, canonical, escapeHTML bool) ([]byte, error) { | ||||
| 	dst := make([]byte, 0, defaultDstCap) | ||||
| 	return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML) | ||||
| } | ||||
|  | ||||
| // MarshalExtJSONWithContext returns the extended JSON encoding of val using Registry r. | ||||
| func MarshalExtJSONWithContext(ec bsoncodec.EncodeContext, val interface{}, canonical, escapeHTML bool) ([]byte, error) { | ||||
| 	dst := make([]byte, 0, defaultDstCap) | ||||
| 	return MarshalExtJSONAppendWithContext(ec, dst, val, canonical, escapeHTML) | ||||
| } | ||||
|  | ||||
| // MarshalExtJSONAppendWithRegistry will append the extended JSON encoding of | ||||
| // val to dst using Registry r. If dst is not large enough to hold the BSON | ||||
| // encoding of val, dst will be grown. | ||||
| func MarshalExtJSONAppendWithRegistry(r *bsoncodec.Registry, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { | ||||
| 	return MarshalExtJSONAppendWithContext(bsoncodec.EncodeContext{Registry: r}, dst, val, canonical, escapeHTML) | ||||
| } | ||||
|  | ||||
| // MarshalExtJSONAppendWithContext will append the extended JSON encoding of | ||||
| // val to dst using Registry r. If dst is not large enough to hold the BSON | ||||
| // encoding of val, dst will be grown. | ||||
| func MarshalExtJSONAppendWithContext(ec bsoncodec.EncodeContext, dst []byte, val interface{}, canonical, escapeHTML bool) ([]byte, error) { | ||||
| 	sw := new(bsonrw.SliceWriter) | ||||
| 	*sw = dst | ||||
| 	ejvw := extjPool.Get(sw, canonical, escapeHTML) | ||||
| 	defer extjPool.Put(ejvw) | ||||
|  | ||||
| 	enc := encPool.Get().(*Encoder) | ||||
| 	defer encPool.Put(enc) | ||||
|  | ||||
| 	err := enc.Reset(ejvw) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	err = enc.SetContext(ec) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	err = enc.Encode(val) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return *sw, nil | ||||
| } | ||||
|  | ||||
| // IndentExtJSON will prefix and indent the provided extended JSON src and append it to dst. | ||||
| func IndentExtJSON(dst *bytes.Buffer, src []byte, prefix, indent string) error { | ||||
| 	return json.Indent(dst, src, prefix, indent) | ||||
| } | ||||
|  | ||||
| // MarshalExtJSONIndent returns the extended JSON encoding of val with each line with prefixed | ||||
| // and indented. | ||||
| func MarshalExtJSONIndent(val interface{}, canonical, escapeHTML bool, prefix, indent string) ([]byte, error) { | ||||
| 	marshaled, err := MarshalExtJSON(val, canonical, escapeHTML) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	var buf bytes.Buffer | ||||
| 	err = IndentExtJSON(&buf, marshaled, prefix, indent) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return buf.Bytes(), nil | ||||
| } | ||||
							
								
								
									
										382
									
								
								mongo/bson/marshal_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										382
									
								
								mongo/bson/marshal_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,382 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/google/go-cmp/cmp" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsonrw" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| var tInt32 = reflect.TypeOf(int32(0)) | ||||
|  | ||||
| func TestMarshalAppendWithRegistry(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			dst := make([]byte, 0, 1024) | ||||
| 			var reg *bsoncodec.Registry | ||||
| 			if tc.reg != nil { | ||||
| 				reg = tc.reg | ||||
| 			} else { | ||||
| 				reg = DefaultRegistry | ||||
| 			} | ||||
| 			got, err := MarshalAppendWithRegistry(reg, dst, tc.val) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMarshalAppendWithContext(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			dst := make([]byte, 0, 1024) | ||||
| 			var reg *bsoncodec.Registry | ||||
| 			if tc.reg != nil { | ||||
| 				reg = tc.reg | ||||
| 			} else { | ||||
| 				reg = DefaultRegistry | ||||
| 			} | ||||
| 			ec := bsoncodec.EncodeContext{Registry: reg} | ||||
| 			got, err := MarshalAppendWithContext(ec, dst, tc.val) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMarshalWithRegistry(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			var reg *bsoncodec.Registry | ||||
| 			if tc.reg != nil { | ||||
| 				reg = tc.reg | ||||
| 			} else { | ||||
| 				reg = DefaultRegistry | ||||
| 			} | ||||
| 			got, err := MarshalWithRegistry(reg, tc.val) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMarshalWithContext(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			var reg *bsoncodec.Registry | ||||
| 			if tc.reg != nil { | ||||
| 				reg = tc.reg | ||||
| 			} else { | ||||
| 				reg = DefaultRegistry | ||||
| 			} | ||||
| 			ec := bsoncodec.EncodeContext{Registry: reg} | ||||
| 			got, err := MarshalWithContext(ec, tc.val) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMarshalAppend(t *testing.T) { | ||||
| 	for _, tc := range marshalingTestCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			if tc.reg != nil { | ||||
| 				t.Skip() // test requires custom registry | ||||
| 			} | ||||
| 			dst := make([]byte, 0, 1024) | ||||
| 			got, err := MarshalAppend(dst, tc.val) | ||||
| 			noerr(t, err) | ||||
|  | ||||
| 			if !bytes.Equal(got, tc.want) { | ||||
| 				t.Errorf("Bytes are not equal. got %v; want %v", got, tc.want) | ||||
| 				t.Errorf("Bytes:\n%v\n%v", got, tc.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestMarshalExtJSONAppendWithContext(t *testing.T) { | ||||
| 	t.Run("MarshalExtJSONAppendWithContext", func(t *testing.T) { | ||||
| 		dst := make([]byte, 0, 1024) | ||||
| 		type teststruct struct{ Foo int } | ||||
| 		val := teststruct{1} | ||||
| 		ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} | ||||
| 		got, err := MarshalExtJSONAppendWithContext(ec, dst, val, true, false) | ||||
| 		noerr(t, err) | ||||
| 		want := []byte(`{"foo":{"$numberInt":"1"}}`) | ||||
| 		if !bytes.Equal(got, want) { | ||||
| 			t.Errorf("Bytes are not equal. got %v; want %v", got, want) | ||||
| 			t.Errorf("Bytes:\n%s\n%s", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestMarshalExtJSONWithContext(t *testing.T) { | ||||
| 	t.Run("MarshalExtJSONWithContext", func(t *testing.T) { | ||||
| 		type teststruct struct{ Foo int } | ||||
| 		val := teststruct{1} | ||||
| 		ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} | ||||
| 		got, err := MarshalExtJSONWithContext(ec, val, true, false) | ||||
| 		noerr(t, err) | ||||
| 		want := []byte(`{"foo":{"$numberInt":"1"}}`) | ||||
| 		if !bytes.Equal(got, want) { | ||||
| 			t.Errorf("Bytes are not equal. got %v; want %v", got, want) | ||||
| 			t.Errorf("Bytes:\n%s\n%s", got, want) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestMarshal_roundtripFromBytes(t *testing.T) { | ||||
| 	before := []byte{ | ||||
| 		// length | ||||
| 		0x1c, 0x0, 0x0, 0x0, | ||||
|  | ||||
| 		// --- begin array --- | ||||
|  | ||||
| 		// type - document | ||||
| 		0x3, | ||||
| 		// key - "foo" | ||||
| 		0x66, 0x6f, 0x6f, 0x0, | ||||
|  | ||||
| 		// length | ||||
| 		0x12, 0x0, 0x0, 0x0, | ||||
| 		// type - string | ||||
| 		0x2, | ||||
| 		// key - "bar" | ||||
| 		0x62, 0x61, 0x72, 0x0, | ||||
| 		// value - string length | ||||
| 		0x4, 0x0, 0x0, 0x0, | ||||
| 		// value - "baz" | ||||
| 		0x62, 0x61, 0x7a, 0x0, | ||||
|  | ||||
| 		// null terminator | ||||
| 		0x0, | ||||
|  | ||||
| 		// --- end array --- | ||||
|  | ||||
| 		// null terminator | ||||
| 		0x0, | ||||
| 	} | ||||
|  | ||||
| 	var doc D | ||||
| 	require.NoError(t, Unmarshal(before, &doc)) | ||||
|  | ||||
| 	after, err := Marshal(doc) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	require.True(t, bytes.Equal(before, after)) | ||||
| } | ||||
|  | ||||
| func TestMarshal_roundtripFromDoc(t *testing.T) { | ||||
| 	before := D{ | ||||
| 		{"foo", "bar"}, | ||||
| 		{"baz", int64(-27)}, | ||||
| 		{"bing", A{nil, primitive.Regex{Pattern: "word", Options: "i"}}}, | ||||
| 	} | ||||
|  | ||||
| 	b, err := Marshal(before) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	var after D | ||||
| 	require.NoError(t, Unmarshal(b, &after)) | ||||
|  | ||||
| 	if !cmp.Equal(after, before) { | ||||
| 		t.Errorf("Documents to not match. got %v; want %v", after, before) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { | ||||
| 	// Encoders that have caches for recursive encoder lookup should not be shared across Registry instances. Otherwise, | ||||
| 	// the first EncodeValue call would cache an encoder and a subsequent call would see that encoder even if a | ||||
| 	// different Registry is used. | ||||
|  | ||||
| 	// Create a custom Registry that negates int32 values when encoding. | ||||
| 	var encodeInt32 bsoncodec.ValueEncoderFunc = func(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { | ||||
| 		if val.Kind() != reflect.Int32 { | ||||
| 			return fmt.Errorf("expected kind to be int32, got %v", val.Kind()) | ||||
| 		} | ||||
|  | ||||
| 		return vw.WriteInt32(int32(val.Int()) * -1) | ||||
| 	} | ||||
| 	customReg := NewRegistryBuilder(). | ||||
| 		RegisterTypeEncoder(tInt32, encodeInt32). | ||||
| 		Build() | ||||
|  | ||||
| 	// Helper function to run the test and make assertions. The provided original value should result in the document | ||||
| 	// {"x": {$numberInt: 1}} when marshalled with the default registry. | ||||
| 	verifyResults := func(t *testing.T, original interface{}) { | ||||
| 		// Marshal using the default and custom registries. Assert that the result is {x: 1} and {x: -1}, respectively. | ||||
|  | ||||
| 		first, err := Marshal(original) | ||||
| 		assert.Nil(t, err, "Marshal error: %v", err) | ||||
| 		expectedFirst := Raw(bsoncore.BuildDocumentFromElements( | ||||
| 			nil, | ||||
| 			bsoncore.AppendInt32Element(nil, "x", 1), | ||||
| 		)) | ||||
| 		assert.Equal(t, expectedFirst, Raw(first), "expected document %v, got %v", expectedFirst, Raw(first)) | ||||
|  | ||||
| 		second, err := MarshalWithRegistry(customReg, original) | ||||
| 		assert.Nil(t, err, "Marshal error: %v", err) | ||||
| 		expectedSecond := Raw(bsoncore.BuildDocumentFromElements( | ||||
| 			nil, | ||||
| 			bsoncore.AppendInt32Element(nil, "x", -1), | ||||
| 		)) | ||||
| 		assert.Equal(t, expectedSecond, Raw(second), "expected document %v, got %v", expectedSecond, Raw(second)) | ||||
| 	} | ||||
|  | ||||
| 	t.Run("struct", func(t *testing.T) { | ||||
| 		type Struct struct { | ||||
| 			X int32 | ||||
| 		} | ||||
| 		verifyResults(t, Struct{ | ||||
| 			X: 1, | ||||
| 		}) | ||||
| 	}) | ||||
| 	t.Run("pointer", func(t *testing.T) { | ||||
| 		i32 := int32(1) | ||||
| 		verifyResults(t, M{ | ||||
| 			"x": &i32, | ||||
| 		}) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestNullBytes(t *testing.T) { | ||||
| 	t.Run("element keys", func(t *testing.T) { | ||||
| 		doc := D{{"a\x00", "foobar"}} | ||||
| 		res, err := Marshal(doc) | ||||
| 		want := errors.New("BSON element key cannot contain null bytes") | ||||
| 		assert.Equal(t, want, err, "expected Marshal error %v, got error %v with result %q", want, err, Raw(res)) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("regex values", func(t *testing.T) { | ||||
| 		wantErr := errors.New("BSON regex values cannot contain null bytes") | ||||
|  | ||||
| 		testCases := []struct { | ||||
| 			name    string | ||||
| 			pattern string | ||||
| 			options string | ||||
| 		}{ | ||||
| 			{"null bytes in pattern", "a\x00", "i"}, | ||||
| 			{"null bytes in options", "pattern", "i\x00"}, | ||||
| 		} | ||||
| 		for _, tc := range testCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				regex := primitive.Regex{ | ||||
| 					Pattern: tc.pattern, | ||||
| 					Options: tc.options, | ||||
| 				} | ||||
| 				res, err := Marshal(D{{"foo", regex}}) | ||||
| 				assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res)) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("sub document field name", func(t *testing.T) { | ||||
| 		doc := D{{"foo", D{{"foobar", D{{"a\x00", "foobar"}}}}}} | ||||
| 		res, err := Marshal(doc) | ||||
| 		wantErr := errors.New("BSON element key cannot contain null bytes") | ||||
| 		assert.Equal(t, wantErr, err, "expected Marshal error %v, got error %v with result %q", wantErr, err, Raw(res)) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestMarshalExtJSONIndent(t *testing.T) { | ||||
| 	type indentTestCase struct { | ||||
| 		name            string | ||||
| 		val             interface{} | ||||
| 		expectedExtJSON string | ||||
| 	} | ||||
|  | ||||
| 	// expectedExtJSON must be written as below because single-quoted | ||||
| 	// literal strings capture undesired code formatting tabs | ||||
| 	testCases := []indentTestCase{ | ||||
| 		{ | ||||
| 			"empty val", | ||||
| 			struct{}{}, | ||||
| 			`{}`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"embedded struct", | ||||
| 			struct { | ||||
| 				Embedded interface{} `json:"embedded"` | ||||
| 				Foo      string      `json:"foo"` | ||||
| 			}{ | ||||
| 				Embedded: struct { | ||||
| 					Name string `json:"name"` | ||||
| 					Word string `json:"word"` | ||||
| 				}{ | ||||
| 					Name: "test", | ||||
| 					Word: "word", | ||||
| 				}, | ||||
| 				Foo: "bar", | ||||
| 			}, | ||||
| 			"{\n\t\"embedded\": {\n\t\t\"name\": \"test\",\n\t\t\"word\": \"word\"\n\t},\n\t\"foo\": \"bar\"\n}", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"date struct", | ||||
| 			struct { | ||||
| 				Foo  string    `json:"foo"` | ||||
| 				Date time.Time `json:"date"` | ||||
| 			}{ | ||||
| 				Foo:  "bar", | ||||
| 				Date: time.Date(2000, time.January, 1, 12, 0, 0, 0, time.UTC), | ||||
| 			}, | ||||
| 			"{\n\t\"foo\": \"bar\",\n\t\"date\": {\n\t\t\"$date\": {\n\t\t\t\"$numberLong\": \"946728000000\"\n\t\t}\n\t}\n}", | ||||
| 		}, | ||||
| 		{ | ||||
| 			"float struct", | ||||
| 			struct { | ||||
| 				Foo   string  `json:"foo"` | ||||
| 				Float float32 `json:"float"` | ||||
| 			}{ | ||||
| 				Foo:   "bar", | ||||
| 				Float: 3.14, | ||||
| 			}, | ||||
| 			"{\n\t\"foo\": \"bar\",\n\t\"float\": {\n\t\t\"$numberDouble\": \"3.140000104904175\"\n\t}\n}", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		tc := tc | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			t.Parallel() | ||||
| 			extJSONBytes, err := MarshalExtJSONIndent(tc.val, true, false, "", "\t") | ||||
| 			assert.Nil(t, err, "Marshal indent error: %v", err) | ||||
|  | ||||
| 			expectedExtJSONBytes := []byte(tc.expectedExtJSON) | ||||
|  | ||||
| 			assert.Equal(t, expectedExtJSONBytes, extJSONBytes, "expected:\n%s\ngot:\n%s", expectedExtJSONBytes, extJSONBytes) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										168
									
								
								mongo/bson/marshal_value_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										168
									
								
								mongo/bson/marshal_value_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,168 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"testing" | ||||
|  | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsontype" | ||||
| 	"go.mongodb.org/mongo-driver/bson/primitive" | ||||
| 	"go.mongodb.org/mongo-driver/internal/testutil/assert" | ||||
| 	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore" | ||||
| ) | ||||
|  | ||||
| // helper type for testing MarshalValue that implements io.Reader | ||||
| type marshalValueInterfaceInner struct { | ||||
| 	Foo int | ||||
| } | ||||
|  | ||||
| var _ io.Reader = marshalValueInterfaceInner{} | ||||
|  | ||||
| func (marshalValueInterfaceInner) Read([]byte) (int, error) { | ||||
| 	return 0, nil | ||||
| } | ||||
|  | ||||
| // helper type for testing MarshalValue that contains an interface | ||||
| type marshalValueInterfaceOuter struct { | ||||
| 	Reader io.Reader | ||||
| } | ||||
|  | ||||
| // helper type for testing MarshalValue that implements ValueMarshaler | ||||
| type marshalValueMarshaler struct { | ||||
| 	Foo int | ||||
| } | ||||
|  | ||||
| var _ ValueMarshaler = marshalValueMarshaler{} | ||||
|  | ||||
| func (mvi marshalValueMarshaler) MarshalBSONValue() (bsontype.Type, []byte, error) { | ||||
| 	return bsontype.Int32, bsoncore.AppendInt32(nil, int32(mvi.Foo)), nil | ||||
| } | ||||
|  | ||||
| type marshalValueStruct struct { | ||||
| 	Foo int | ||||
| } | ||||
|  | ||||
| type marshalValueTestCase struct { | ||||
| 	name          string | ||||
| 	val           interface{} | ||||
| 	expectedType  bsontype.Type | ||||
| 	expectedBytes []byte | ||||
| } | ||||
|  | ||||
| func TestMarshalValue(t *testing.T) { | ||||
| 	oid := primitive.NewObjectID() | ||||
| 	regex := primitive.Regex{Pattern: "pattern", Options: "imx"} | ||||
| 	dbPointer := primitive.DBPointer{DB: "db", Pointer: primitive.NewObjectID()} | ||||
| 	codeWithScope := primitive.CodeWithScope{Code: "code", Scope: D{{"a", "b"}}} | ||||
| 	idx, scopeCore := bsoncore.AppendDocumentStart(nil) | ||||
| 	scopeCore = bsoncore.AppendStringElement(scopeCore, "a", "b") | ||||
| 	scopeCore, _ = bsoncore.AppendDocumentEnd(scopeCore, idx) | ||||
| 	decimal128 := primitive.NewDecimal128(5, 10) | ||||
| 	interfaceTest := marshalValueInterfaceOuter{ | ||||
| 		Reader: marshalValueInterfaceInner{ | ||||
| 			Foo: 10, | ||||
| 		}, | ||||
| 	} | ||||
| 	interfaceCore, err := Marshal(interfaceTest) | ||||
| 	assert.Nil(t, err, "Marshal error: %v", err) | ||||
| 	structTest := marshalValueStruct{Foo: 10} | ||||
| 	structCore, err := Marshal(structTest) | ||||
| 	assert.Nil(t, err, "Marshal error: %v", err) | ||||
|  | ||||
| 	marshalValueTestCases := []marshalValueTestCase{ | ||||
| 		{"double", 3.14, bsontype.Double, bsoncore.AppendDouble(nil, 3.14)}, | ||||
| 		{"string", "hello world", bsontype.String, bsoncore.AppendString(nil, "hello world")}, | ||||
| 		{"binary", primitive.Binary{1, []byte{1, 2}}, bsontype.Binary, bsoncore.AppendBinary(nil, 1, []byte{1, 2})}, | ||||
| 		{"undefined", primitive.Undefined{}, bsontype.Undefined, []byte{}}, | ||||
| 		{"object id", oid, bsontype.ObjectID, bsoncore.AppendObjectID(nil, oid)}, | ||||
| 		{"boolean", true, bsontype.Boolean, bsoncore.AppendBoolean(nil, true)}, | ||||
| 		{"datetime", primitive.DateTime(5), bsontype.DateTime, bsoncore.AppendDateTime(nil, 5)}, | ||||
| 		{"null", primitive.Null{}, bsontype.Null, []byte{}}, | ||||
| 		{"regex", regex, bsontype.Regex, bsoncore.AppendRegex(nil, regex.Pattern, regex.Options)}, | ||||
| 		{"dbpointer", dbPointer, bsontype.DBPointer, bsoncore.AppendDBPointer(nil, dbPointer.DB, dbPointer.Pointer)}, | ||||
| 		{"javascript", primitive.JavaScript("js"), bsontype.JavaScript, bsoncore.AppendJavaScript(nil, "js")}, | ||||
| 		{"symbol", primitive.Symbol("symbol"), bsontype.Symbol, bsoncore.AppendSymbol(nil, "symbol")}, | ||||
| 		{"code with scope", codeWithScope, bsontype.CodeWithScope, bsoncore.AppendCodeWithScope(nil, "code", scopeCore)}, | ||||
| 		{"int32", 5, bsontype.Int32, bsoncore.AppendInt32(nil, 5)}, | ||||
| 		{"int64", int64(5), bsontype.Int64, bsoncore.AppendInt64(nil, 5)}, | ||||
| 		{"timestamp", primitive.Timestamp{T: 1, I: 5}, bsontype.Timestamp, bsoncore.AppendTimestamp(nil, 1, 5)}, | ||||
| 		{"decimal128", decimal128, bsontype.Decimal128, bsoncore.AppendDecimal128(nil, decimal128)}, | ||||
| 		{"min key", primitive.MinKey{}, bsontype.MinKey, []byte{}}, | ||||
| 		{"max key", primitive.MaxKey{}, bsontype.MaxKey, []byte{}}, | ||||
| 		{"struct", structTest, bsontype.EmbeddedDocument, structCore}, | ||||
| 		{"interface", interfaceTest, bsontype.EmbeddedDocument, interfaceCore}, | ||||
| 		{"D", D{{"foo", 10}}, bsontype.EmbeddedDocument, structCore}, | ||||
| 		{"M", M{"foo": 10}, bsontype.EmbeddedDocument, structCore}, | ||||
| 		{"ValueMarshaler", marshalValueMarshaler{Foo: 10}, bsontype.Int32, bsoncore.AppendInt32(nil, 10)}, | ||||
| 	} | ||||
|  | ||||
| 	t.Run("MarshalValue", func(t *testing.T) { | ||||
| 		for _, tc := range marshalValueTestCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				valueType, valueBytes, err := MarshalValue(tc.val) | ||||
| 				assert.Nil(t, err, "MarshalValue error: %v", err) | ||||
| 				compareMarshalValueResults(t, tc, valueType, valueBytes) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("MarshalValueAppend", func(t *testing.T) { | ||||
| 		for _, tc := range marshalValueTestCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				valueType, valueBytes, err := MarshalValueAppend(nil, tc.val) | ||||
| 				assert.Nil(t, err, "MarshalValueAppend error: %v", err) | ||||
| 				compareMarshalValueResults(t, tc, valueType, valueBytes) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("MarshalValueWithRegistry", func(t *testing.T) { | ||||
| 		for _, tc := range marshalValueTestCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				valueType, valueBytes, err := MarshalValueWithRegistry(DefaultRegistry, tc.val) | ||||
| 				assert.Nil(t, err, "MarshalValueWithRegistry error: %v", err) | ||||
| 				compareMarshalValueResults(t, tc, valueType, valueBytes) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("MarshalValueWithContext", func(t *testing.T) { | ||||
| 		ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} | ||||
| 		for _, tc := range marshalValueTestCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				valueType, valueBytes, err := MarshalValueWithContext(ec, tc.val) | ||||
| 				assert.Nil(t, err, "MarshalValueWithContext error: %v", err) | ||||
| 				compareMarshalValueResults(t, tc, valueType, valueBytes) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("MarshalValueAppendWithRegistry", func(t *testing.T) { | ||||
| 		for _, tc := range marshalValueTestCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				valueType, valueBytes, err := MarshalValueAppendWithRegistry(DefaultRegistry, nil, tc.val) | ||||
| 				assert.Nil(t, err, "MarshalValueAppendWithRegistry error: %v", err) | ||||
| 				compareMarshalValueResults(t, tc, valueType, valueBytes) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| 	t.Run("MarshalValueAppendWithContext", func(t *testing.T) { | ||||
| 		ec := bsoncodec.EncodeContext{Registry: DefaultRegistry} | ||||
| 		for _, tc := range marshalValueTestCases { | ||||
| 			t.Run(tc.name, func(t *testing.T) { | ||||
| 				valueType, valueBytes, err := MarshalValueAppendWithContext(ec, nil, tc.val) | ||||
| 				assert.Nil(t, err, "MarshalValueWithContext error: %v", err) | ||||
| 				compareMarshalValueResults(t, tc, valueType, valueBytes) | ||||
| 			}) | ||||
| 		} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func compareMarshalValueResults(t *testing.T, tc marshalValueTestCase, gotType bsontype.Type, gotBytes []byte) { | ||||
| 	t.Helper() | ||||
| 	expectedValue := RawValue{Type: tc.expectedType, Value: tc.expectedBytes} | ||||
| 	gotValue := RawValue{Type: gotType, Value: gotBytes} | ||||
| 	assert.Equal(t, expectedValue, gotValue, "value mismatch; expected %s, got %s", expectedValue, gotValue) | ||||
| } | ||||
							
								
								
									
										29
									
								
								mongo/bson/marshaling_cases_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								mongo/bson/marshaling_cases_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| // Copyright (C) MongoDB, Inc. 2017-present. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); you may | ||||
| // not use this file except in compliance with the License. You may obtain | ||||
| // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
| package bson | ||||
|  | ||||
| import ( | ||||
| 	"go.mongodb.org/mongo-driver/bson/bsoncodec" | ||||
| ) | ||||
|  | ||||
| type marshalingTestCase struct { | ||||
| 	name string | ||||
| 	reg  *bsoncodec.Registry | ||||
| 	val  interface{} | ||||
| 	want []byte | ||||
| } | ||||
|  | ||||
| var marshalingTestCases = []marshalingTestCase{ | ||||
| 	{ | ||||
| 		"small struct", | ||||
| 		nil, | ||||
| 		struct { | ||||
| 			Foo bool | ||||
| 		}{Foo: true}, | ||||
| 		docToBytes(D{{"foo", true}}), | ||||
| 	}, | ||||
| } | ||||
							
								
								
									
										1712
									
								
								mongo/bson/mgocompat/bson_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1712
									
								
								mongo/bson/mgocompat/bson_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user