Compare commits

...

100 Commits

Author SHA1 Message Date
9d88ab3a2b mongo data 2023-06-22 15:06:07 +02:00
2ec88e81f3 Patch mongo (add omitalways) 2023-06-18 16:04:34 +02:00
d471d7c396 Copied mongo repo (to patch it) 2023-06-18 15:52:17 +02:00
21d241f9b1 v0.0.163 2023-06-18 01:16:52 +02:00
2569c165f8 v0.0.162 2023-06-11 16:38:47 +02:00
ee262a94fb v0.0.161 2023-06-11 16:35:20 +02:00
7977c0e59c Added rfctime.Date type 2023-06-10 19:13:15 +02:00
ceff0161c6 v0.0.159 2023-06-10 18:35:56 +02:00
a30da61419 v0.0.158 2023-06-10 16:28:50 +02:00
b613b122e3 v0.0.157 2023-06-10 16:22:14 +02:00
d017530444 v0.0.156 2023-06-10 00:19:17 +02:00
8de83cc290 v0.0.155 2023-06-08 16:26:06 +02:00
603ec82b83 v0.0.154 2023-06-08 16:24:53 +02:00
93c4cf31a8 v0.0.153 2023-06-08 16:24:15 +02:00
dc2d8a9103 v0.0.152 2023-06-08 16:17:01 +02:00
6589e8d5cd v0.0.151 2023-06-07 17:57:03 +02:00
0006c6859d v0.0.150 2023-06-07 17:48:36 +02:00
827b3fc1b7 v0.0.149 2023-06-07 17:45:45 +02:00
f7dce4a102 v0.0.148 2023-06-07 17:22:38 +02:00
45d4fd7101 v0.0.147 2023-06-07 16:58:17 +02:00
c7df9d2264 v0.0.146 2023-06-07 12:59:15 +02:00
d0954bf133 v0.0.145 2023-06-07 12:45:48 +02:00
8affa81bb9 v0.0.144 2023-06-07 12:39:21 +02:00
fe9ebf0bab v0.0.143 2023-06-07 12:36:41 +02:00
a4b5f33d15 v0.0.142 2023-06-07 11:28:07 +02:00
e89e2c18f2 v0.0.141 2023-06-07 10:56:11 +02:00
b16d5152c7 v0.0.140 2023-06-07 10:42:56 +02:00
5fb2f8a312 v0.0.139 2023-06-06 21:40:34 +02:00
2ad820be8d v0.0.138 2023-06-06 21:33:49 +02:00
555096102a v0.0.137 2023-06-06 21:30:22 +02:00
d76d7b5cb9 v0.0.136 2023-06-06 21:26:12 +02:00
6622c9003d v0.0.135 2023-06-06 21:24:13 +02:00
b02e1d2e85 v0.0.134 2023-06-06 21:22:44 +02:00
c338d23070 v0.0.133 2023-06-06 21:18:40 +02:00
1fbae343a4 Fix RFC3339 serialization 2023-06-06 11:26:46 +02:00
31418bf0e6 v0.0.130 2023-06-05 13:30:32 +02:00
6d45f6f667 v0.0.129 2023-06-05 13:24:52 +02:00
f610a2202c v0.0.128 2023-06-02 09:44:31 +02:00
2807299d46 v0.0.127 2023-05-28 22:55:06 +02:00
e872dbccec v0.0.126 2023-05-28 19:53:30 +02:00
9daf71e2ed v0.0.125 2023-05-28 19:41:24 +02:00
fe278f7772 v0.0.124 2023-05-28 18:21:02 +02:00
8ebda6fb3a v0.0.123 2023-05-25 18:20:31 +02:00
b0d3ce8c1c v0.0.122 2023-05-24 22:01:29 +02:00
021465e524 v0.0.121 2023-05-24 21:55:21 +02:00
cf9c73aa4a v0.0.120 2023-05-24 21:42:10 +02:00
0652bf22dc v0.0.119 2023-05-24 21:32:00 +02:00
b196adffc7 v0.0.118 2023-05-09 11:33:01 +02:00
717065e62d v0.0.117 2023-05-09 09:57:05 +02:00
e7b2b040b2 v0.0.116 2023-05-05 18:22:15 +02:00
05d0f9e469 v0.0.115 2023-05-05 18:18:20 +02:00
ccd03e50c8 v0.0.114 2023-05-05 18:17:15 +02:00
1c77c2b8e8 v0.0.113 2023-05-05 18:05:58 +02:00
9f6f967299 v0.0.112 2023-05-05 18:00:25 +02:00
18c83f0f76 v0.0.111 2023-05-05 17:57:21 +02:00
a64f336e24 v0.0.110 2023-05-05 17:47:30 +02:00
14bbd205f8 v0.0.109 2023-05-05 15:04:08 +02:00
cecfb0d788 v0.0.108 2023-05-05 14:43:40 +02:00
a445e6f623 v0.0.107 2023-04-26 11:35:28 +02:00
0aa6310971 v0.0.106 2023-04-26 11:34:46 +02:00
2f66ab1cf0 v0.0.105 2023-04-23 19:31:48 +02:00
304e779470 v0.0.104 2023-04-23 14:54:23 +02:00
5e295d65c5 v0.0.103 2023-04-20 14:35:55 +02:00
ef3705937c gojson: added MarshalSafeCollections 2023-04-20 14:34:57 +02:00
d780c7965f added gojson as a go/json fork (tag go1.20.2) 2023-04-20 14:30:24 +02:00
c13db6802e v0.0.102 2023-04-13 14:40:07 +02:00
c5e23ab451 v0.0.101 2023-04-08 19:39:13 +02:00
c266d9204b v0.0.100 2023-04-04 17:10:38 +02:00
2550691e2e v0.0.99 2023-03-31 13:33:06 +02:00
ca24e1d5bf v0.0.98 2023-03-29 20:25:03 +02:00
b156052e6f v0.0.97 2023-03-29 19:53:53 +02:00
dda2418255 v0.0.96 2023-03-29 19:53:10 +02:00
8e40deae6a add git-pull to Makefile 2023-03-28 16:30:56 +02:00
289b9f47a2 v0.0.95 2023-03-28 16:29:16 +02:00
007c44df85 v0.0.94 2023-03-21 16:00:15 +01:00
a6252f0743 v0.0.93 2023-03-15 15:41:55 +01:00
86c01659d7 base58 2023-03-15 14:00:48 +01:00
62acddda5e v0.0.91 2023-03-11 14:38:19 +01:00
ee325f67fd v0.0.90 2023-03-09 14:51:53 +01:00
dba0cd229e v0.0.89 2023-03-07 10:43:30 +01:00
ec4dba173f v0.0.88 2023-02-16 13:27:34 +01:00
22ce2d26f3 v0.0.87 2023-02-16 13:22:15 +01:00
4fd768e573 v0.0.86 2023-02-14 17:18:58 +01:00
bf16a8165f v0.0.85 2023-02-14 16:25:45 +01:00
9f5612248a fix fd0 read error on long stdout output (scanner buffer was too small) 2023-02-13 01:41:33 +01:00
4a2b830252 added more tests to cmdrunner (reproduce another ?? cmdrunner bug...) 2023-02-09 16:49:33 +01:00
c492c80881 v0.0.83 2023-02-09 15:06:37 +01:00
26dd16d021 v0.0.82 2023-02-09 15:01:54 +01:00
b0b43de8ca v0.0.81 2023-02-09 11:27:49 +01:00
94f72e4ddf v0.0.80 2023-02-09 11:16:23 +01:00
df4388e6dc v0.0.79 2023-02-08 18:55:51 +01:00
fd33b43f31 v0.0.78 2023-02-03 01:05:36 +01:00
be4de07eb8 v0.0.77 2023-02-03 00:59:54 +01:00
36ed474bfe v0.0.76 2023-01-31 23:46:35 +01:00
fdc590c8c3 v0.0.75 2023-01-31 22:41:12 +01:00
1990e5d32d v0.0.74 2023-01-31 11:01:45 +01:00
72883cf6bd v0.0.73 2023-01-31 10:56:30 +01:00
ff08d5f180 v0.0.72 2023-01-30 19:55:55 +01:00
72d6b538f7 v0.0.71 2023-01-29 22:28:08 +01:00
48dd30fb94 v0.0.70 2023-01-29 22:07:28 +01:00
640 changed files with 157515 additions and 453 deletions

6
.idea/goext.iml generated
View File

@@ -1,6 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4"> <module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" /> <component name="Go" enabled="true">
<buildTags>
<option name="goVersion" value="1.19" />
</buildTags>
</component>
<component name="NewModuleRootManager"> <component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" /> <content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" /> <orderEntry type="inheritedJdk" />

View File

@@ -1,9 +1,16 @@
.PHONY: run test version update-mongo
run: run:
echo "This is a library - can't be run" && false echo "This is a library - can't be run" && false
test: test:
go test ./... # go test ./...
which gotestsum || go install gotest.tools/gotestsum@latest
gotestsum --format "testname" -- -tags="timetzdata sqlite_fts5 sqlite_foreign_keys" "./test"
version: version:
_data/version.sh _data/version.sh
update-mongo:
_data/update-mongo.sh

View File

@@ -5,4 +5,37 @@ A collection of general & useful library methods
This should not have any heavy dependencies (gin, mongo, etc) and add missing basic language features... This should not have any heavy dependencies (gin, mongo, etc) and add missing basic language features...
Potentially needs `export GOPRIVATE="gogs.mikescher.com"` Potentially needs `export GOPRIVATE="gogs.mikescher.com"`
### Packages:
| Name | Maintainer | Description |
|--------------|------------|---------------------------------------------------------------------------------------------------------------|
| langext | Mike | General uttility/helper functions, (everything thats missing from go standard library) |
| mathext | Mike | Utility/Helper functions for math |
| cryptext | Mike | Utility/Helper functions for encryption |
| syncext | Mike | Utility/Helper funtions for multi-threading / mutex / channels |
| dataext | Mike | Various useful data structures |
| zipext | Mike | Utility for zip/gzip/tar etc |
| reflectext | Mike | Utility for golagn reflection |
| | | |
| mongoext | Mike | Utility/Helper functions for mongodb |
| cursortoken | Mike | MongoDB cursortoken implementation |
| | | |
| totpext | Mike | Implementation of TOTP (2-Factor-Auth) |
| termext | Mike | Utilities for terminals (mostly color output) |
| confext | Mike | Parses environment configuration into structs |
| cmdext | Mike | Runner for external commands/processes |
| | | |
| sq | Mike | Utility functions for sql based databases |
| tst | Mike | Utility functions for unit tests |
| | | |
| rfctime | Mike | Classes for time seriallization, with different marshallign method for mongo and json |
| gojson | Mike | Same interface for marshalling/unmarshalling as go/json, except with proper serialization of null arrays/maps |
| | | |
| bfcodegen | Mike | Various codegen tools (run via go generate) |
| | | |
| rext | Mike | Regex Wrapper, wraps regexp with a better interface |
| wmo | Mike | Mongo Wrapper, wraps mongodb with a better interface |
| | | |

13
TODO.md Normal file
View File

@@ -0,0 +1,13 @@
- cronext
- cursortoken
- typed/geenric mongo wrapper
- error package
- rfctime.DateOnly
- rfctime.HMSTimeOnly
- rfctime.NanoTimeOnly

80
_data/mongo.patch Normal file
View File

@@ -0,0 +1,80 @@
diff --git a/mongo/bson/bsoncodec/struct_codec.go b/mongo/bson/bsoncodec/struct_codec.go
--- a/mongo/bson/bsoncodec/struct_codec.go
+++ b/mongo/bson/bsoncodec/struct_codec.go
@@ -122,6 +122,10 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r
}
var rv reflect.Value
for _, desc := range sd.fl {
+ if desc.omitAlways {
+ continue
+ }
+
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
@@ -400,15 +404,16 @@ type structDescription struct {
}
type fieldDescription struct {
- name string // BSON key name
- fieldName string // struct field name
- idx int
- omitEmpty bool
- minSize bool
- truncate bool
- inline []int
- encoder ValueEncoder
- decoder ValueDecoder
+ name string // BSON key name
+ fieldName string // struct field name
+ idx int
+ omitEmpty bool
+ omitAlways bool
+ minSize bool
+ truncate bool
+ inline []int
+ encoder ValueEncoder
+ decoder ValueDecoder
}
type byIndex []fieldDescription
@@ -491,6 +496,7 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
+ description.omitAlways = stags.OmitAlways
description.minSize = stags.MinSize
description.truncate = stags.Truncate
diff --git a/mongo/bson/bsoncodec/struct_tag_parser.go b/mongo/bson/bsoncodec/struct_tag_parser.go
--- a/mongo/bson/bsoncodec/struct_tag_parser.go
+++ b/mongo/bson/bsoncodec/struct_tag_parser.go
@@ -52,12 +52,13 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT
//
// TODO(skriptble): Add tags for undefined as nil and for null as nil.
type StructTags struct {
- Name string
- OmitEmpty bool
- MinSize bool
- Truncate bool
- Inline bool
- Skip bool
+ Name string
+ OmitEmpty bool
+ OmitAlways bool
+ MinSize bool
+ Truncate bool
+ Inline bool
+ Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
@@ -108,6 +109,8 @@ func parseTags(key string, tag string) (StructTags, error) {
switch str {
case "omitempty":
st.OmitEmpty = true
+ case "omitalways":
+ st.OmitAlways = true
case "minsize":
st.MinSize = true
case "truncate":

95
_data/update-mongo.sh Executable file
View File

@@ -0,0 +1,95 @@
#!/bin/bash
set -o nounset # disallow usage of unset vars ( set -u )
set -o errexit # Exit immediately if a pipeline returns non-zero. ( set -e )
set -o errtrace # Allow the above trap be inherited by all functions in the script. ( set -E )
set -o pipefail # Return value of a pipeline is the value of the last (rightmost) command to exit with a non-zero status
IFS=$'\n\t' # Set $IFS to only newline and tab.
dir="/tmp/mongo_repo_$( uuidgen )"
echo ""
echo "> Clone https://github.dev/mongodb/mongo-go-driver"
echo ""
git clone "https://github.com/mongodb/mongo-go-driver" "$dir"
pushd "$dir"
git fetch --tags
latestTag="$( git describe --tags `git rev-list --tags --max-count=1` )"
git -c "advice.detachedHead=false" checkout $latestTag
latestSHA="$( git rev-parse HEAD )"
popd
existingTag=$( cat mongoPatchVersion.go | grep -oP "(?<=const MongoCloneTag = \")([A-Za-z0-9.]+)(?=\")" )
existingSHA=$( cat mongoPatchVersion.go | grep -oP "(?<=const MongoCloneCommit = \")([A-Za-z0-9.]+)(?=\")" )
echo "===================================="
echo "ID (online) $latestSHA"
echo "ID (local) $existingSHA"
echo "Tag (online) $latestTag"
echo "Tag (local) $existingTag"
echo "===================================="
if [[ "$latestTag" == "$existingTag" ]]; then
echo "Nothing to do"
rm -rf "$dir"
exit 0
fi
echo ""
echo "> Copy repository"
echo ""
rm -rf mongo
cp -r "$dir" "mongo"
rm -rf "$dir"
echo ""
echo "> Clean repository"
echo ""
rm -rf "mongo/.git"
rm -rf "mongo/.evergreen"
rm -rf "mongo/cmd"
rm -rf "mongo/docs"
rm -rf "mongo/etc"
rm -rf "mongo/examples"
rm -rf "mongo/testdata"
rm -rf "mongo/benchmark"
rm -rf "mongo/vendor"
rm -rf "mongo/internal/test"
rm -rf "mongo/go.mod"
rm -rf "mongo/go.sum"
echo ""
echo "> Update mongoPatchVersion.go"
echo ""
{
printf "package goext\n"
printf "\n"
printf "// %s\n" "$( date +"%Y-%m-%d %H:%M:%S%z" )"
printf "\n"
printf "const MongoCloneTag = \"%s\"\n" "$latestTag"
printf "const MongoCloneCommit = \"%s\"\n" "$latestSHA"
} > mongoPatchVersion.go
echo ""
echo "> Patch mongo"
echo ""
git apply -v _data/mongo.patch
echo ""
echo "Done."

View File

@@ -7,6 +7,24 @@ set -o pipefail # Return value of a pipeline is the value of the last (rightmos
IFS=$'\n\t' # Set $IFS to only newline and tab. IFS=$'\n\t' # Set $IFS to only newline and tab.
function black() { echo -e "\x1B[30m $1 \x1B[0m"; }
function red() { echo -e "\x1B[31m $1 \x1B[0m"; }
function green() { echo -e "\x1B[32m $1 \x1B[0m"; }
function yellow(){ echo -e "\x1B[33m $1 \x1B[0m"; }
function blue() { echo -e "\x1B[34m $1 \x1B[0m"; }
function purple(){ echo -e "\x1B[35m $1 \x1B[0m"; }
function cyan() { echo -e "\x1B[36m $1 \x1B[0m"; }
function white() { echo -e "\x1B[37m $1 \x1B[0m"; }
if [ "$( git rev-parse --abbrev-ref HEAD )" != "master" ]; then
>&2 red "[ERROR] Can only create versions of <master>"
exit 1
fi
git pull --ff
go get -u ./...
curr_vers=$(git describe --tags --abbrev=0 | sed 's/v//g') curr_vers=$(git describe --tags --abbrev=0 | sed 's/v//g')
next_ver=$(echo "$curr_vers" | awk -F. -v OFS=. 'NF==1{print ++$NF}; NF>1{if(length($NF+1)>length($NF))$(NF-1)++; $NF=sprintf("%0*d", length($NF), ($NF+1)%(10^length($NF))); print}') next_ver=$(echo "$curr_vers" | awk -F. -v OFS=. 'NF==1{print ++$NF}; NF>1{if(length($NF+1)>length($NF))$(NF-1)++; $NF=sprintf("%0*d", length($NF), ($NF+1)%(10^length($NF))); print}')
@@ -16,9 +34,17 @@ echo "> Current Version: ${curr_vers}"
echo "> Next Version: ${next_ver}" echo "> Next Version: ${next_ver}"
echo "" echo ""
printf "package goext\n\nconst GoextVersion = \"%s\"\n\nconst GoextVersionTimestamp = \"%s\"\n" "${next_ver}" "$( date +"%Y-%m-%dT%H:%M:%S%z" )" > "goextVersion.go"
git add --verbose . git add --verbose .
git commit -a -m "v${next_ver}" msg="v${next_ver}"
if [ $# -gt 0 ]; then
msg="$1"
fi
git commit -a -m "${msg}"
git tag "v${next_ver}" git tag "v${next_ver}"

364
bfcodegen/enum-generate.go Normal file
View File

@@ -0,0 +1,364 @@
package bfcodegen
import (
"errors"
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext"
"gogs.mikescher.com/BlackForestBytes/goext/cmdext"
"gogs.mikescher.com/BlackForestBytes/goext/cryptext"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/rext"
"io"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"time"
)
type EnumDefVal struct {
VarName string
Value string
Description *string
}
type EnumDef struct {
File string
FileRelative string
EnumTypeName string
Type string
Values []EnumDefVal
}
var rexPackage = rext.W(regexp.MustCompile("^package\\s+(?P<name>[A-Za-z0-9_]+)\\s*$"))
var rexEnumDef = rext.W(regexp.MustCompile("^\\s*type\\s+(?P<name>[A-Za-z0-9_]+)\\s+(?P<type>[A-Za-z0-9_]+)\\s*//\\s*(@enum:type).*$"))
var rexValueDef = rext.W(regexp.MustCompile("^\\s*(?P<name>[A-Za-z0-9_]+)\\s+(?P<type>[A-Za-z0-9_]+)\\s*=\\s*(?P<value>(\"[A-Za-z0-9_:]+\"|[0-9]+))\\s*(//(?P<descr>.*))?.*$"))
var rexChecksumConst = rext.W(regexp.MustCompile("const ChecksumGenerator = \"(?P<cs>[A-Za-z0-9_]*)\""))
func GenerateEnumSpecs(sourceDir string, destFile string) error {
files, err := os.ReadDir(sourceDir)
if err != nil {
return err
}
oldChecksum := "N/A"
if _, err := os.Stat(destFile); !os.IsNotExist(err) {
content, err := os.ReadFile(destFile)
if err != nil {
return err
}
if m, ok := rexChecksumConst.MatchFirst(string(content)); ok {
oldChecksum = m.GroupByName("cs").Value()
}
}
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return v.Name() != path.Base(destFile) })
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return strings.HasSuffix(v.Name(), ".go") })
langext.SortBy(files, func(v os.DirEntry) string { return v.Name() })
newChecksumStr := goext.GoextVersion
for _, f := range files {
content, err := os.ReadFile(path.Join(sourceDir, f.Name()))
if err != nil {
return err
}
newChecksumStr += "\n" + f.Name() + "\t" + cryptext.BytesSha256(content)
}
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr))
if newChecksum != oldChecksum {
fmt.Printf("[EnumGenerate] Checksum has changed ( %s -> %s ), will generate new file\n\n", oldChecksum, newChecksum)
} else {
fmt.Printf("[EnumGenerate] Checksum unchanged ( %s ), nothing to do\n", oldChecksum)
return nil
}
allEnums := make([]EnumDef, 0)
pkgname := ""
for _, f := range files {
fmt.Printf("========= %s =========\n\n", f.Name())
fileEnums, pn, err := processFile(sourceDir, path.Join(sourceDir, f.Name()))
if err != nil {
return err
}
fmt.Printf("\n")
allEnums = append(allEnums, fileEnums...)
if pn != "" {
pkgname = pn
}
}
if pkgname == "" {
return errors.New("no package name found in any file")
}
err = os.WriteFile(destFile, []byte(fmtOutput(newChecksum, allEnums, pkgname)), 0o755)
if err != nil {
return err
}
res, err := cmdext.RunCommand("go", []string{"fmt", destFile}, langext.Ptr(2*time.Second))
if err != nil {
return err
}
if res.CommandTimedOut {
fmt.Println(res.StdCombined)
return errors.New("go fmt timed out")
}
if res.ExitCode != 0 {
fmt.Println(res.StdCombined)
return errors.New("go fmt did not succeed")
}
return nil
}
func processFile(basedir string, fn string) ([]EnumDef, string, error) {
file, err := os.Open(fn)
if err != nil {
return nil, "", err
}
defer func() { _ = file.Close() }()
bin, err := io.ReadAll(file)
if err != nil {
return nil, "", err
}
lines := strings.Split(string(bin), "\n")
enums := make([]EnumDef, 0)
pkgname := ""
for i, line := range lines {
if i == 0 && strings.HasPrefix(line, "// Code generated by") {
break
}
if match, ok := rexPackage.MatchFirst(line); i == 0 && ok {
pkgname = match.GroupByName("name").Value()
continue
}
if match, ok := rexEnumDef.MatchFirst(line); ok {
rfp, err := filepath.Rel(basedir, fn)
if err != nil {
return nil, "", err
}
def := EnumDef{
File: fn,
FileRelative: rfp,
EnumTypeName: match.GroupByName("name").Value(),
Type: match.GroupByName("type").Value(),
Values: make([]EnumDefVal, 0),
}
enums = append(enums, def)
fmt.Printf("Found enum definition { '%s' -> '%s' }\n", def.EnumTypeName, def.Type)
}
if match, ok := rexValueDef.MatchFirst(line); ok {
typename := match.GroupByName("type").Value()
def := EnumDefVal{
VarName: match.GroupByName("name").Value(),
Value: match.GroupByName("value").Value(),
Description: match.GroupByNameOrEmpty("descr").ValueOrNil(),
}
found := false
for i, v := range enums {
if v.EnumTypeName == typename {
enums[i].Values = append(enums[i].Values, def)
found = true
if def.Description != nil {
fmt.Printf("Found enum value [%s] for '%s' ('%s')\n", def.Value, def.VarName, *def.Description)
} else {
fmt.Printf("Found enum value [%s] for '%s'\n", def.Value, def.VarName)
}
break
}
}
if !found {
fmt.Printf("Found non-enum value [%s] for '%s' ( looks like enum value, but no matching @enum:type )\n", def.Value, def.VarName)
}
}
}
return enums, pkgname, nil
}
func fmtOutput(cs string, enums []EnumDef, pkgname string) string {
str := "// Code generated by enum-generate.go DO NOT EDIT.\n"
str += "\n"
str += "package " + pkgname + "\n"
str += "\n"
str += "import \"gogs.mikescher.com/BlackForestBytes/goext/langext\"" + "\n"
str += "\n"
str += "const ChecksumGenerator = \"" + cs + "\"" + "\n"
str += "\n"
str += "type Enum interface {" + "\n"
str += " Valid() bool" + "\n"
str += " ValuesAny() []any" + "\n"
str += " ValuesMeta() []EnumMetaValue" + "\n"
str += " VarName() string" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "type StringEnum interface {" + "\n"
str += " Enum" + "\n"
str += " String() string" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "type DescriptionEnum interface {" + "\n"
str += " Enum" + "\n"
str += " Description() string" + "\n"
str += "}" + "\n"
str += "\n"
str += "type EnumMetaValue struct {" + "\n"
str += " VarName string `json:\"varName\"`" + "\n"
str += " Value any `json:\"value\"`" + "\n"
str += " Description *string `json:\"description\"`" + "\n"
str += "}" + "\n"
str += "\n"
for _, enumdef := range enums {
hasDescr := langext.ArrAll(enumdef.Values, func(val EnumDefVal) bool { return val.Description != nil })
hasStr := enumdef.Type == "string"
str += "// ================================ " + enumdef.EnumTypeName + " ================================" + "\n"
str += "//" + "\n"
str += "// File: " + enumdef.FileRelative + "\n"
str += "// StringEnum: " + langext.Conditional(hasStr, "true", "false") + "\n"
str += "// DescrEnum: " + langext.Conditional(hasDescr, "true", "false") + "\n"
str += "//" + "\n"
str += "" + "\n"
str += "var __" + enumdef.EnumTypeName + "Values = []" + enumdef.EnumTypeName + "{" + "\n"
for _, v := range enumdef.Values {
str += " " + v.VarName + "," + "\n"
}
str += "}" + "\n"
str += "" + "\n"
if hasDescr {
str += "var __" + enumdef.EnumTypeName + "Descriptions = map[" + enumdef.EnumTypeName + "]string{" + "\n"
for _, v := range enumdef.Values {
str += " " + v.VarName + ": \"" + strings.TrimSpace(*v.Description) + "\"," + "\n"
}
str += "}" + "\n"
str += "" + "\n"
}
str += "var __" + enumdef.EnumTypeName + "Varnames = map[" + enumdef.EnumTypeName + "]string{" + "\n"
for _, v := range enumdef.Values {
str += " " + v.VarName + ": \"" + v.VarName + "\"," + "\n"
}
str += "}" + "\n"
str += "" + "\n"
str += "func (e " + enumdef.EnumTypeName + ") Valid() bool {" + "\n"
str += " return langext.InArray(e, __" + enumdef.EnumTypeName + "Values)" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func (e " + enumdef.EnumTypeName + ") Values() []" + enumdef.EnumTypeName + " {" + "\n"
str += " return __" + enumdef.EnumTypeName + "Values" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func (e " + enumdef.EnumTypeName + ") ValuesAny() []any {" + "\n"
str += " return langext.ArrCastToAny(__" + enumdef.EnumTypeName + "Values)" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func (e " + enumdef.EnumTypeName + ") ValuesMeta() []EnumMetaValue {" + "\n"
str += " return []EnumMetaValue{" + "\n"
for _, v := range enumdef.Values {
if hasDescr {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: langext.Ptr(\"%s\")},", v.VarName, v.VarName, strings.TrimSpace(*v.Description)) + "\n"
} else {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: nil},", v.VarName, v.VarName) + "\n"
}
}
str += " }" + "\n"
str += "}" + "\n"
str += "" + "\n"
if hasStr {
str += "func (e " + enumdef.EnumTypeName + ") String() string {" + "\n"
str += " return string(e)" + "\n"
str += "}" + "\n"
str += "" + "\n"
}
if hasDescr {
str += "func (e " + enumdef.EnumTypeName + ") Description() string {" + "\n"
str += " if d, ok := __" + enumdef.EnumTypeName + "Descriptions[e]; ok {" + "\n"
str += " return d" + "\n"
str += " }" + "\n"
str += " return \"\"" + "\n"
str += "}" + "\n"
str += "" + "\n"
}
str += "func (e " + enumdef.EnumTypeName + ") VarName() string {" + "\n"
str += " if d, ok := __" + enumdef.EnumTypeName + "Varnames[e]; ok {" + "\n"
str += " return d" + "\n"
str += " }" + "\n"
str += " return \"\"" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func Parse" + enumdef.EnumTypeName + "(vv string) (" + enumdef.EnumTypeName + ", bool) {" + "\n"
str += " for _, ev := range __" + enumdef.EnumTypeName + "Values {" + "\n"
str += " if string(ev) == vv {" + "\n"
str += " return ev, true" + "\n"
str += " }" + "\n"
str += " }" + "\n"
str += " return \"\", false" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func " + enumdef.EnumTypeName + "Values() []" + enumdef.EnumTypeName + " {" + "\n"
str += " return __" + enumdef.EnumTypeName + "Values" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func " + enumdef.EnumTypeName + "ValuesMeta() []EnumMetaValue {" + "\n"
str += " return []EnumMetaValue{" + "\n"
for _, v := range enumdef.Values {
if hasDescr {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: langext.Ptr(\"%s\")},", v.VarName, v.VarName, strings.TrimSpace(*v.Description)) + "\n"
} else {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: nil},", v.VarName, v.VarName) + "\n"
}
}
str += " }" + "\n"
str += "}" + "\n"
str += "" + "\n"
}
return str
}

View File

@@ -0,0 +1,15 @@
package bfcodegen
import (
"testing"
)
func TestApplyEnvOverridesSimple(t *testing.T) {
err := GenerateEnumSpecs("/home/mike/Code/reiff/badennet/bnet-backend/models", "/home/mike/Code/reiff/badennet/bnet-backend/models/enums_gen.go")
if err != nil {
t.Error(err)
t.Fail()
}
}

View File

@@ -2,22 +2,29 @@ package cmdext
import ( import (
"fmt" "fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"time" "time"
) )
type CommandRunner struct { type CommandRunner struct {
program string program string
args []string args []string
timeout *time.Duration timeout *time.Duration
env []string env []string
listener []CommandListener
enforceExitCodes *[]int
enforceNoTimeout bool
} }
func Runner(program string) *CommandRunner { func Runner(program string) *CommandRunner {
return &CommandRunner{ return &CommandRunner{
program: program, program: program,
args: make([]string, 0), args: make([]string, 0),
timeout: nil, timeout: nil,
env: make([]string, 0), env: make([]string, 0),
listener: make([]CommandListener, 0),
enforceExitCodes: nil,
enforceNoTimeout: false,
} }
} }
@@ -51,6 +58,36 @@ func (r *CommandRunner) Envs(env []string) *CommandRunner {
return r return r
} }
func (r *CommandRunner) EnsureExitcode(arg ...int) *CommandRunner {
r.enforceExitCodes = langext.Ptr(langext.ForceArray(arg))
return r
}
func (r *CommandRunner) FailOnExitCode() *CommandRunner {
r.enforceExitCodes = langext.Ptr([]int{0})
return r
}
func (r *CommandRunner) FailOnTimeout() *CommandRunner {
r.enforceNoTimeout = true
return r
}
func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner {
r.listener = append(r.listener, lstr)
return r
}
func (r *CommandRunner) ListenStdout(lstr func(string)) *CommandRunner {
r.listener = append(r.listener, genericCommandListener{_readStdoutLine: &lstr})
return r
}
func (r *CommandRunner) ListenStderr(lstr func(string)) *CommandRunner {
r.listener = append(r.listener, genericCommandListener{_readStderrLine: &lstr})
return r
}
func (r *CommandRunner) Run() (CommandResult, error) { func (r *CommandRunner) Run() (CommandResult, error) {
return run(*r) return run(*r)
} }

View File

@@ -1,11 +1,17 @@
package cmdext package cmdext
import ( import (
"bufio" "errors"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/mathext"
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
"os/exec" "os/exec"
"time" "time"
) )
var ErrExitCode = errors.New("process exited with an unexpected exitcode")
var ErrTimeout = errors.New("process did not exit after the specified timeout")
type CommandResult struct { type CommandResult struct {
StdOut string StdOut string
StdErr string StdErr string
@@ -16,7 +22,8 @@ type CommandResult struct {
func run(opt CommandRunner) (CommandResult, error) { func run(opt CommandRunner) (CommandResult, error) {
cmd := exec.Command(opt.program, opt.args...) cmd := exec.Command(opt.program, opt.args...)
cmd.Env = append(cmd.Env, opt.env)
cmd.Env = append(cmd.Env, opt.env...)
stdoutPipe, err := cmd.StdoutPipe() stdoutPipe, err := cmd.StdoutPipe()
if err != nil { if err != nil {
@@ -28,51 +35,43 @@ func run(opt CommandRunner) (CommandResult, error) {
return CommandResult{}, err return CommandResult{}, err
} }
preader := pipeReader{
lineBufferSize: langext.Ptr(128 * 1024 * 1024), // 128MB max size of a single line, is hopefully enough....
stdout: stdoutPipe,
stderr: stderrPipe,
}
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
return CommandResult{}, err return CommandResult{}, err
} }
errch := make(chan error, 1) type resultObj struct {
go func() { errch <- cmd.Wait() }() stdout string
stderr string
stdcombined string
err error
}
combch := make(chan string, 32) outputChan := make(chan resultObj)
stopCombch := make(chan bool)
stdout := ""
go func() { go func() {
scanner := bufio.NewScanner(stdoutPipe) // we need to first fully read the pipes and then call Wait
for scanner.Scan() { // see https://pkg.go.dev/os/exec#Cmd.StdoutPipe
txt := scanner.Text()
stdout += txt
combch <- txt
}
}()
stderr := "" stdout, stderr, stdcombined, err := preader.Read(opt.listener)
go func() { if err != nil {
scanner := bufio.NewScanner(stderrPipe) outputChan <- resultObj{stdout, stderr, stdcombined, err}
for scanner.Scan() { _ = cmd.Process.Kill()
txt := scanner.Text() return
stderr += txt
combch <- txt
} }
}()
defer func() { err = cmd.Wait()
stopCombch <- true if err != nil {
}() outputChan <- resultObj{stdout, stderr, stdcombined, err}
} else {
stdcombined := "" outputChan <- resultObj{stdout, stderr, stdcombined, nil}
go func() {
for {
select {
case txt := <-combch:
stdcombined += txt
case <-stopCombch:
return
}
} }
}() }()
var timeoutChan <-chan time.Time = make(chan time.Time, 1) var timeoutChan <-chan time.Time = make(chan time.Time, 1)
@@ -84,33 +83,72 @@ func run(opt CommandRunner) (CommandResult, error) {
case <-timeoutChan: case <-timeoutChan:
_ = cmd.Process.Kill() _ = cmd.Process.Kill()
return CommandResult{ for _, lstr := range opt.listener {
StdOut: stdout, lstr.Timeout()
StdErr: stderr, }
StdCombined: stdcombined,
ExitCode: -1,
CommandTimedOut: true,
}, nil
case err := <-errch: if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, mathext.Min(32*time.Millisecond, *opt.timeout)); ok {
if exiterr, ok := err.(*exec.ExitError); ok { // most of the time the cmd.Process.Kill() should also ahve finished the pipereader
return CommandResult{ // and we can at least return the already collected stdout, stderr, etc
StdOut: stdout, res := CommandResult{
StdErr: stderr, StdOut: fallback.stdout,
StdCombined: stdcombined, StdErr: fallback.stderr,
ExitCode: exiterr.ExitCode(), StdCombined: fallback.stdcombined,
ExitCode: -1,
CommandTimedOut: true,
}
if opt.enforceNoTimeout {
return res, ErrTimeout
}
return res, nil
} else {
res := CommandResult{
StdOut: "",
StdErr: "",
StdCombined: "",
ExitCode: -1,
CommandTimedOut: true,
}
if opt.enforceNoTimeout {
return res, ErrTimeout
}
return res, nil
}
case outobj := <-outputChan:
if exiterr, ok := outobj.err.(*exec.ExitError); ok {
excode := exiterr.ExitCode()
for _, lstr := range opt.listener {
lstr.Finished(excode)
}
res := CommandResult{
StdOut: outobj.stdout,
StdErr: outobj.stderr,
StdCombined: outobj.stdcombined,
ExitCode: excode,
CommandTimedOut: false, CommandTimedOut: false,
}, nil }
if opt.enforceExitCodes != nil && !langext.InArray(excode, *opt.enforceExitCodes) {
return res, ErrExitCode
}
return res, nil
} else if err != nil { } else if err != nil {
return CommandResult{}, err return CommandResult{}, err
} else { } else {
return CommandResult{ for _, lstr := range opt.listener {
StdOut: stdout, lstr.Finished(0)
StdErr: stderr, }
StdCombined: stdcombined, res := CommandResult{
StdOut: outobj.stdout,
StdErr: outobj.stderr,
StdCombined: outobj.stdcombined,
ExitCode: 0, ExitCode: 0,
CommandTimedOut: false, CommandTimedOut: false,
}, nil }
if opt.enforceExitCodes != nil && !langext.InArray(0, *opt.enforceExitCodes) {
return res, ErrExitCode
}
return res, nil
} }
} }
} }

323
cmdext/cmdrunner_test.go Normal file
View File

@@ -0,0 +1,323 @@
package cmdext
import (
"fmt"
"testing"
"time"
)
func TestStdout(t *testing.T) {
res1, err := Runner("printf").Arg("hello").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "hello" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "hello\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestStderr(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "error" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "error\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestStdcombined(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"1\", file=sys.stderr, flush=True); time.sleep(0.1); print(\"2\", file=sys.stdout, flush=True); time.sleep(0.1); print(\"3\", file=sys.stderr, flush=True)").
Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "1\n3\n" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "2\n" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "1\n2\n3\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialRead(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", flush=True); time.sleep(5); print(\"cant see me\", flush=True);").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "first message\n" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadStderr(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, flush=True); time.sleep(5); print(\"cant see me\", file=sys.stderr, flush=True);").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "first message\n" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestReadUnflushedStdout(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stdout, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "message101" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "message101\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestReadUnflushedStderr(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "message101" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "message101\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadUnflushed(t *testing.T) {
t.SkipNow()
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", end=''); time.sleep(5); print(\"cant see me\", end='');").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "first message" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadUnflushedStderr(t *testing.T) {
t.SkipNow()
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, end=''); time.sleep(5); print(\"cant see me\", file=sys.stderr, end='');").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "first message" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestListener(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys;" +
"import time;" +
"print(\"message 1\", flush=True);" +
"time.sleep(1);" +
"print(\"message 2\", flush=True);" +
"time.sleep(1);" +
"print(\"message 3\", flush=True);" +
"time.sleep(1);" +
"print(\"message 4\", file=sys.stderr, flush=True);" +
"time.sleep(1);" +
"print(\"message 5\", flush=True);" +
"time.sleep(1);" +
"print(\"final\");").
ListenStdout(func(s string) { fmt.Printf("@@STDOUT <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
ListenStderr(func(s string) { fmt.Printf("@@STDERR <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
Timeout(10 * time.Second).
Run()
if err != nil {
panic(err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
}
func TestLongStdout(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"X\" * 125001 + \"\\n\"); print(\"Y\" * 125001 + \"\\n\"); print(\"Z\" * 125001 + \"\\n\");").
Timeout(5000 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if len(res1.StdOut) != 375009 {
t.Errorf("len(res1.StdOut) == '%v'", len(res1.StdOut))
}
}
func TestFailOnTimeout(t *testing.T) {
_, err := Runner("sleep").Arg("2").Timeout(200 * time.Millisecond).FailOnTimeout().Run()
if err != ErrTimeout {
t.Errorf("wrong err := %v", err)
}
}
func TestFailOnExitcode(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).FailOnExitCode().Run()
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}
}
func TestEnsureExitcode1(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).EnsureExitcode(1).Run()
if err != nil {
t.Errorf("wrong err := %v", err)
}
}
func TestEnsureExitcode2(t *testing.T) {
_, err := Runner("false").Timeout(200*time.Millisecond).EnsureExitcode(0, 2, 3).Run()
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}
}

57
cmdext/listener.go Normal file
View File

@@ -0,0 +1,57 @@
package cmdext
type CommandListener interface {
ReadRawStdout([]byte)
ReadRawStderr([]byte)
ReadStdoutLine(string)
ReadStderrLine(string)
Finished(int)
Timeout()
}
type genericCommandListener struct {
_readRawStdout *func([]byte)
_readRawStderr *func([]byte)
_readStdoutLine *func(string)
_readStderrLine *func(string)
_finished *func(int)
_timeout *func()
}
func (g genericCommandListener) ReadRawStdout(v []byte) {
if g._readRawStdout != nil {
(*g._readRawStdout)(v)
}
}
func (g genericCommandListener) ReadRawStderr(v []byte) {
if g._readRawStderr != nil {
(*g._readRawStderr)(v)
}
}
func (g genericCommandListener) ReadStdoutLine(v string) {
if g._readStdoutLine != nil {
(*g._readStdoutLine)(v)
}
}
func (g genericCommandListener) ReadStderrLine(v string) {
if g._readStderrLine != nil {
(*g._readStderrLine)(v)
}
}
func (g genericCommandListener) Finished(v int) {
if g._finished != nil {
(*g._finished)(v)
}
}
func (g genericCommandListener) Timeout() {
if g._timeout != nil {
(*g._timeout)()
}
}

158
cmdext/pipereader.go Normal file
View File

@@ -0,0 +1,158 @@
package cmdext
import (
"bufio"
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
"io"
"sync"
)
type pipeReader struct {
lineBufferSize *int
stdout io.ReadCloser
stderr io.ReadCloser
}
// Read ready stdout and stdin until finished
// also splits both pipes into lines and calld the listener
func (pr *pipeReader) Read(listener []CommandListener) (string, string, string, error) {
type combevt struct {
line string
stop bool
}
errch := make(chan error, 8)
wg := sync.WaitGroup{}
// [1] read raw stdout
wg.Add(1)
stdoutBufferReader, stdoutBufferWriter := io.Pipe()
stdout := ""
go func() {
buf := make([]byte, 128)
for true {
n, out := pr.stdout.Read(buf)
if n > 0 {
txt := string(buf[:n])
stdout += txt
_, _ = stdoutBufferWriter.Write(buf[:n])
for _, lstr := range listener {
lstr.ReadRawStdout(buf[:n])
}
}
if out == io.EOF {
break
}
if out != nil {
errch <- out
break
}
}
_ = stdoutBufferWriter.Close()
wg.Done()
}()
// [2] read raw stderr
wg.Add(1)
stderrBufferReader, stderrBufferWriter := io.Pipe()
stderr := ""
go func() {
buf := make([]byte, 128)
for true {
n, err := pr.stderr.Read(buf)
if n > 0 {
txt := string(buf[:n])
stderr += txt
_, _ = stderrBufferWriter.Write(buf[:n])
for _, lstr := range listener {
lstr.ReadRawStderr(buf[:n])
}
}
if err == io.EOF {
break
}
if err != nil {
errch <- err
break
}
}
_ = stderrBufferWriter.Close()
wg.Done()
}()
combch := make(chan combevt, 32)
// [3] collect stdout line-by-line
wg.Add(1)
go func() {
scanner := bufio.NewScanner(stdoutBufferReader)
if pr.lineBufferSize != nil {
scanner.Buffer([]byte{}, *pr.lineBufferSize)
}
for scanner.Scan() {
txt := scanner.Text()
for _, lstr := range listener {
lstr.ReadStdoutLine(txt)
}
combch <- combevt{txt, false}
}
if err := scanner.Err(); err != nil {
errch <- err
}
combch <- combevt{"", true}
wg.Done()
}()
// [4] collect stderr line-by-line
wg.Add(1)
go func() {
scanner := bufio.NewScanner(stderrBufferReader)
if pr.lineBufferSize != nil {
scanner.Buffer([]byte{}, *pr.lineBufferSize)
}
for scanner.Scan() {
txt := scanner.Text()
for _, lstr := range listener {
lstr.ReadStderrLine(txt)
}
combch <- combevt{txt, false}
}
if err := scanner.Err(); err != nil {
errch <- err
}
combch <- combevt{"", true}
wg.Done()
}()
// [5] combine stdcombined
wg.Add(1)
stdcombined := ""
go func() {
stopctr := 0
for stopctr < 2 {
vvv := <-combch
if vvv.stop {
stopctr++
} else {
stdcombined += vvv.line + "\n" // this comes from bufio.Scanner and has no newlines...
}
}
wg.Done()
}()
// wait for all (5) goroutines to finish
wg.Wait()
if err, ok := syncext.ReadNonBlocking(errch); ok {
return "", "", "", err
}
return stdout, stderr, stdcombined, nil
}

View File

@@ -8,6 +8,7 @@ import (
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
"strings"
"time" "time"
) )
@@ -22,10 +23,10 @@ import (
// //
// sub-structs are recursively parsed (if they have an env tag) and the env-variable keys are delimited by the delim parameter // sub-structs are recursively parsed (if they have an env tag) and the env-variable keys are delimited by the delim parameter
// sub-structs with `env:""` are also parsed, but the delimited is skipped (they are handled as if they were one level higher) // sub-structs with `env:""` are also parsed, but the delimited is skipped (they are handled as if they were one level higher)
func ApplyEnvOverrides[T any](c *T, delim string) error { func ApplyEnvOverrides[T any](prefix string, c *T, delim string) error {
rval := reflect.ValueOf(c).Elem() rval := reflect.ValueOf(c).Elem()
return processEnvOverrides(rval, delim, "") return processEnvOverrides(rval, delim, prefix)
} }
func processEnvOverrides(rval reflect.Value, delim string, prefix string) error { func processEnvOverrides(rval reflect.Value, delim string, prefix string) error {
@@ -70,103 +71,128 @@ func processEnvOverrides(rval reflect.Value, delim string, prefix string) error
continue continue
} }
if rvfield.Type() == reflect.TypeOf("") { if rvfield.Type().Kind() == reflect.Pointer {
rvfield.Set(reflect.ValueOf(envval)) newval, err := parseEnvToValue(envval, fullEnvKey, rvfield.Type().Elem())
if err != nil {
return err
}
// converts reflect.Value to pointer
ptrval := reflect.New(rvfield.Type().Elem())
ptrval.Elem().Set(newval)
rvfield.Set(ptrval)
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval) fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int(0)) {
envint, err := strconv.ParseInt(envval, 10, bits.UintSize)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int64(0)) {
envint, err := strconv.ParseInt(envval, 10, 64)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int64 (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int64(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int32(0)) {
envint, err := strconv.ParseInt(envval, 10, 32)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int32(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int8(0)) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int8(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(time.Duration(0)) {
dur, err := timeext.ParseDurationShortString(envval)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to duration (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(dur))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, dur.String())
} else if rvfield.Type() == reflect.TypeOf(time.UnixMilli(0)) {
tim, err := time.Parse(time.RFC3339Nano, envval)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to time.time (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(tim))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, tim.String())
} else if rvfield.Type().ConvertibleTo(reflect.TypeOf(int(0))) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,int> (value := '%s')", rvfield.Type().Name(), fullEnvKey, envval))
}
envcvl := reflect.ValueOf(envint).Convert(rvfield.Type())
rvfield.Set(envcvl)
fmt.Printf("[CONF] Overwrite config '%s' with '%v'\n", fullEnvKey, envcvl.Interface())
} else if rvfield.Type().ConvertibleTo(reflect.TypeOf("")) {
envcvl := reflect.ValueOf(envval).Convert(rvfield.Type())
rvfield.Set(envcvl)
fmt.Printf("[CONF] Overwrite config '%s' with '%v'\n", fullEnvKey, envcvl.Interface())
} else { } else {
return errors.New(fmt.Sprintf("Unknown kind/type in config: [ %s | %s ]", rvfield.Kind().String(), rvfield.Type().String()))
newval, err := parseEnvToValue(envval, fullEnvKey, rvfield.Type())
if err != nil {
return err
}
rvfield.Set(newval)
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} }
} }
return nil return nil
} }
func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (reflect.Value, error) {
if rvtype == reflect.TypeOf("") {
return reflect.ValueOf(envval), nil
} else if rvtype == reflect.TypeOf(int(0)) {
envint, err := strconv.ParseInt(envval, 10, bits.UintSize)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int(envint)), nil
} else if rvtype == reflect.TypeOf(int64(0)) {
envint, err := strconv.ParseInt(envval, 10, 64)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int64 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int64(envint)), nil
} else if rvtype == reflect.TypeOf(int32(0)) {
envint, err := strconv.ParseInt(envval, 10, 32)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int32(envint)), nil
} else if rvtype == reflect.TypeOf(int8(0)) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int8(envint)), nil
} else if rvtype == reflect.TypeOf(time.Duration(0)) {
dur, err := timeext.ParseDurationShortString(envval)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to duration (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(dur), nil
} else if rvtype == reflect.TypeOf(time.UnixMilli(0)) {
tim, err := time.Parse(time.RFC3339Nano, envval)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to time.time (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(tim), nil
} else if rvtype.ConvertibleTo(reflect.TypeOf(int(0))) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,int> (value := '%s')", rvtype.Name(), fullEnvKey, envval))
}
envcvl := reflect.ValueOf(envint).Convert(rvtype)
return envcvl, nil
} else if rvtype.ConvertibleTo(reflect.TypeOf(false)) {
if strings.TrimSpace(strings.ToLower(envval)) == "true" {
return reflect.ValueOf(true).Convert(rvtype), nil
} else if strings.TrimSpace(strings.ToLower(envval)) == "false" {
return reflect.ValueOf(false).Convert(rvtype), nil
} else if strings.TrimSpace(strings.ToLower(envval)) == "1" {
return reflect.ValueOf(true).Convert(rvtype), nil
} else if strings.TrimSpace(strings.ToLower(envval)) == "0" {
return reflect.ValueOf(false).Convert(rvtype), nil
} else {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,bool> (value := '%s')", rvtype.Name(), fullEnvKey, envval))
}
} else if rvtype.ConvertibleTo(reflect.TypeOf("")) {
envcvl := reflect.ValueOf(envval).Convert(rvtype)
return envcvl, nil
} else {
return reflect.Value{}, errors.New(fmt.Sprintf("Unknown kind/type in config: [ %s | %s ]", rvtype.Kind().String(), rvtype.String()))
}
}

View File

@@ -2,6 +2,7 @@ package confext
import ( import (
"gogs.mikescher.com/BlackForestBytes/goext/timeext" "gogs.mikescher.com/BlackForestBytes/goext/timeext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing" "testing"
"time" "time"
) )
@@ -41,13 +42,13 @@ func TestApplyEnvOverridesNoop(t *testing.T) {
output := input output := input
err := ApplyEnvOverrides(&output, ".") err := ApplyEnvOverrides("", &output, ".")
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
t.FailNow() t.FailNow()
} }
assertEqual(t, input, output) tst.AssertEqual(t, input, output)
} }
func TestApplyEnvOverridesSimple(t *testing.T) { func TestApplyEnvOverridesSimple(t *testing.T) {
@@ -67,6 +68,7 @@ func TestApplyEnvOverridesSimple(t *testing.T) {
V7 aliasstring `env:"TEST_V7"` V7 aliasstring `env:"TEST_V7"`
V8 time.Duration `env:"TEST_V8"` V8 time.Duration `env:"TEST_V8"`
V9 time.Time `env:"TEST_V9"` V9 time.Time `env:"TEST_V9"`
VA bool `env:"TEST_VA"`
} }
data := testdata{ data := testdata{
@@ -81,6 +83,7 @@ func TestApplyEnvOverridesSimple(t *testing.T) {
V7: "7", V7: "7",
V8: 9, V8: 9,
V9: time.Unix(1671102873, 0), V9: time.Unix(1671102873, 0),
VA: false,
} }
t.Setenv("TEST_V1", "846") t.Setenv("TEST_V1", "846")
@@ -92,22 +95,24 @@ func TestApplyEnvOverridesSimple(t *testing.T) {
t.Setenv("TEST_V7", "AAAAAA") t.Setenv("TEST_V7", "AAAAAA")
t.Setenv("TEST_V8", "1min4s") t.Setenv("TEST_V8", "1min4s")
t.Setenv("TEST_V9", "2009-11-10T23:00:00Z") t.Setenv("TEST_V9", "2009-11-10T23:00:00Z")
t.Setenv("TEST_VA", "true")
err := ApplyEnvOverrides(&data, ".") err := ApplyEnvOverrides("", &data, ".")
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
t.FailNow() t.FailNow()
} }
assertEqual(t, data.V1, 846) tst.AssertEqual(t, data.V1, 846)
assertEqual(t, data.V2, "hello_world") tst.AssertEqual(t, data.V2, "hello_world")
assertEqual(t, data.V3, 6) tst.AssertEqual(t, data.V3, 6)
assertEqual(t, data.V4, 333) tst.AssertEqual(t, data.V4, 333)
assertEqual(t, data.V5, -937) tst.AssertEqual(t, data.V5, -937)
assertEqual(t, data.V6, 70) tst.AssertEqual(t, data.V6, 70)
assertEqual(t, data.V7, "AAAAAA") tst.AssertEqual(t, data.V7, "AAAAAA")
assertEqual(t, data.V8, time.Second*64) tst.AssertEqual(t, data.V8, time.Second*64)
assertEqual(t, data.V9, time.Unix(1257894000, 0).UTC()) tst.AssertEqual(t, data.V9, time.Unix(1257894000, 0).UTC())
tst.AssertEqual(t, data.VA, true)
} }
func TestApplyEnvOverridesRecursive(t *testing.T) { func TestApplyEnvOverridesRecursive(t *testing.T) {
@@ -182,35 +187,83 @@ func TestApplyEnvOverridesRecursive(t *testing.T) {
t.Setenv("SUB_V3", "33min") t.Setenv("SUB_V3", "33min")
t.Setenv("SUB_V4", "2044-01-01T00:00:00Z") t.Setenv("SUB_V4", "2044-01-01T00:00:00Z")
err := ApplyEnvOverrides(&data, "_") err := ApplyEnvOverrides("", &data, "_")
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
t.FailNow() t.FailNow()
} }
assertEqual(t, data.V1, 999) tst.AssertEqual(t, data.V1, 999)
assertEqual(t, data.VX, "2") tst.AssertEqual(t, data.VX, "2")
assertEqual(t, data.V5, "no") tst.AssertEqual(t, data.V5, "no")
assertEqual(t, data.Sub1.V1, 3) tst.AssertEqual(t, data.Sub1.V1, 3)
assertEqual(t, data.Sub1.VX, "4") tst.AssertEqual(t, data.Sub1.VX, "4")
assertEqual(t, data.Sub1.V2, "5") tst.AssertEqual(t, data.Sub1.V2, "5")
assertEqual(t, data.Sub1.V8, time.Second*6) tst.AssertEqual(t, data.Sub1.V8, time.Second*6)
assertEqual(t, data.Sub1.V9, time.Unix(947206861, 0).UTC()) tst.AssertEqual(t, data.Sub1.V9, time.Unix(947206861, 0).UTC())
assertEqual(t, data.Sub2.V1, 846) tst.AssertEqual(t, data.Sub2.V1, 846)
assertEqual(t, data.Sub2.VX, "9") tst.AssertEqual(t, data.Sub2.VX, "9")
assertEqual(t, data.Sub2.V2, "222_hello_world") tst.AssertEqual(t, data.Sub2.V2, "222_hello_world")
assertEqual(t, data.Sub2.V8, time.Second*64) tst.AssertEqual(t, data.Sub2.V8, time.Second*64)
assertEqual(t, data.Sub2.V9, time.Unix(1257894000, 0).UTC()) tst.AssertEqual(t, data.Sub2.V9, time.Unix(1257894000, 0).UTC())
assertEqual(t, data.Sub3.V1, 33846) tst.AssertEqual(t, data.Sub3.V1, 33846)
assertEqual(t, data.Sub3.VX, "14") tst.AssertEqual(t, data.Sub3.VX, "14")
assertEqual(t, data.Sub3.V2, "33_hello_world") tst.AssertEqual(t, data.Sub3.V2, "33_hello_world")
assertEqual(t, data.Sub3.V8, time.Second*1984) tst.AssertEqual(t, data.Sub3.V8, time.Second*1984)
assertEqual(t, data.Sub3.V9, time.Unix(2015276400, 0).UTC()) tst.AssertEqual(t, data.Sub3.V9, time.Unix(2015276400, 0).UTC())
assertEqual(t, data.Sub4.V1, 11) tst.AssertEqual(t, data.Sub4.V1, 11)
assertEqual(t, data.Sub4.VX, "19") tst.AssertEqual(t, data.Sub4.VX, "19")
assertEqual(t, data.Sub4.V2, "22") tst.AssertEqual(t, data.Sub4.V2, "22")
assertEqual(t, data.Sub4.V8, time.Second*1980) tst.AssertEqual(t, data.Sub4.V8, time.Second*1980)
assertEqual(t, data.Sub4.V9, time.Unix(2335219200, 0).UTC()) tst.AssertEqual(t, data.Sub4.V9, time.Unix(2335219200, 0).UTC())
}
func TestApplyEnvOverridesPointer(t *testing.T) {
type aliasint int
type aliasstring string
type testdata struct {
V1 *int `env:"TEST_V1"`
VX *string ``
V2 *string `env:"TEST_V2"`
V3 *int8 `env:"TEST_V3"`
V4 *int32 `env:"TEST_V4"`
V5 *int64 `env:"TEST_V5"`
V6 *aliasint `env:"TEST_V6"`
VY *aliasint ``
V7 *aliasstring `env:"TEST_V7"`
V8 *time.Duration `env:"TEST_V8"`
V9 *time.Time `env:"TEST_V9"`
}
data := testdata{}
t.Setenv("TEST_V1", "846")
t.Setenv("TEST_V2", "hello_world")
t.Setenv("TEST_V3", "6")
t.Setenv("TEST_V4", "333")
t.Setenv("TEST_V5", "-937")
t.Setenv("TEST_V6", "070")
t.Setenv("TEST_V7", "AAAAAA")
t.Setenv("TEST_V8", "1min4s")
t.Setenv("TEST_V9", "2009-11-10T23:00:00Z")
err := ApplyEnvOverrides("", &data, ".")
if err != nil {
t.Errorf("%v", err)
t.FailNow()
}
tst.AssertDeRefEqual(t, data.V1, 846)
tst.AssertDeRefEqual(t, data.V2, "hello_world")
tst.AssertDeRefEqual(t, data.V3, 6)
tst.AssertDeRefEqual(t, data.V4, 333)
tst.AssertDeRefEqual(t, data.V5, -937)
tst.AssertDeRefEqual(t, data.V6, 70)
tst.AssertDeRefEqual(t, data.V7, "AAAAAA")
tst.AssertDeRefEqual(t, data.V8, time.Second*64)
tst.AssertDeRefEqual(t, data.V9, time.Unix(1257894000, 0).UTC())
} }
func assertEqual[T comparable](t *testing.T, actual T, expected T) { func assertEqual[T comparable](t *testing.T, actual T, expected T) {
@@ -218,3 +271,12 @@ func assertEqual[T comparable](t *testing.T, actual T, expected T) {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected) t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
} }
} }
func assertPtrEqual[T comparable](t *testing.T, actual *T, expected T) {
if actual == nil {
t.Errorf("values differ: Actual: NIL, Expected: '%v'", expected)
}
if *actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -1,10 +1,13 @@
package cryptext package cryptext
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/base64" "crypto/sha256"
"encoding/base32"
"encoding/json"
"errors" "errors"
"golang.org/x/crypto/scrypt" "golang.org/x/crypto/scrypt"
"io" "io"
@@ -12,35 +15,90 @@ import (
// https://stackoverflow.com/a/18819040/1761622 // https://stackoverflow.com/a/18819040/1761622
func EncryptAESSimple(password, text []byte) ([]byte, error) { type aesPayload struct {
Salt []byte `json:"s"`
key, err := scrypt.Key(password, nil, 32768, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing IV []byte `json:"i"`
if err != nil { Data []byte `json:"d"`
return nil, err Rounds int `json:"r"`
} Version uint `json:"v"`
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
b := base64.StdEncoding.EncodeToString(text)
ciphertext := make([]byte, aes.BlockSize+len(b))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b))
return ciphertext, nil
} }
func DecryptAESSimple(password, text []byte) ([]byte, error) { func EncryptAESSimple(password []byte, data []byte, rounds int) (string, error) {
key, err := scrypt.Key(password, nil, 32768, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing salt := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, salt)
if err != nil {
return "", err
}
key, err := scrypt.Key(password, salt, rounds, 8, 1, 32)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
h := sha256.New()
h.Write(data)
checksum := h.Sum(nil)
if len(checksum) != 32 {
return "", errors.New("wrong cs size")
}
ciphertext := make([]byte, 32+len(data))
iv := make([]byte, aes.BlockSize)
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return "", err
}
combinedData := make([]byte, 0, 32+len(data))
combinedData = append(combinedData, checksum...)
combinedData = append(combinedData, data...)
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext, combinedData)
pl := aesPayload{
Salt: salt,
IV: iv,
Data: ciphertext,
Version: 1,
Rounds: rounds,
}
jbin, err := json.Marshal(pl)
if err != nil {
return "", err
}
res := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(jbin)
return res, nil
}
func DecryptAESSimple(password []byte, encText string) ([]byte, error) {
jbin, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encText)
if err != nil {
return nil, err
}
var pl aesPayload
err = json.Unmarshal(jbin, &pl)
if err != nil {
return nil, err
}
if pl.Version != 1 {
return nil, errors.New("unsupported version")
}
key, err := scrypt.Key(password, pl.Salt, pl.Rounds, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -50,18 +108,24 @@ func DecryptAESSimple(password, text []byte) ([]byte, error) {
return nil, err return nil, err
} }
if len(text) < aes.BlockSize { dest := make([]byte, len(pl.Data))
return nil, errors.New("ciphertext too short")
cfb := cipher.NewCFBDecrypter(block, pl.IV)
cfb.XORKeyStream(dest, pl.Data)
if len(dest) < 32 {
return nil, errors.New("payload too small")
} }
iv := text[:aes.BlockSize] chck := dest[:32]
text = text[aes.BlockSize:] data := dest[32:]
cfb := cipher.NewCFBDecrypter(block, iv)
cfb.XORKeyStream(text, text)
data, err := base64.StdEncoding.DecodeString(string(text)) h := sha256.New()
if err != nil { h.Write(data)
return nil, err chck2 := h.Sum(nil)
if !bytes.Equal(chck, chck2) {
return nil, errors.New("checksum mismatch")
} }
return data, nil return data, nil

View File

@@ -1,6 +1,10 @@
package cryptext package cryptext
import "testing" import (
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
func TestEncryptAESSimple(t *testing.T) { func TestEncryptAESSimple(t *testing.T) {
@@ -8,15 +12,25 @@ func TestEncryptAESSimple(t *testing.T) {
str1 := []byte("Hello World") str1 := []byte("Hello World")
str2, err := EncryptAESSimple(pw, str1) str2, err := EncryptAESSimple(pw, str1, 512)
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Printf("%s\n", str2)
str3, err := DecryptAESSimple(pw, str2) str3, err := DecryptAESSimple(pw, str2)
if err != nil { if err != nil {
panic(err) panic(err)
} }
assertEqual(t, string(str1), string(str3)) tst.AssertEqual(t, string(str1), string(str3))
str4, err := EncryptAESSimple(pw, str3, 512)
if err != nil {
panic(err)
}
tst.AssertNotEqual(t, string(str2), string(str4))
} }

View File

@@ -1,25 +1,20 @@
package cryptext package cryptext
import ( import (
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing" "testing"
) )
func TestStrSha256(t *testing.T) { func TestStrSha256(t *testing.T) {
assertEqual(t, StrSha256(""), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") tst.AssertEqual(t, StrSha256(""), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
assertEqual(t, StrSha256("0"), "5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91b46729d73a27fb57e9") tst.AssertEqual(t, StrSha256("0"), "5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91b46729d73a27fb57e9")
assertEqual(t, StrSha256("80085"), "b3786e141d65638ad8a98173e26b5f6a53c927737b23ff31fb1843937250f44b") tst.AssertEqual(t, StrSha256("80085"), "b3786e141d65638ad8a98173e26b5f6a53c927737b23ff31fb1843937250f44b")
assertEqual(t, StrSha256("Hello World"), "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e") tst.AssertEqual(t, StrSha256("Hello World"), "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e")
} }
func TestBytesSha256(t *testing.T) { func TestBytesSha256(t *testing.T) {
assertEqual(t, BytesSha256([]byte{}), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") tst.AssertEqual(t, BytesSha256([]byte{}), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
assertEqual(t, BytesSha256([]byte{0}), "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d") tst.AssertEqual(t, BytesSha256([]byte{0}), "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d")
assertEqual(t, BytesSha256([]byte{128}), "76be8b528d0075f7aae98d6fa57a6d3c83ae480a8469e668d7b0af968995ac71") tst.AssertEqual(t, BytesSha256([]byte{128}), "76be8b528d0075f7aae98d6fa57a6d3c83ae480a8469e668d7b0af968995ac71")
assertEqual(t, BytesSha256([]byte{0, 1, 2, 4, 8, 16, 32, 64, 128, 255}), "55016a318ba538e00123c736b2a8b6db368d00e7e25727547655b653e5853603") tst.AssertEqual(t, BytesSha256([]byte{0, 1, 2, 4, 8, 16, 32, 64, 128, 255}), "55016a318ba538e00123c736b2a8b6db368d00e7e25727547655b653e5853603")
}
func assertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
} }

8
cursortoken/direction.go Normal file
View File

@@ -0,0 +1,8 @@
package cursortoken
type SortDirection string //@enum:type
const (
SortASC SortDirection = "ASC"
SortDESC SortDirection = "DESC"
)

10
cursortoken/filter.go Normal file
View File

@@ -0,0 +1,10 @@
package cursortoken
import (
"go.mongodb.org/mongo-driver/mongo"
)
type Filter interface {
FilterQuery() mongo.Pipeline
Pagination() (string, SortDirection, string, SortDirection)
}

184
cursortoken/token.go Normal file
View File

@@ -0,0 +1,184 @@
package cursortoken
import (
"encoding/base32"
"encoding/json"
"errors"
"go.mongodb.org/mongo-driver/bson/primitive"
"strings"
"time"
)
type Mode string
const (
CTMStart Mode = "START"
CTMNormal Mode = "NORMAL"
CTMEnd Mode = "END"
)
type Extra struct {
Timestamp *time.Time
Id *string
Page *int
PageSize *int
}
type CursorToken struct {
Mode Mode
ValuePrimary string
ValueSecondary string
Direction SortDirection
DirectionSecondary SortDirection
PageSize int
Extra Extra
}
type cursorTokenSerialize struct {
ValuePrimary *string `json:"v1,omitempty"`
ValueSecondary *string `json:"v2,omitempty"`
Direction *SortDirection `json:"dir,omitempty"`
DirectionSecondary *SortDirection `json:"dir2,omitempty"`
PageSize *int `json:"size,omitempty"`
ExtraTimestamp *time.Time `json:"ts,omitempty"`
ExtraId *string `json:"id,omitempty"`
ExtraPage *int `json:"pg,omitempty"`
ExtraPageSize *int `json:"sz,omitempty"`
}
func Start() CursorToken {
return CursorToken{
Mode: CTMStart,
ValuePrimary: "",
ValueSecondary: "",
Direction: "",
DirectionSecondary: "",
PageSize: 0,
Extra: Extra{},
}
}
func End() CursorToken {
return CursorToken{
Mode: CTMEnd,
ValuePrimary: "",
ValueSecondary: "",
Direction: "",
DirectionSecondary: "",
PageSize: 0,
Extra: Extra{},
}
}
func (c *CursorToken) Token() string {
if c.Mode == CTMStart {
return "@start"
}
if c.Mode == CTMEnd {
return "@end"
}
// We kinda manually implement omitempty for the CursorToken here
// because omitempty does not work for time.Time and otherwise we would always
// get weird time values when decoding a token that initially didn't have an Timestamp set
// For this usecase we treat Unix=0 as an empty timestamp
sertok := cursorTokenSerialize{}
if c.ValuePrimary != "" {
sertok.ValuePrimary = &c.ValuePrimary
}
if c.ValueSecondary != "" {
sertok.ValueSecondary = &c.ValueSecondary
}
if c.Direction != "" {
sertok.Direction = &c.Direction
}
if c.DirectionSecondary != "" {
sertok.DirectionSecondary = &c.DirectionSecondary
}
if c.PageSize != 0 {
sertok.PageSize = &c.PageSize
}
sertok.ExtraTimestamp = c.Extra.Timestamp
sertok.ExtraId = c.Extra.Id
sertok.ExtraPage = c.Extra.Page
sertok.ExtraPageSize = c.Extra.PageSize
body, err := json.Marshal(sertok)
if err != nil {
panic(err)
}
return "tok_" + base32.StdEncoding.EncodeToString(body)
}
func Decode(tok string) (CursorToken, error) {
if tok == "" {
return Start(), nil
}
if strings.ToLower(tok) == "@start" {
return Start(), nil
}
if strings.ToLower(tok) == "@end" {
return End(), nil
}
if !strings.HasPrefix(tok, "tok_") {
return CursorToken{}, errors.New("could not decode token, missing prefix")
}
body, err := base32.StdEncoding.DecodeString(tok[len("tok_"):])
if err != nil {
return CursorToken{}, err
}
var tokenDeserialize cursorTokenSerialize
err = json.Unmarshal(body, &tokenDeserialize)
if err != nil {
return CursorToken{}, err
}
token := CursorToken{Mode: CTMNormal}
if tokenDeserialize.ValuePrimary != nil {
token.ValuePrimary = *tokenDeserialize.ValuePrimary
}
if tokenDeserialize.ValueSecondary != nil {
token.ValueSecondary = *tokenDeserialize.ValueSecondary
}
if tokenDeserialize.Direction != nil {
token.Direction = *tokenDeserialize.Direction
}
if tokenDeserialize.DirectionSecondary != nil {
token.DirectionSecondary = *tokenDeserialize.DirectionSecondary
}
if tokenDeserialize.PageSize != nil {
token.PageSize = *tokenDeserialize.PageSize
}
token.Extra.Timestamp = tokenDeserialize.ExtraTimestamp
token.Extra.Id = tokenDeserialize.ExtraId
token.Extra.Page = tokenDeserialize.ExtraPage
token.Extra.PageSize = tokenDeserialize.ExtraPageSize
return token, nil
}
func (c *CursorToken) ValuePrimaryObjectId() (primitive.ObjectID, bool) {
if oid, err := primitive.ObjectIDFromHex(c.ValuePrimary); err == nil {
return oid, true
} else {
return primitive.ObjectID{}, false
}
}
func (c *CursorToken) ValueSecondaryObjectId() (primitive.ObjectID, bool) {
if oid, err := primitive.ObjectIDFromHex(c.ValueSecondary); err == nil {
return oid, true
} else {
return primitive.ObjectID{}, false
}
}

View File

@@ -12,7 +12,7 @@ func init() {
} }
func TestResultCache1(t *testing.T) { func TestResultCache1(t *testing.T) {
cache := NewLRUMap[string](8) cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t) verifyLRUList(cache, t)
key := randomKey() key := randomKey()
@@ -50,7 +50,7 @@ func TestResultCache1(t *testing.T) {
} }
func TestResultCache2(t *testing.T) { func TestResultCache2(t *testing.T) {
cache := NewLRUMap[string](8) cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t) verifyLRUList(cache, t)
key1 := "key1" key1 := "key1"
@@ -150,7 +150,7 @@ func TestResultCache2(t *testing.T) {
} }
func TestResultCache3(t *testing.T) { func TestResultCache3(t *testing.T) {
cache := NewLRUMap[string](8) cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t) verifyLRUList(cache, t)
key1 := "key1" key1 := "key1"
@@ -173,7 +173,7 @@ func TestResultCache3(t *testing.T) {
} }
// does a basic consistency check over the internal cache representation // does a basic consistency check over the internal cache representation
func verifyLRUList[TData any](cache *LRUMap[TData], t *testing.T) { func verifyLRUList[TKey comparable, TData any](cache *LRUMap[TKey, TData], t *testing.T) {
size := 0 size := 0
tailFound := false tailFound := false

View File

@@ -2,6 +2,7 @@ package dataext
import ( import (
"gogs.mikescher.com/BlackForestBytes/goext/langext" "gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing" "testing"
) )
@@ -43,10 +44,10 @@ func TestObjectMerge(t *testing.T) {
valueMerge := ObjectMerge(valueA, valueB) valueMerge := ObjectMerge(valueA, valueB)
assertPtrEqual(t, "Field1", valueMerge.Field1, valueB.Field1) tst.AssertIdentPtrEqual(t, "Field1", valueMerge.Field1, valueB.Field1)
assertPtrEqual(t, "Field2", valueMerge.Field2, valueA.Field2) tst.AssertIdentPtrEqual(t, "Field2", valueMerge.Field2, valueA.Field2)
assertPtrEqual(t, "Field3", valueMerge.Field3, valueB.Field3) tst.AssertIdentPtrEqual(t, "Field3", valueMerge.Field3, valueB.Field3)
assertPtrEqual(t, "Field4", valueMerge.Field4, nil) tst.AssertIdentPtrEqual(t, "Field4", valueMerge.Field4, nil)
} }

View File

@@ -1,8 +1,8 @@
package dataext package dataext
import ( import (
"encoding/hex"
"gogs.mikescher.com/BlackForestBytes/goext/langext" "gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing" "testing"
) )
@@ -18,14 +18,14 @@ func noErrStructHash(t *testing.T, dat any, opt ...StructHashOptions) []byte {
func TestStructHashSimple(t *testing.T) { func TestStructHashSimple(t *testing.T) {
assertEqual(t, "209bf774af36cc3a045c152d9f1269ef3684ad819c1359ee73ff0283a308fefa", noErrStructHash(t, "Hello")) tst.AssertHexEqual(t, "209bf774af36cc3a045c152d9f1269ef3684ad819c1359ee73ff0283a308fefa", noErrStructHash(t, "Hello"))
assertEqual(t, "c32f3626b981ae2997db656f3acad3f1dc9d30ef6b6d14296c023e391b25f71a", noErrStructHash(t, 0)) tst.AssertHexEqual(t, "c32f3626b981ae2997db656f3acad3f1dc9d30ef6b6d14296c023e391b25f71a", noErrStructHash(t, 0))
assertEqual(t, "01b781b03e9586b257d387057dfc70d9f06051e7d3c1e709a57e13cc8daf3e35", noErrStructHash(t, []byte{})) tst.AssertHexEqual(t, "01b781b03e9586b257d387057dfc70d9f06051e7d3c1e709a57e13cc8daf3e35", noErrStructHash(t, []byte{}))
assertEqual(t, "93e1dcd45c732fe0079b0fb3204c7c803f0921835f6bfee2e6ff263e73eed53c", noErrStructHash(t, []int{})) tst.AssertHexEqual(t, "93e1dcd45c732fe0079b0fb3204c7c803f0921835f6bfee2e6ff263e73eed53c", noErrStructHash(t, []int{}))
assertEqual(t, "54f637a376aad55b3160d98ebbcae8099b70d91b9400df23fb3709855d59800a", noErrStructHash(t, []int{1, 2, 3})) tst.AssertHexEqual(t, "54f637a376aad55b3160d98ebbcae8099b70d91b9400df23fb3709855d59800a", noErrStructHash(t, []int{1, 2, 3}))
assertEqual(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", noErrStructHash(t, nil)) tst.AssertHexEqual(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", noErrStructHash(t, nil))
assertEqual(t, "349a7db91aa78fd30bbaa7c7f9c7bfb2fcfe72869b4861162a96713a852f60d3", noErrStructHash(t, []any{1, "", nil})) tst.AssertHexEqual(t, "349a7db91aa78fd30bbaa7c7f9c7bfb2fcfe72869b4861162a96713a852f60d3", noErrStructHash(t, []any{1, "", nil}))
assertEqual(t, "ca51aab87808bf0062a4a024de6aac0c2bad54275cc857a4944569f89fd245ad", noErrStructHash(t, struct{}{})) tst.AssertHexEqual(t, "ca51aab87808bf0062a4a024de6aac0c2bad54275cc857a4944569f89fd245ad", noErrStructHash(t, struct{}{}))
} }
@@ -37,13 +37,13 @@ func TestStructHashSimpleStruct(t *testing.T) {
F3 *int F3 *int
} }
assertEqual(t, "a90bff751c70c738bb5cfc9b108e783fa9c19c0bc9273458e0aaee6e74aa1b92", noErrStructHash(t, t0{ tst.AssertHexEqual(t, "a90bff751c70c738bb5cfc9b108e783fa9c19c0bc9273458e0aaee6e74aa1b92", noErrStructHash(t, t0{
F1: 10, F1: 10,
F2: []string{"1", "2", "3"}, F2: []string{"1", "2", "3"},
F3: nil, F3: nil,
})) }))
assertEqual(t, "5d09090dc34ac59dd645f197a255f653387723de3afa1b614721ea5a081c675f", noErrStructHash(t, t0{ tst.AssertHexEqual(t, "5d09090dc34ac59dd645f197a255f653387723de3afa1b614721ea5a081c675f", noErrStructHash(t, t0{
F1: 10, F1: 10,
F2: []string{"1", "2", "3"}, F2: []string{"1", "2", "3"},
F3: langext.Ptr(99), F3: langext.Ptr(99),
@@ -64,7 +64,7 @@ func TestStructHashLayeredStruct(t *testing.T) {
SV3 t1_1 SV3 t1_1
} }
assertEqual(t, "fd4ca071fb40a288fee4b7a3dfdaab577b30cb8f80f81ec511e7afd72dc3b469", noErrStructHash(t, t1_2{ tst.AssertHexEqual(t, "fd4ca071fb40a288fee4b7a3dfdaab577b30cb8f80f81ec511e7afd72dc3b469", noErrStructHash(t, t1_2{
SV1: nil, SV1: nil,
SV2: nil, SV2: nil,
SV3: t1_1{ SV3: t1_1{
@@ -73,7 +73,7 @@ func TestStructHashLayeredStruct(t *testing.T) {
F15: false, F15: false,
}, },
})) }))
assertEqual(t, "3fbf7c67d8121deda075cc86319a4e32d71744feb2cebf89b43bc682f072a029", noErrStructHash(t, t1_2{ tst.AssertHexEqual(t, "3fbf7c67d8121deda075cc86319a4e32d71744feb2cebf89b43bc682f072a029", noErrStructHash(t, t1_2{
SV1: nil, SV1: nil,
SV2: &t1_1{}, SV2: &t1_1{},
SV3: t1_1{ SV3: t1_1{
@@ -82,7 +82,7 @@ func TestStructHashLayeredStruct(t *testing.T) {
F15: true, F15: true,
}, },
})) }))
assertEqual(t, "b1791ccd1b346c3ede5bbffda85555adcd8216b93ffca23f14fe175ec47c5104", noErrStructHash(t, t1_2{ tst.AssertHexEqual(t, "b1791ccd1b346c3ede5bbffda85555adcd8216b93ffca23f14fe175ec47c5104", noErrStructHash(t, t1_2{
SV1: &t1_1{}, SV1: &t1_1{},
SV2: &t1_1{}, SV2: &t1_1{},
SV3: t1_1{ SV3: t1_1{
@@ -101,7 +101,7 @@ func TestStructHashMap(t *testing.T) {
F2 map[string]int F2 map[string]int
} }
assertEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{ tst.AssertHexEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{
F1: 10, F1: 10,
F2: map[string]int{ F2: map[string]int{
"x": 1, "x": 1,
@@ -110,7 +110,7 @@ func TestStructHashMap(t *testing.T) {
}, },
})) }))
assertEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{ tst.AssertHexEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{
F1: 10, F1: 10,
F2: map[string]int{ F2: map[string]int{
"a": 99, "a": 99,
@@ -128,16 +128,9 @@ func TestStructHashMap(t *testing.T) {
m3["x"] = 1 m3["x"] = 1
m3["a"] = 2 m3["a"] = 2
assertEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{ tst.AssertHexEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{
F1: 10, F1: 10,
F2: m3, F2: m3,
})) }))
} }
func assertEqual(t *testing.T, expected string, actual []byte) {
actualStr := hex.EncodeToString(actual)
if actualStr != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actualStr, expected)
}
}

27
go.mod
View File

@@ -3,11 +3,30 @@ module gogs.mikescher.com/BlackForestBytes/goext
go 1.19 go 1.19
require ( require (
golang.org/x/sys v0.3.0 github.com/golang/snappy v0.0.4
golang.org/x/term v0.3.0 github.com/google/go-cmp v0.5.9
github.com/jmoiron/sqlx v1.3.5
github.com/klauspost/compress v1.16.6
github.com/kr/pretty v0.1.0
github.com/montanaflynn/stats v0.7.1
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/tidwall/pretty v1.0.0
github.com/xdg-go/scram v1.1.2
github.com/xdg-go/stringprep v1.0.4
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
go.mongodb.org/mongo-driver v1.11.7
golang.org/x/crypto v0.10.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.9.0
golang.org/x/term v0.9.0
) )
require ( require (
github.com/jmoiron/sqlx v1.3.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
golang.org/x/crypto v0.4.0 // indirect github.com/kr/text v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
golang.org/x/text v0.10.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

105
go.sum
View File

@@ -1,15 +1,100 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk=
github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mongodb.org/mongo-driver v1.11.7 h1:LIwYxASDLGUg/8wOhgOOZhX8tQa/9tgZPgzZoVqJvcs=
go.mongodb.org/mongo-driver v1.11.7/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

5
goextVersion.go Normal file
View File

@@ -0,0 +1,5 @@
package goext
const GoextVersion = "0.0.166"
const GoextVersionTimestamp = "2023-06-19T10:25:41+0200"

27
gojson/LICENSE Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

12
gojson/README.md Normal file
View File

@@ -0,0 +1,12 @@
JSON serializer which serializes nil-Arrays as `[]` and nil-maps als `{}`.
Idea from: https://github.com/homelight/json
Forked from https://github.com/golang/go/tree/547e8e22fe565d65d1fd4d6e71436a5a855447b0/src/encoding/json ( tag go1.20.2 )
Added:
- `MarshalSafeCollections()` method
- `Encoder.nilSafeSlices` and `Encoder.nilSafeMaps` fields

1311
gojson/decode.go Normal file

File diff suppressed because it is too large Load Diff

2574
gojson/decode_test.go Normal file

File diff suppressed because it is too large Load Diff

1459
gojson/encode.go Normal file

File diff suppressed because it is too large Load Diff

1285
gojson/encode_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,73 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json_test
import (
"encoding/json"
"fmt"
"log"
"strings"
)
type Animal int
const (
Unknown Animal = iota
Gopher
Zebra
)
func (a *Animal) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
switch strings.ToLower(s) {
default:
*a = Unknown
case "gopher":
*a = Gopher
case "zebra":
*a = Zebra
}
return nil
}
func (a Animal) MarshalJSON() ([]byte, error) {
var s string
switch a {
default:
s = "unknown"
case Gopher:
s = "gopher"
case Zebra:
s = "zebra"
}
return json.Marshal(s)
}
func Example_customMarshalJSON() {
blob := `["gopher","armadillo","zebra","unknown","gopher","bee","gopher","zebra"]`
var zoo []Animal
if err := json.Unmarshal([]byte(blob), &zoo); err != nil {
log.Fatal(err)
}
census := make(map[Animal]int)
for _, animal := range zoo {
census[animal] += 1
}
fmt.Printf("Zoo Census:\n* Gophers: %d\n* Zebras: %d\n* Unknown: %d\n",
census[Gopher], census[Zebra], census[Unknown])
// Output:
// Zoo Census:
// * Gophers: 3
// * Zebras: 2
// * Unknown: 3
}

310
gojson/example_test.go Normal file
View File

@@ -0,0 +1,310 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json_test
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"os"
"strings"
)
func ExampleMarshal() {
type ColorGroup struct {
ID int
Name string
Colors []string
}
group := ColorGroup{
ID: 1,
Name: "Reds",
Colors: []string{"Crimson", "Red", "Ruby", "Maroon"},
}
b, err := json.Marshal(group)
if err != nil {
fmt.Println("error:", err)
}
os.Stdout.Write(b)
// Output:
// {"ID":1,"Name":"Reds","Colors":["Crimson","Red","Ruby","Maroon"]}
}
func ExampleUnmarshal() {
var jsonBlob = []byte(`[
{"Name": "Platypus", "Order": "Monotremata"},
{"Name": "Quoll", "Order": "Dasyuromorphia"}
]`)
type Animal struct {
Name string
Order string
}
var animals []Animal
err := json.Unmarshal(jsonBlob, &animals)
if err != nil {
fmt.Println("error:", err)
}
fmt.Printf("%+v", animals)
// Output:
// [{Name:Platypus Order:Monotremata} {Name:Quoll Order:Dasyuromorphia}]
}
// This example uses a Decoder to decode a stream of distinct JSON values.
func ExampleDecoder() {
const jsonStream = `
{"Name": "Ed", "Text": "Knock knock."}
{"Name": "Sam", "Text": "Who's there?"}
{"Name": "Ed", "Text": "Go fmt."}
{"Name": "Sam", "Text": "Go fmt who?"}
{"Name": "Ed", "Text": "Go fmt yourself!"}
`
type Message struct {
Name, Text string
}
dec := json.NewDecoder(strings.NewReader(jsonStream))
for {
var m Message
if err := dec.Decode(&m); err == io.EOF {
break
} else if err != nil {
log.Fatal(err)
}
fmt.Printf("%s: %s\n", m.Name, m.Text)
}
// Output:
// Ed: Knock knock.
// Sam: Who's there?
// Ed: Go fmt.
// Sam: Go fmt who?
// Ed: Go fmt yourself!
}
// This example uses a Decoder to decode a stream of distinct JSON values.
func ExampleDecoder_Token() {
const jsonStream = `
{"Message": "Hello", "Array": [1, 2, 3], "Null": null, "Number": 1.234}
`
dec := json.NewDecoder(strings.NewReader(jsonStream))
for {
t, err := dec.Token()
if err == io.EOF {
break
}
if err != nil {
log.Fatal(err)
}
fmt.Printf("%T: %v", t, t)
if dec.More() {
fmt.Printf(" (more)")
}
fmt.Printf("\n")
}
// Output:
// json.Delim: { (more)
// string: Message (more)
// string: Hello (more)
// string: Array (more)
// json.Delim: [ (more)
// float64: 1 (more)
// float64: 2 (more)
// float64: 3
// json.Delim: ] (more)
// string: Null (more)
// <nil>: <nil> (more)
// string: Number (more)
// float64: 1.234
// json.Delim: }
}
// This example uses a Decoder to decode a streaming array of JSON objects.
func ExampleDecoder_Decode_stream() {
const jsonStream = `
[
{"Name": "Ed", "Text": "Knock knock."},
{"Name": "Sam", "Text": "Who's there?"},
{"Name": "Ed", "Text": "Go fmt."},
{"Name": "Sam", "Text": "Go fmt who?"},
{"Name": "Ed", "Text": "Go fmt yourself!"}
]
`
type Message struct {
Name, Text string
}
dec := json.NewDecoder(strings.NewReader(jsonStream))
// read open bracket
t, err := dec.Token()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%T: %v\n", t, t)
// while the array contains values
for dec.More() {
var m Message
// decode an array value (Message)
err := dec.Decode(&m)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%v: %v\n", m.Name, m.Text)
}
// read closing bracket
t, err = dec.Token()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%T: %v\n", t, t)
// Output:
// json.Delim: [
// Ed: Knock knock.
// Sam: Who's there?
// Ed: Go fmt.
// Sam: Go fmt who?
// Ed: Go fmt yourself!
// json.Delim: ]
}
// This example uses RawMessage to delay parsing part of a JSON message.
func ExampleRawMessage_unmarshal() {
type Color struct {
Space string
Point json.RawMessage // delay parsing until we know the color space
}
type RGB struct {
R uint8
G uint8
B uint8
}
type YCbCr struct {
Y uint8
Cb int8
Cr int8
}
var j = []byte(`[
{"Space": "YCbCr", "Point": {"Y": 255, "Cb": 0, "Cr": -10}},
{"Space": "RGB", "Point": {"R": 98, "G": 218, "B": 255}}
]`)
var colors []Color
err := json.Unmarshal(j, &colors)
if err != nil {
log.Fatalln("error:", err)
}
for _, c := range colors {
var dst any
switch c.Space {
case "RGB":
dst = new(RGB)
case "YCbCr":
dst = new(YCbCr)
}
err := json.Unmarshal(c.Point, dst)
if err != nil {
log.Fatalln("error:", err)
}
fmt.Println(c.Space, dst)
}
// Output:
// YCbCr &{255 0 -10}
// RGB &{98 218 255}
}
// This example uses RawMessage to use a precomputed JSON during marshal.
func ExampleRawMessage_marshal() {
h := json.RawMessage(`{"precomputed": true}`)
c := struct {
Header *json.RawMessage `json:"header"`
Body string `json:"body"`
}{Header: &h, Body: "Hello Gophers!"}
b, err := json.MarshalIndent(&c, "", "\t")
if err != nil {
fmt.Println("error:", err)
}
os.Stdout.Write(b)
// Output:
// {
// "header": {
// "precomputed": true
// },
// "body": "Hello Gophers!"
// }
}
func ExampleIndent() {
type Road struct {
Name string
Number int
}
roads := []Road{
{"Diamond Fork", 29},
{"Sheep Creek", 51},
}
b, err := json.Marshal(roads)
if err != nil {
log.Fatal(err)
}
var out bytes.Buffer
json.Indent(&out, b, "=", "\t")
out.WriteTo(os.Stdout)
// Output:
// [
// = {
// = "Name": "Diamond Fork",
// = "Number": 29
// = },
// = {
// = "Name": "Sheep Creek",
// = "Number": 51
// = }
// =]
}
func ExampleMarshalIndent() {
data := map[string]int{
"a": 1,
"b": 2,
}
b, err := json.MarshalIndent(data, "<prefix>", "<indent>")
if err != nil {
log.Fatal(err)
}
fmt.Println(string(b))
// Output:
// {
// <prefix><indent>"a": 1,
// <prefix><indent>"b": 2
// <prefix>}
}
func ExampleValid() {
goodJSON := `{"example": 1}`
badJSON := `{"example":2:]}}`
fmt.Println(json.Valid([]byte(goodJSON)), json.Valid([]byte(badJSON)))
// Output:
// true false
}
func ExampleHTMLEscape() {
var out bytes.Buffer
json.HTMLEscape(&out, []byte(`{"Name":"<b>HTML content</b>"}`))
out.WriteTo(os.Stdout)
// Output:
//{"Name":"\u003cb\u003eHTML content\u003c/b\u003e"}
}

View File

@@ -0,0 +1,67 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json_test
import (
"encoding/json"
"fmt"
"log"
"strings"
)
type Size int
const (
Unrecognized Size = iota
Small
Large
)
func (s *Size) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
default:
*s = Unrecognized
case "small":
*s = Small
case "large":
*s = Large
}
return nil
}
func (s Size) MarshalText() ([]byte, error) {
var name string
switch s {
default:
name = "unrecognized"
case Small:
name = "small"
case Large:
name = "large"
}
return []byte(name), nil
}
func Example_textMarshalJSON() {
blob := `["small","regular","large","unrecognized","small","normal","small","large"]`
var inventory []Size
if err := json.Unmarshal([]byte(blob), &inventory); err != nil {
log.Fatal(err)
}
counts := make(map[Size]int)
for _, size := range inventory {
counts[size] += 1
}
fmt.Printf("Inventory Counts:\n* Small: %d\n* Large: %d\n* Unrecognized: %d\n",
counts[Small], counts[Large], counts[Unrecognized])
// Output:
// Inventory Counts:
// * Small: 3
// * Large: 2
// * Unrecognized: 3
}

141
gojson/fold.go Normal file
View File

@@ -0,0 +1,141 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"unicode/utf8"
)
const (
caseMask = ^byte(0x20) // Mask to ignore case in ASCII.
kelvin = '\u212a'
smallLongEss = '\u017f'
)
// foldFunc returns one of four different case folding equivalence
// functions, from most general (and slow) to fastest:
//
// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
// 3) asciiEqualFold, no special, but includes non-letters (including _)
// 4) simpleLetterEqualFold, no specials, no non-letters.
//
// The letters S and K are special because they map to 3 runes, not just 2:
// - S maps to s and to U+017F 'ſ' Latin small letter long s
// - k maps to K and to U+212A '' Kelvin sign
//
// See https://play.golang.org/p/tTxjOc0OGo
//
// The returned function is specialized for matching against s and
// should only be given s. It's not curried for performance reasons.
func foldFunc(s []byte) func(s, t []byte) bool {
nonLetter := false
special := false // special letter
for _, b := range s {
if b >= utf8.RuneSelf {
return bytes.EqualFold
}
upper := b & caseMask
if upper < 'A' || upper > 'Z' {
nonLetter = true
} else if upper == 'K' || upper == 'S' {
// See above for why these letters are special.
special = true
}
}
if special {
return equalFoldRight
}
if nonLetter {
return asciiEqualFold
}
return simpleLetterEqualFold
}
// equalFoldRight is a specialization of bytes.EqualFold when s is
// known to be all ASCII (including punctuation), but contains an 's',
// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
// See comments on foldFunc.
func equalFoldRight(s, t []byte) bool {
for _, sb := range s {
if len(t) == 0 {
return false
}
tb := t[0]
if tb < utf8.RuneSelf {
if sb != tb {
sbUpper := sb & caseMask
if 'A' <= sbUpper && sbUpper <= 'Z' {
if sbUpper != tb&caseMask {
return false
}
} else {
return false
}
}
t = t[1:]
continue
}
// sb is ASCII and t is not. t must be either kelvin
// sign or long s; sb must be s, S, k, or K.
tr, size := utf8.DecodeRune(t)
switch sb {
case 's', 'S':
if tr != smallLongEss {
return false
}
case 'k', 'K':
if tr != kelvin {
return false
}
default:
return false
}
t = t[size:]
}
return len(t) == 0
}
// asciiEqualFold is a specialization of bytes.EqualFold for use when
// s is all ASCII (but may contain non-letters) and contains no
// special-folding letters.
// See comments on foldFunc.
func asciiEqualFold(s, t []byte) bool {
if len(s) != len(t) {
return false
}
for i, sb := range s {
tb := t[i]
if sb == tb {
continue
}
if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
if sb&caseMask != tb&caseMask {
return false
}
} else {
return false
}
}
return true
}
// simpleLetterEqualFold is a specialization of bytes.EqualFold for
// use when s is all ASCII letters (no underscores, etc) and also
// doesn't contain 'k', 'K', 's', or 'S'.
// See comments on foldFunc.
func simpleLetterEqualFold(s, t []byte) bool {
if len(s) != len(t) {
return false
}
for i, b := range s {
if b&caseMask != t[i]&caseMask {
return false
}
}
return true
}

110
gojson/fold_test.go Normal file
View File

@@ -0,0 +1,110 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"strings"
"testing"
"unicode/utf8"
)
var foldTests = []struct {
fn func(s, t []byte) bool
s, t string
want bool
}{
{equalFoldRight, "", "", true},
{equalFoldRight, "a", "a", true},
{equalFoldRight, "", "a", false},
{equalFoldRight, "a", "", false},
{equalFoldRight, "a", "A", true},
{equalFoldRight, "AB", "ab", true},
{equalFoldRight, "AB", "ac", false},
{equalFoldRight, "sbkKc", "ſbKc", true},
{equalFoldRight, "SbKkc", "ſbKc", true},
{equalFoldRight, "SbKkc", "ſbKK", false},
{equalFoldRight, "e", "é", false},
{equalFoldRight, "s", "S", true},
{simpleLetterEqualFold, "", "", true},
{simpleLetterEqualFold, "abc", "abc", true},
{simpleLetterEqualFold, "abc", "ABC", true},
{simpleLetterEqualFold, "abc", "ABCD", false},
{simpleLetterEqualFold, "abc", "xxx", false},
{asciiEqualFold, "a_B", "A_b", true},
{asciiEqualFold, "aa@", "aa`", false}, // verify 0x40 and 0x60 aren't case-equivalent
}
func TestFold(t *testing.T) {
for i, tt := range foldTests {
if got := tt.fn([]byte(tt.s), []byte(tt.t)); got != tt.want {
t.Errorf("%d. %q, %q = %v; want %v", i, tt.s, tt.t, got, tt.want)
}
truth := strings.EqualFold(tt.s, tt.t)
if truth != tt.want {
t.Errorf("strings.EqualFold doesn't agree with case %d", i)
}
}
}
func TestFoldAgainstUnicode(t *testing.T) {
var buf1, buf2 []byte
var runes []rune
for i := 0x20; i <= 0x7f; i++ {
runes = append(runes, rune(i))
}
runes = append(runes, kelvin, smallLongEss)
funcs := []struct {
name string
fold func(s, t []byte) bool
letter bool // must be ASCII letter
simple bool // must be simple ASCII letter (not 'S' or 'K')
}{
{
name: "equalFoldRight",
fold: equalFoldRight,
},
{
name: "asciiEqualFold",
fold: asciiEqualFold,
simple: true,
},
{
name: "simpleLetterEqualFold",
fold: simpleLetterEqualFold,
simple: true,
letter: true,
},
}
for _, ff := range funcs {
for _, r := range runes {
if r >= utf8.RuneSelf {
continue
}
if ff.letter && !isASCIILetter(byte(r)) {
continue
}
if ff.simple && (r == 's' || r == 'S' || r == 'k' || r == 'K') {
continue
}
for _, r2 := range runes {
buf1 = append(utf8.AppendRune(append(buf1[:0], 'x'), r), 'x')
buf2 = append(utf8.AppendRune(append(buf2[:0], 'x'), r2), 'x')
want := bytes.EqualFold(buf1, buf2)
if got := ff.fold(buf1, buf2); got != want {
t.Errorf("%s(%q, %q) = %v; want %v", ff.name, buf1, buf2, got, want)
}
}
}
}
}
func isASCIILetter(b byte) bool {
return ('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z')
}

42
gojson/fuzz.go Normal file
View File

@@ -0,0 +1,42 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gofuzz
package json
import (
"fmt"
)
func Fuzz(data []byte) (score int) {
for _, ctor := range []func() any{
func() any { return new(any) },
func() any { return new(map[string]any) },
func() any { return new([]any) },
} {
v := ctor()
err := Unmarshal(data, v)
if err != nil {
continue
}
score = 1
m, err := Marshal(v)
if err != nil {
fmt.Printf("v=%#v\n", v)
panic(err)
}
u := ctor()
err = Unmarshal(m, u)
if err != nil {
fmt.Printf("v=%#v\n", v)
fmt.Printf("m=%s\n", m)
panic(err)
}
}
return
}

83
gojson/fuzz_test.go Normal file
View File

@@ -0,0 +1,83 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"io"
"testing"
)
func FuzzUnmarshalJSON(f *testing.F) {
f.Add([]byte(`{
"object": {
"slice": [
1,
2.0,
"3",
[4],
{5: {}}
]
},
"slice": [[]],
"string": ":)",
"int": 1e5,
"float": 3e-9"
}`))
f.Fuzz(func(t *testing.T, b []byte) {
for _, typ := range []func() interface{}{
func() interface{} { return new(interface{}) },
func() interface{} { return new(map[string]interface{}) },
func() interface{} { return new([]interface{}) },
} {
i := typ()
if err := Unmarshal(b, i); err != nil {
return
}
encoded, err := Marshal(i)
if err != nil {
t.Fatalf("failed to marshal: %s", err)
}
if err := Unmarshal(encoded, i); err != nil {
t.Fatalf("failed to roundtrip: %s", err)
}
}
})
}
func FuzzDecoderToken(f *testing.F) {
f.Add([]byte(`{
"object": {
"slice": [
1,
2.0,
"3",
[4],
{5: {}}
]
},
"slice": [[]],
"string": ":)",
"int": 1e5,
"float": 3e-9"
}`))
f.Fuzz(func(t *testing.T, b []byte) {
r := bytes.NewReader(b)
d := NewDecoder(r)
for {
_, err := d.Token()
if err != nil {
if err == io.EOF {
break
}
return
}
}
})
}

44
gojson/gionic.go Normal file
View File

@@ -0,0 +1,44 @@
package json
import (
"net/http"
)
// Render interface is copied from github.com/gin-gonic/gin@v1.8.1/render/render.go
type Render interface {
// Render writes data with custom ContentType.
Render(http.ResponseWriter) error
// WriteContentType writes custom ContentType.
WriteContentType(w http.ResponseWriter)
}
type GoJsonRender struct {
Data any
NilSafeSlices bool
NilSafeMaps bool
Indent *IndentOpt
}
func (r GoJsonRender) Render(w http.ResponseWriter) error {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = []string{"application/json; charset=utf-8"}
}
jsonBytes, err := MarshalSafeCollections(r.Data, r.NilSafeSlices, r.NilSafeMaps, r.Indent)
if err != nil {
panic(err)
}
_, err = w.Write(jsonBytes)
if err != nil {
panic(err)
}
return nil
}
func (r GoJsonRender) WriteContentType(w http.ResponseWriter) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = []string{"application/json; charset=utf-8"}
}
}

143
gojson/indent.go Normal file
View File

@@ -0,0 +1,143 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
)
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
return compact(dst, src, false)
}
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
origLen := dst.Len()
scan := newScanner()
defer freeScanner(scan)
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u00`)
dst.WriteByte(hex[c>>4])
dst.WriteByte(hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u202`)
dst.WriteByte(hex[src[i+2]&0xF])
start = i + 3
}
v := scan.step(scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
if start < i {
dst.Write(src[start:i])
}
start = i + 1
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
if start < len(src) {
dst.Write(src[start:])
}
return nil
}
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
dst.WriteByte('\n')
dst.WriteString(prefix)
for i := 0; i < depth; i++ {
dst.WriteString(indent)
}
}
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
origLen := dst.Len()
scan := newScanner()
defer freeScanner(scan)
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
newline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst.WriteByte(c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst.WriteByte(c)
case ',':
dst.WriteByte(c)
newline(dst, prefix, indent, depth)
case ':':
dst.WriteByte(c)
dst.WriteByte(' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
newline(dst, prefix, indent, depth)
}
dst.WriteByte(c)
default:
dst.WriteByte(c)
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
return nil
}

118
gojson/number_test.go Normal file
View File

@@ -0,0 +1,118 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"regexp"
"testing"
)
func TestNumberIsValid(t *testing.T) {
// From: https://stackoverflow.com/a/13340826
var jsonNumberRegexp = regexp.MustCompile(`^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$`)
validTests := []string{
"0",
"-0",
"1",
"-1",
"0.1",
"-0.1",
"1234",
"-1234",
"12.34",
"-12.34",
"12E0",
"12E1",
"12e34",
"12E-0",
"12e+1",
"12e-34",
"-12E0",
"-12E1",
"-12e34",
"-12E-0",
"-12e+1",
"-12e-34",
"1.2E0",
"1.2E1",
"1.2e34",
"1.2E-0",
"1.2e+1",
"1.2e-34",
"-1.2E0",
"-1.2E1",
"-1.2e34",
"-1.2E-0",
"-1.2e+1",
"-1.2e-34",
"0E0",
"0E1",
"0e34",
"0E-0",
"0e+1",
"0e-34",
"-0E0",
"-0E1",
"-0e34",
"-0E-0",
"-0e+1",
"-0e-34",
}
for _, test := range validTests {
if !isValidNumber(test) {
t.Errorf("%s should be valid", test)
}
var f float64
if err := Unmarshal([]byte(test), &f); err != nil {
t.Errorf("%s should be valid but Unmarshal failed: %v", test, err)
}
if !jsonNumberRegexp.MatchString(test) {
t.Errorf("%s should be valid but regexp does not match", test)
}
}
invalidTests := []string{
"",
"invalid",
"1.0.1",
"1..1",
"-1-2",
"012a42",
"01.2",
"012",
"12E12.12",
"1e2e3",
"1e+-2",
"1e--23",
"1e",
"e1",
"1e+",
"1ea",
"1a",
"1.a",
"1.",
"01",
"1.e1",
}
for _, test := range invalidTests {
if isValidNumber(test) {
t.Errorf("%s should be invalid", test)
}
var f float64
if err := Unmarshal([]byte(test), &f); err == nil {
t.Errorf("%s should be invalid but unmarshal wrote %v", test, f)
}
if jsonNumberRegexp.MatchString(test) {
t.Errorf("%s should be invalid but matches regexp", test)
}
}
}

610
gojson/scanner.go Normal file
View File

@@ -0,0 +1,610 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import (
"strconv"
"sync"
)
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
scan := newScanner()
defer freeScanner(scan)
return checkValid(data, scan) == nil
}
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
// checkValid returns nil or a SyntaxError.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// A SyntaxError is a description of a JSON syntax error.
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// total bytes consumed, updated by decoder.Decode (and deliberately
// not set to zero by scan.reset)
bytes int64
}
var scannerPool = sync.Pool{
New: func() any {
return &scanner{}
},
}
func newScanner() *scanner {
scan := scannerPool.Get().(*scanner)
// scan.reset by design doesn't set bytes to zero
scan.bytes = 0
scan.reset()
return scan
}
func freeScanner(scan *scanner) {
// Avoid hanging on to too much memory in extreme cases.
if len(scan.parseState) > 1024 {
scan.parseState = nil
}
scannerPool.Put(scan)
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
const maxNestingDepth = 10000
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state p onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if !isSpace(c) {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}

301
gojson/scanner_test.go Normal file
View File

@@ -0,0 +1,301 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"math"
"math/rand"
"reflect"
"testing"
)
var validTests = []struct {
data string
ok bool
}{
{`foo`, false},
{`}{`, false},
{`{]`, false},
{`{}`, true},
{`{"foo":"bar"}`, true},
{`{"foo":"bar","bar":{"baz":["qux"]}}`, true},
}
func TestValid(t *testing.T) {
for _, tt := range validTests {
if ok := Valid([]byte(tt.data)); ok != tt.ok {
t.Errorf("Valid(%#q) = %v, want %v", tt.data, ok, tt.ok)
}
}
}
// Tests of simple examples.
type example struct {
compact string
indent string
}
var examples = []example{
{`1`, `1`},
{`{}`, `{}`},
{`[]`, `[]`},
{`{"":2}`, "{\n\t\"\": 2\n}"},
{`[3]`, "[\n\t3\n]"},
{`[1,2,3]`, "[\n\t1,\n\t2,\n\t3\n]"},
{`{"x":1}`, "{\n\t\"x\": 1\n}"},
{ex1, ex1i},
{"{\"\":\"<>&\u2028\u2029\"}", "{\n\t\"\": \"<>&\u2028\u2029\"\n}"}, // See golang.org/issue/34070
}
var ex1 = `[true,false,null,"x",1,1.5,0,-5e+2]`
var ex1i = `[
true,
false,
null,
"x",
1,
1.5,
0,
-5e+2
]`
func TestCompact(t *testing.T) {
var buf bytes.Buffer
for _, tt := range examples {
buf.Reset()
if err := Compact(&buf, []byte(tt.compact)); err != nil {
t.Errorf("Compact(%#q): %v", tt.compact, err)
} else if s := buf.String(); s != tt.compact {
t.Errorf("Compact(%#q) = %#q, want original", tt.compact, s)
}
buf.Reset()
if err := Compact(&buf, []byte(tt.indent)); err != nil {
t.Errorf("Compact(%#q): %v", tt.indent, err)
continue
} else if s := buf.String(); s != tt.compact {
t.Errorf("Compact(%#q) = %#q, want %#q", tt.indent, s, tt.compact)
}
}
}
func TestCompactSeparators(t *testing.T) {
// U+2028 and U+2029 should be escaped inside strings.
// They should not appear outside strings.
tests := []struct {
in, compact string
}{
{"{\"\u2028\": 1}", "{\"\u2028\":1}"},
{"{\"\u2029\" :2}", "{\"\u2029\":2}"},
}
for _, tt := range tests {
var buf bytes.Buffer
if err := Compact(&buf, []byte(tt.in)); err != nil {
t.Errorf("Compact(%q): %v", tt.in, err)
} else if s := buf.String(); s != tt.compact {
t.Errorf("Compact(%q) = %q, want %q", tt.in, s, tt.compact)
}
}
}
func TestIndent(t *testing.T) {
var buf bytes.Buffer
for _, tt := range examples {
buf.Reset()
if err := Indent(&buf, []byte(tt.indent), "", "\t"); err != nil {
t.Errorf("Indent(%#q): %v", tt.indent, err)
} else if s := buf.String(); s != tt.indent {
t.Errorf("Indent(%#q) = %#q, want original", tt.indent, s)
}
buf.Reset()
if err := Indent(&buf, []byte(tt.compact), "", "\t"); err != nil {
t.Errorf("Indent(%#q): %v", tt.compact, err)
continue
} else if s := buf.String(); s != tt.indent {
t.Errorf("Indent(%#q) = %#q, want %#q", tt.compact, s, tt.indent)
}
}
}
// Tests of a large random structure.
func TestCompactBig(t *testing.T) {
initBig()
var buf bytes.Buffer
if err := Compact(&buf, jsonBig); err != nil {
t.Fatalf("Compact: %v", err)
}
b := buf.Bytes()
if !bytes.Equal(b, jsonBig) {
t.Error("Compact(jsonBig) != jsonBig")
diff(t, b, jsonBig)
return
}
}
func TestIndentBig(t *testing.T) {
t.Parallel()
initBig()
var buf bytes.Buffer
if err := Indent(&buf, jsonBig, "", "\t"); err != nil {
t.Fatalf("Indent1: %v", err)
}
b := buf.Bytes()
if len(b) == len(jsonBig) {
// jsonBig is compact (no unnecessary spaces);
// indenting should make it bigger
t.Fatalf("Indent(jsonBig) did not get bigger")
}
// should be idempotent
var buf1 bytes.Buffer
if err := Indent(&buf1, b, "", "\t"); err != nil {
t.Fatalf("Indent2: %v", err)
}
b1 := buf1.Bytes()
if !bytes.Equal(b1, b) {
t.Error("Indent(Indent(jsonBig)) != Indent(jsonBig)")
diff(t, b1, b)
return
}
// should get back to original
buf1.Reset()
if err := Compact(&buf1, b); err != nil {
t.Fatalf("Compact: %v", err)
}
b1 = buf1.Bytes()
if !bytes.Equal(b1, jsonBig) {
t.Error("Compact(Indent(jsonBig)) != jsonBig")
diff(t, b1, jsonBig)
return
}
}
type indentErrorTest struct {
in string
err error
}
var indentErrorTests = []indentErrorTest{
{`{"X": "foo", "Y"}`, &SyntaxError{"invalid character '}' after object key", 17}},
{`{"X": "foo" "Y": "bar"}`, &SyntaxError{"invalid character '\"' after object key:value pair", 13}},
}
func TestIndentErrors(t *testing.T) {
for i, tt := range indentErrorTests {
slice := make([]uint8, 0)
buf := bytes.NewBuffer(slice)
if err := Indent(buf, []uint8(tt.in), "", ""); err != nil {
if !reflect.DeepEqual(err, tt.err) {
t.Errorf("#%d: Indent: %#v", i, err)
continue
}
}
}
}
func diff(t *testing.T, a, b []byte) {
for i := 0; ; i++ {
if i >= len(a) || i >= len(b) || a[i] != b[i] {
j := i - 10
if j < 0 {
j = 0
}
t.Errorf("diverge at %d: «%s» vs «%s»", i, trim(a[j:]), trim(b[j:]))
return
}
}
}
func trim(b []byte) []byte {
if len(b) > 20 {
return b[0:20]
}
return b
}
// Generate a random JSON object.
var jsonBig []byte
func initBig() {
n := 10000
if testing.Short() {
n = 100
}
b, err := Marshal(genValue(n))
if err != nil {
panic(err)
}
jsonBig = b
}
func genValue(n int) any {
if n > 1 {
switch rand.Intn(2) {
case 0:
return genArray(n)
case 1:
return genMap(n)
}
}
switch rand.Intn(3) {
case 0:
return rand.Intn(2) == 0
case 1:
return rand.NormFloat64()
case 2:
return genString(30)
}
panic("unreachable")
}
func genString(stddev float64) string {
n := int(math.Abs(rand.NormFloat64()*stddev + stddev/2))
c := make([]rune, n)
for i := range c {
f := math.Abs(rand.NormFloat64()*64 + 32)
if f > 0x10ffff {
f = 0x10ffff
}
c[i] = rune(f)
}
return string(c)
}
func genArray(n int) []any {
f := int(math.Abs(rand.NormFloat64()) * math.Min(10, float64(n/2)))
if f > n {
f = n
}
if f < 1 {
f = 1
}
x := make([]any, f)
for i := range x {
x[i] = genValue(((i+1)*n)/f - (i*n)/f)
}
return x
}
func genMap(n int) map[string]any {
f := int(math.Abs(rand.NormFloat64()) * math.Min(10, float64(n/2)))
if f > n {
f = n
}
if n > 0 && f == 0 {
f = 1
}
x := make(map[string]any)
for i := 0; i < f; i++ {
x[genString(10)] = genValue(((i+1)*n)/f - (i*n)/f)
}
return x
}

524
gojson/stream.go Normal file
View File

@@ -0,0 +1,524 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scanned int64 // amount of data already scanned
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
// Number instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v any) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to Decode.
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
// help the compiler see that scanp is never negative, so it can remove
// some bounds checks below.
for scanp >= 0 {
// Look in the buffer for a new value.
for ; scanp < len(dec.buf); scanp++ {
c := dec.buf[scanp]
dec.scan.bytes++
switch dec.scan.step(&dec.scan, c) {
case scanEnd:
// scanEnd is delayed one byte so we decrement
// the scanner bytes count by 1 to ensure that
// this value is correct in the next call of Decode.
dec.scan.bytes--
break Input
case scanEndObject, scanEndArray:
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if stateEndValue(&dec.scan, ' ') == scanEnd {
scanp++
break Input
}
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
dec.scanned += int64(dec.scanp)
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
nilSafeSlices bool
nilSafeMaps bool
indentBuf *bytes.Buffer
indentPrefix string
indentValue string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
// followed by a newline character.
//
// See the documentation for Marshal for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML, nilSafeMaps: enc.nilSafeMaps, nilSafeSlices: enc.nilSafeSlices})
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
b := e.Bytes()
if enc.indentPrefix != "" || enc.indentValue != "" {
if enc.indentBuf == nil {
enc.indentBuf = new(bytes.Buffer)
}
enc.indentBuf.Reset()
err = Indent(enc.indentBuf, b, enc.indentPrefix, enc.indentValue)
if err != nil {
return err
}
b = enc.indentBuf.Bytes()
}
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
return err
}
// SetIndent instructs the encoder to format each subsequent encoded
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
// Calling SetIndent("", "") disables indentation.
func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentPrefix = prefix
enc.indentValue = indent
}
// SetNilSafeCollection specifies whether to represent nil slices and maps as
// '[]' or '{}' respectfully (flag on) instead of 'null' (default) when marshaling json.
func (enc *Encoder) SetNilSafeCollection(nilSafeSlices bool, nilSafeMaps bool) {
enc.nilSafeSlices = nilSafeSlices
enc.nilSafeMaps = nilSafeMaps
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
// to avoid certain safety problems that can arise when embedding JSON in HTML.
//
// In non-HTML settings where the escaping interferes with the readability
// of the output, SetEscapeHTML(false) disables this behavior.
func (enc *Encoder) SetEscapeHTML(on bool) {
enc.escapeHTML = on
}
// RawMessage is a raw encoded JSON value.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns m as the JSON encoding of m.
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var _ Marshaler = (*RawMessage)(nil)
var _ Unmarshaler = (*RawMessage)(nil)
// A Token holds a value of one of these types:
//
// Delim, for the four JSON delimiters [ ] { }
// bool, for JSON booleans
// float64, for JSON numbers
// Number, for JSON numbers
// string, for JSON string literals
// nil, for JSON null
type Token any
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type Delim
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x any
if err := dec.Decode(&x); err != nil {
return nil, err
}
return x, nil
}
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (dec *Decoder) InputOffset() int64 {
return dec.scanned + int64(dec.scanp)
}

539
gojson/stream_test.go Normal file
View File

@@ -0,0 +1,539 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"reflect"
"runtime/debug"
"strings"
"testing"
)
// Test values for the stream test.
// One of each JSON kind.
var streamTest = []any{
0.1,
"hello",
nil,
true,
false,
[]any{"a", "b", "c"},
map[string]any{"": "Kelvin", "ß": "long s"},
3.14, // another value to make sure something can follow map
}
var streamEncoded = `0.1
"hello"
null
true
false
["a","b","c"]
{"ß":"long s","":"Kelvin"}
3.14
`
func TestEncoder(t *testing.T) {
for i := 0; i <= len(streamTest); i++ {
var buf strings.Builder
enc := NewEncoder(&buf)
// Check that enc.SetIndent("", "") turns off indentation.
enc.SetIndent(">", ".")
enc.SetIndent("", "")
for j, v := range streamTest[0:i] {
if err := enc.Encode(v); err != nil {
t.Fatalf("encode #%d: %v", j, err)
}
}
if have, want := buf.String(), nlines(streamEncoded, i); have != want {
t.Errorf("encoding %d items: mismatch", i)
diff(t, []byte(have), []byte(want))
break
}
}
}
func TestEncoderErrorAndReuseEncodeState(t *testing.T) {
// Disable the GC temporarily to prevent encodeState's in Pool being cleaned away during the test.
percent := debug.SetGCPercent(-1)
defer debug.SetGCPercent(percent)
// Trigger an error in Marshal with cyclic data.
type Dummy struct {
Name string
Next *Dummy
}
dummy := Dummy{Name: "Dummy"}
dummy.Next = &dummy
var buf bytes.Buffer
enc := NewEncoder(&buf)
if err := enc.Encode(dummy); err == nil {
t.Errorf("Encode(dummy) == nil; want error")
}
type Data struct {
A string
I int
}
data := Data{A: "a", I: 1}
if err := enc.Encode(data); err != nil {
t.Errorf("Marshal(%v) = %v", data, err)
}
var data2 Data
if err := Unmarshal(buf.Bytes(), &data2); err != nil {
t.Errorf("Unmarshal(%v) = %v", data2, err)
}
if data2 != data {
t.Errorf("expect: %v, but get: %v", data, data2)
}
}
var streamEncodedIndent = `0.1
"hello"
null
true
false
[
>."a",
>."b",
>."c"
>]
{
>."ß": "long s",
>."": "Kelvin"
>}
3.14
`
func TestEncoderIndent(t *testing.T) {
var buf strings.Builder
enc := NewEncoder(&buf)
enc.SetIndent(">", ".")
for _, v := range streamTest {
enc.Encode(v)
}
if have, want := buf.String(), streamEncodedIndent; have != want {
t.Error("indented encoding mismatch")
diff(t, []byte(have), []byte(want))
}
}
type strMarshaler string
func (s strMarshaler) MarshalJSON() ([]byte, error) {
return []byte(s), nil
}
type strPtrMarshaler string
func (s *strPtrMarshaler) MarshalJSON() ([]byte, error) {
return []byte(*s), nil
}
func TestEncoderSetEscapeHTML(t *testing.T) {
var c C
var ct CText
var tagStruct struct {
Valid int `json:"<>&#! "`
Invalid int `json:"\\"`
}
// This case is particularly interesting, as we force the encoder to
// take the address of the Ptr field to use its MarshalJSON method. This
// is why the '&' is important.
marshalerStruct := &struct {
NonPtr strMarshaler
Ptr strPtrMarshaler
}{`"<str>"`, `"<str>"`}
// https://golang.org/issue/34154
stringOption := struct {
Bar string `json:"bar,string"`
}{`<html>foobar</html>`}
for _, tt := range []struct {
name string
v any
wantEscape string
want string
}{
{"c", c, `"\u003c\u0026\u003e"`, `"<&>"`},
{"ct", ct, `"\"\u003c\u0026\u003e\""`, `"\"<&>\""`},
{`"<&>"`, "<&>", `"\u003c\u0026\u003e"`, `"<&>"`},
{
"tagStruct", tagStruct,
`{"\u003c\u003e\u0026#! ":0,"Invalid":0}`,
`{"<>&#! ":0,"Invalid":0}`,
},
{
`"<str>"`, marshalerStruct,
`{"NonPtr":"\u003cstr\u003e","Ptr":"\u003cstr\u003e"}`,
`{"NonPtr":"<str>","Ptr":"<str>"}`,
},
{
"stringOption", stringOption,
`{"bar":"\"\\u003chtml\\u003efoobar\\u003c/html\\u003e\""}`,
`{"bar":"\"<html>foobar</html>\""}`,
},
} {
var buf strings.Builder
enc := NewEncoder(&buf)
if err := enc.Encode(tt.v); err != nil {
t.Errorf("Encode(%s): %s", tt.name, err)
continue
}
if got := strings.TrimSpace(buf.String()); got != tt.wantEscape {
t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.wantEscape)
}
buf.Reset()
enc.SetEscapeHTML(false)
if err := enc.Encode(tt.v); err != nil {
t.Errorf("SetEscapeHTML(false) Encode(%s): %s", tt.name, err)
continue
}
if got := strings.TrimSpace(buf.String()); got != tt.want {
t.Errorf("SetEscapeHTML(false) Encode(%s) = %#q, want %#q",
tt.name, got, tt.want)
}
}
}
func TestDecoder(t *testing.T) {
for i := 0; i <= len(streamTest); i++ {
// Use stream without newlines as input,
// just to stress the decoder even more.
// Our test input does not include back-to-back numbers.
// Otherwise stripping the newlines would
// merge two adjacent JSON values.
var buf bytes.Buffer
for _, c := range nlines(streamEncoded, i) {
if c != '\n' {
buf.WriteRune(c)
}
}
out := make([]any, i)
dec := NewDecoder(&buf)
for j := range out {
if err := dec.Decode(&out[j]); err != nil {
t.Fatalf("decode #%d/%d: %v", j, i, err)
}
}
if !reflect.DeepEqual(out, streamTest[0:i]) {
t.Errorf("decoding %d items: mismatch", i)
for j := range out {
if !reflect.DeepEqual(out[j], streamTest[j]) {
t.Errorf("#%d: have %v want %v", j, out[j], streamTest[j])
}
}
break
}
}
}
func TestDecoderBuffered(t *testing.T) {
r := strings.NewReader(`{"Name": "Gopher"} extra `)
var m struct {
Name string
}
d := NewDecoder(r)
err := d.Decode(&m)
if err != nil {
t.Fatal(err)
}
if m.Name != "Gopher" {
t.Errorf("Name = %q; want Gopher", m.Name)
}
rest, err := io.ReadAll(d.Buffered())
if err != nil {
t.Fatal(err)
}
if g, w := string(rest), " extra "; g != w {
t.Errorf("Remaining = %q; want %q", g, w)
}
}
func nlines(s string, n int) string {
if n <= 0 {
return ""
}
for i, c := range s {
if c == '\n' {
if n--; n == 0 {
return s[0 : i+1]
}
}
}
return s
}
func TestRawMessage(t *testing.T) {
var data struct {
X float64
Id RawMessage
Y float32
}
const raw = `["\u0056",null]`
const msg = `{"X":0.1,"Id":["\u0056",null],"Y":0.2}`
err := Unmarshal([]byte(msg), &data)
if err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if string([]byte(data.Id)) != raw {
t.Fatalf("Raw mismatch: have %#q want %#q", []byte(data.Id), raw)
}
b, err := Marshal(&data)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if string(b) != msg {
t.Fatalf("Marshal: have %#q want %#q", b, msg)
}
}
func TestNullRawMessage(t *testing.T) {
var data struct {
X float64
Id RawMessage
IdPtr *RawMessage
Y float32
}
const msg = `{"X":0.1,"Id":null,"IdPtr":null,"Y":0.2}`
err := Unmarshal([]byte(msg), &data)
if err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if want, got := "null", string(data.Id); want != got {
t.Fatalf("Raw mismatch: have %q, want %q", got, want)
}
if data.IdPtr != nil {
t.Fatalf("Raw pointer mismatch: have non-nil, want nil")
}
b, err := Marshal(&data)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if string(b) != msg {
t.Fatalf("Marshal: have %#q want %#q", b, msg)
}
}
var blockingTests = []string{
`{"x": 1}`,
`[1, 2, 3]`,
}
func TestBlocking(t *testing.T) {
for _, enc := range blockingTests {
r, w := net.Pipe()
go w.Write([]byte(enc))
var val any
// If Decode reads beyond what w.Write writes above,
// it will block, and the test will deadlock.
if err := NewDecoder(r).Decode(&val); err != nil {
t.Errorf("decoding %s: %v", enc, err)
}
r.Close()
w.Close()
}
}
type tokenStreamCase struct {
json string
expTokens []any
}
type decodeThis struct {
v any
}
var tokenStreamCases = []tokenStreamCase{
// streaming token cases
{json: `10`, expTokens: []any{float64(10)}},
{json: ` [10] `, expTokens: []any{
Delim('['), float64(10), Delim(']')}},
{json: ` [false,10,"b"] `, expTokens: []any{
Delim('['), false, float64(10), "b", Delim(']')}},
{json: `{ "a": 1 }`, expTokens: []any{
Delim('{'), "a", float64(1), Delim('}')}},
{json: `{"a": 1, "b":"3"}`, expTokens: []any{
Delim('{'), "a", float64(1), "b", "3", Delim('}')}},
{json: ` [{"a": 1},{"a": 2}] `, expTokens: []any{
Delim('['),
Delim('{'), "a", float64(1), Delim('}'),
Delim('{'), "a", float64(2), Delim('}'),
Delim(']')}},
{json: `{"obj": {"a": 1}}`, expTokens: []any{
Delim('{'), "obj", Delim('{'), "a", float64(1), Delim('}'),
Delim('}')}},
{json: `{"obj": [{"a": 1}]}`, expTokens: []any{
Delim('{'), "obj", Delim('['),
Delim('{'), "a", float64(1), Delim('}'),
Delim(']'), Delim('}')}},
// streaming tokens with intermittent Decode()
{json: `{ "a": 1 }`, expTokens: []any{
Delim('{'), "a",
decodeThis{float64(1)},
Delim('}')}},
{json: ` [ { "a" : 1 } ] `, expTokens: []any{
Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
Delim(']')}},
{json: ` [{"a": 1},{"a": 2}] `, expTokens: []any{
Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
decodeThis{map[string]any{"a": float64(2)}},
Delim(']')}},
{json: `{ "obj" : [ { "a" : 1 } ] }`, expTokens: []any{
Delim('{'), "obj", Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
Delim(']'), Delim('}')}},
{json: `{"obj": {"a": 1}}`, expTokens: []any{
Delim('{'), "obj",
decodeThis{map[string]any{"a": float64(1)}},
Delim('}')}},
{json: `{"obj": [{"a": 1}]}`, expTokens: []any{
Delim('{'), "obj",
decodeThis{[]any{
map[string]any{"a": float64(1)},
}},
Delim('}')}},
{json: ` [{"a": 1} {"a": 2}] `, expTokens: []any{
Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
decodeThis{&SyntaxError{"expected comma after array element", 11}},
}},
{json: `{ "` + strings.Repeat("a", 513) + `" 1 }`, expTokens: []any{
Delim('{'), strings.Repeat("a", 513),
decodeThis{&SyntaxError{"expected colon after object key", 518}},
}},
{json: `{ "\a" }`, expTokens: []any{
Delim('{'),
&SyntaxError{"invalid character 'a' in string escape code", 3},
}},
{json: ` \a`, expTokens: []any{
&SyntaxError{"invalid character '\\\\' looking for beginning of value", 1},
}},
}
func TestDecodeInStream(t *testing.T) {
for ci, tcase := range tokenStreamCases {
dec := NewDecoder(strings.NewReader(tcase.json))
for i, etk := range tcase.expTokens {
var tk any
var err error
if dt, ok := etk.(decodeThis); ok {
etk = dt.v
err = dec.Decode(&tk)
} else {
tk, err = dec.Token()
}
if experr, ok := etk.(error); ok {
if err == nil || !reflect.DeepEqual(err, experr) {
t.Errorf("case %v: Expected error %#v in %q, but was %#v", ci, experr, tcase.json, err)
}
break
} else if err == io.EOF {
t.Errorf("case %v: Unexpected EOF in %q", ci, tcase.json)
break
} else if err != nil {
t.Errorf("case %v: Unexpected error '%#v' in %q", ci, err, tcase.json)
break
}
if !reflect.DeepEqual(tk, etk) {
t.Errorf(`case %v: %q @ %v expected %T(%v) was %T(%v)`, ci, tcase.json, i, etk, etk, tk, tk)
break
}
}
}
}
// Test from golang.org/issue/11893
func TestHTTPDecoding(t *testing.T) {
const raw = `{ "foo": "bar" }`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(raw))
}))
defer ts.Close()
res, err := http.Get(ts.URL)
if err != nil {
log.Fatalf("GET failed: %v", err)
}
defer res.Body.Close()
foo := struct {
Foo string
}{}
d := NewDecoder(res.Body)
err = d.Decode(&foo)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if foo.Foo != "bar" {
t.Errorf("decoded %q; want \"bar\"", foo.Foo)
}
// make sure we get the EOF the second time
err = d.Decode(&foo)
if err != io.EOF {
t.Errorf("err = %v; want io.EOF", err)
}
}
func TestEncoderSetNilSafeCollection(t *testing.T) {
var (
nilSlice []interface{}
pNilSlice *[]interface{}
nilMap map[string]interface{}
pNilMap *map[string]interface{}
)
for _, tt := range []struct {
name string
v interface{}
want string
rescuedWant string
}{
{"nilSlice", nilSlice, "null", "[]"},
{"nonNilSlice", []interface{}{}, "[]", "[]"},
{"sliceWithValues", []interface{}{1, 2, 3}, "[1,2,3]", "[1,2,3]"},
{"pNilSlice", pNilSlice, "null", "null"},
{"nilMap", nilMap, "null", "{}"},
{"nonNilMap", map[string]interface{}{}, "{}", "{}"},
{"mapWithValues", map[string]interface{}{"1": 1, "2": 2, "3": 3}, "{\"1\":1,\"2\":2,\"3\":3}", "{\"1\":1,\"2\":2,\"3\":3}"},
{"pNilMap", pNilMap, "null", "null"},
} {
var buf bytes.Buffer
enc := NewEncoder(&buf)
if err := enc.Encode(tt.v); err != nil {
t.Fatalf("Encode(%s): %s", tt.name, err)
}
if got := strings.TrimSpace(buf.String()); got != tt.want {
t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.want)
}
buf.Reset()
enc.SetNilSafeCollection(true, true)
if err := enc.Encode(tt.v); err != nil {
t.Fatalf("SetNilSafeCollection(true) Encode(%s): %s", tt.name, err)
}
if got := strings.TrimSpace(buf.String()); got != tt.rescuedWant {
t.Errorf("SetNilSafeCollection(true) Encode(%s) = %#q, want %#q",
tt.name, got, tt.want)
}
}
}

218
gojson/tables.go Normal file
View File

@@ -0,0 +1,218 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

120
gojson/tagkey_test.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"testing"
)
type basicLatin2xTag struct {
V string `json:"$%-/"`
}
type basicLatin3xTag struct {
V string `json:"0123456789"`
}
type basicLatin4xTag struct {
V string `json:"ABCDEFGHIJKLMO"`
}
type basicLatin5xTag struct {
V string `json:"PQRSTUVWXYZ_"`
}
type basicLatin6xTag struct {
V string `json:"abcdefghijklmno"`
}
type basicLatin7xTag struct {
V string `json:"pqrstuvwxyz"`
}
type miscPlaneTag struct {
V string `json:"色は匂へど"`
}
type percentSlashTag struct {
V string `json:"text/html%"` // https://golang.org/issue/2718
}
type punctuationTag struct {
V string `json:"!#$%&()*+-./:;<=>?@[]^_{|}~ "` // https://golang.org/issue/3546
}
type dashTag struct {
V string `json:"-,"`
}
type emptyTag struct {
W string
}
type misnamedTag struct {
X string `jsom:"Misnamed"`
}
type badFormatTag struct {
Y string `:"BadFormat"`
}
type badCodeTag struct {
Z string `json:" !\"#&'()*+,."`
}
type spaceTag struct {
Q string `json:"With space"`
}
type unicodeTag struct {
W string `json:"Ελλάδα"`
}
var structTagObjectKeyTests = []struct {
raw any
value string
key string
}{
{basicLatin2xTag{"2x"}, "2x", "$%-/"},
{basicLatin3xTag{"3x"}, "3x", "0123456789"},
{basicLatin4xTag{"4x"}, "4x", "ABCDEFGHIJKLMO"},
{basicLatin5xTag{"5x"}, "5x", "PQRSTUVWXYZ_"},
{basicLatin6xTag{"6x"}, "6x", "abcdefghijklmno"},
{basicLatin7xTag{"7x"}, "7x", "pqrstuvwxyz"},
{miscPlaneTag{"いろはにほへと"}, "いろはにほへと", "色は匂へど"},
{dashTag{"foo"}, "foo", "-"},
{emptyTag{"Pour Moi"}, "Pour Moi", "W"},
{misnamedTag{"Animal Kingdom"}, "Animal Kingdom", "X"},
{badFormatTag{"Orfevre"}, "Orfevre", "Y"},
{badCodeTag{"Reliable Man"}, "Reliable Man", "Z"},
{percentSlashTag{"brut"}, "brut", "text/html%"},
{punctuationTag{"Union Rags"}, "Union Rags", "!#$%&()*+-./:;<=>?@[]^_{|}~ "},
{spaceTag{"Perreddu"}, "Perreddu", "With space"},
{unicodeTag{"Loukanikos"}, "Loukanikos", "Ελλάδα"},
}
func TestStructTagObjectKey(t *testing.T) {
for _, tt := range structTagObjectKeyTests {
b, err := Marshal(tt.raw)
if err != nil {
t.Fatalf("Marshal(%#q) failed: %v", tt.raw, err)
}
var f any
err = Unmarshal(b, &f)
if err != nil {
t.Fatalf("Unmarshal(%#q) failed: %v", b, err)
}
for i, v := range f.(map[string]any) {
switch i {
case tt.key:
if s, ok := v.(string); !ok || s != tt.value {
t.Fatalf("Unexpected value: %#q, want %v", s, tt.value)
}
default:
t.Fatalf("Unexpected key: %#q, from %#q", i, b)
}
}
}
}

38
gojson/tags.go Normal file
View File

@@ -0,0 +1,38 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, ",")
return tag, tagOptions(opt)
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
name, s, _ = strings.Cut(s, ",")
if name == optionName {
return true
}
}
return false
}

28
gojson/tags_test.go Normal file
View File

@@ -0,0 +1,28 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"testing"
)
func TestTagParsing(t *testing.T) {
name, opts := parseTag("field,foobar,foo")
if name != "field" {
t.Fatalf("name = %q, want field", name)
}
for _, tt := range []struct {
opt string
want bool
}{
{"foobar", true},
{"foo", true},
{"bar", false},
} {
if opts.Contains(tt.opt) != tt.want {
t.Errorf("Contains(%q) = %v", tt.opt, !tt.want)
}
}
}

BIN
gojson/testdata/code.json.gz vendored Normal file

Binary file not shown.

View File

@@ -1,6 +1,8 @@
package langext package langext
import ( import (
"errors"
"fmt"
"reflect" "reflect"
) )
@@ -217,6 +219,15 @@ func ArrFirst[T any](arr []T, comp func(v T) bool) (T, bool) {
return *new(T), false return *new(T), false
} }
func ArrFirstOrNil[T any](arr []T, comp func(v T) bool) *T {
for _, v := range arr {
if comp(v) {
return Ptr(v)
}
}
return nil
}
func ArrLast[T any](arr []T, comp func(v T) bool) (T, bool) { func ArrLast[T any](arr []T, comp func(v T) bool) (T, bool) {
found := false found := false
result := *new(T) result := *new(T)
@@ -229,6 +240,22 @@ func ArrLast[T any](arr []T, comp func(v T) bool) (T, bool) {
return result, found return result, found
} }
func ArrLastOrNil[T any](arr []T, comp func(v T) bool) *T {
found := false
result := *new(T)
for _, v := range arr {
if comp(v) {
found = true
result = v
}
}
if found {
return Ptr(result)
} else {
return nil
}
}
func ArrFirstIndex[T comparable](arr []T, needle T) int { func ArrFirstIndex[T comparable](arr []T, needle T) int {
for i, v := range arr { for i, v := range arr {
if v == needle { if v == needle {
@@ -265,6 +292,66 @@ func ArrMap[T1 any, T2 any](arr []T1, conv func(v T1) T2) []T2 {
return r return r
} }
func MapMap[TK comparable, TV any, TR any](inmap map[TK]TV, conv func(k TK, v TV) TR) []TR {
r := make([]TR, 0, len(inmap))
for k, v := range inmap {
r = append(r, conv(k, v))
}
return r
}
func MapMapErr[TK comparable, TV any, TR any](inmap map[TK]TV, conv func(k TK, v TV) (TR, error)) ([]TR, error) {
r := make([]TR, 0, len(inmap))
for k, v := range inmap {
elem, err := conv(k, v)
if err != nil {
return nil, err
}
r = append(r, elem)
}
return r, nil
}
func ArrMapExt[T1 any, T2 any](arr []T1, conv func(idx int, v T1) T2) []T2 {
r := make([]T2, len(arr))
for i, v := range arr {
r[i] = conv(i, v)
}
return r
}
func ArrMapErr[T1 any, T2 any](arr []T1, conv func(v T1) (T2, error)) ([]T2, error) {
var err error
r := make([]T2, len(arr))
for i, v := range arr {
r[i], err = conv(v)
if err != nil {
return nil, err
}
}
return r, nil
}
func ArrFilterMap[T1 any, T2 any](arr []T1, filter func(v T1) bool, conv func(v T1) T2) []T2 {
r := make([]T2, 0, len(arr))
for _, v := range arr {
if filter(v) {
r = append(r, conv(v))
}
}
return r
}
func ArrFilter[T any](arr []T, filter func(v T) bool) []T {
r := make([]T, 0, len(arr))
for _, v := range arr {
if filter(v) {
r = append(r, v)
}
}
return r
}
func ArrSum[T NumberConstraint](arr []T) T { func ArrSum[T NumberConstraint](arr []T) T {
var r T = 0 var r T = 0
for _, v := range arr { for _, v := range arr {
@@ -272,3 +359,84 @@ func ArrSum[T NumberConstraint](arr []T) T {
} }
return r return r
} }
func ArrFlatten[T1 any, T2 any](arr []T1, conv func(v T1) []T2) []T2 {
r := make([]T2, 0, len(arr))
for _, v1 := range arr {
r = append(r, conv(v1)...)
}
return r
}
func ArrFlattenDirect[T1 any](arr [][]T1) []T1 {
r := make([]T1, 0, len(arr))
for _, v1 := range arr {
r = append(r, v1...)
}
return r
}
func ArrCastToAny[T1 any](arr []T1) []any {
r := make([]any, len(arr))
for i, v := range arr {
r[i] = any(v)
}
return r
}
func ArrCastSafe[T1 any, T2 any](arr []T1) []T2 {
r := make([]T2, 0, len(arr))
for _, v := range arr {
if vcast, ok := any(v).(T2); ok {
r = append(r, vcast)
}
}
return r
}
func ArrCastErr[T1 any, T2 any](arr []T1) ([]T2, error) {
r := make([]T2, len(arr))
for i, v := range arr {
if vcast, ok := any(v).(T2); ok {
r[i] = vcast
} else {
return nil, errors.New(fmt.Sprintf("Cannot cast element %d of type %T to type %s", i, v, *new(T2)))
}
}
return r, nil
}
func ArrCastPanic[T1 any, T2 any](arr []T1) []T2 {
r := make([]T2, len(arr))
for i, v := range arr {
if vcast, ok := any(v).(T2); ok {
r[i] = vcast
} else {
panic(fmt.Sprintf("Cannot cast element %d of type %T to type %s", i, v, *new(T2)))
}
}
return r
}
func ArrConcat[T any](arr ...[]T) []T {
c := 0
for _, v := range arr {
c += len(v)
}
r := make([]T, c)
i := 0
for _, av := range arr {
for _, v := range av {
r[i] = v
i++
}
}
return r
}
// ArrCopy does a shallow copy of the 'in' array
func ArrCopy[T any](in []T) []T {
out := make([]T, len(in))
copy(out, in)
return out
}

178
langext/base58.go Normal file
View File

@@ -0,0 +1,178 @@
package langext
import (
"bytes"
"errors"
"math/big"
)
// shamelessly stolen from https://github.com/btcsuite/
type B58Encoding struct {
bigRadix [11]*big.Int
bigRadix10 *big.Int
alphabet string
alphabetIdx0 byte
b58 [256]byte
}
var Base58DefaultEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
var Base58FlickrEncoding = newBase58Encoding("123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ")
var Base58RippleEncoding = newBase58Encoding("rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz")
var Base58BitcoinEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
func newBase58Encoding(alphabet string) *B58Encoding {
bigRadix10 := big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58)
enc := &B58Encoding{
alphabet: alphabet,
alphabetIdx0: '1',
bigRadix: [...]*big.Int{
big.NewInt(0),
big.NewInt(58),
big.NewInt(58 * 58),
big.NewInt(58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
bigRadix10,
},
bigRadix10: bigRadix10,
}
b58 := make([]byte, 0, 256)
for i := byte(0); i < 32; i++ {
for j := byte(0); j < 8; j++ {
b := i*8 + j
idx := bytes.IndexByte([]byte(alphabet), b)
if idx == -1 {
b58 = append(b58, 255)
} else {
b58 = append(b58, byte(idx))
}
}
}
enc.b58 = *((*[256]byte)(b58))
return enc
}
func (enc *B58Encoding) EncodeString(src string) (string, error) {
v, err := enc.Encode([]byte(src))
if err != nil {
return "", err
}
return string(v), nil
}
func (enc *B58Encoding) Encode(src []byte) ([]byte, error) {
x := new(big.Int)
x.SetBytes(src)
// maximum length of output is log58(2^(8*len(b))) == len(b) * 8 / log(58)
maxlen := int(float64(len(src))*1.365658237309761) + 1
answer := make([]byte, 0, maxlen)
mod := new(big.Int)
for x.Sign() > 0 {
// Calculating with big.Int is slow for each iteration.
// x, mod = x / 58, x % 58
//
// Instead we can try to do as much calculations on int64.
// x, mod = x / 58^10, x % 58^10
//
// Which will give us mod, which is 10 digit base58 number.
// We'll loop that 10 times to convert to the answer.
x.DivMod(x, enc.bigRadix10, mod)
if x.Sign() == 0 {
// When x = 0, we need to ensure we don't add any extra zeros.
m := mod.Int64()
for m > 0 {
answer = append(answer, enc.alphabet[m%58])
m /= 58
}
} else {
m := mod.Int64()
for i := 0; i < 10; i++ {
answer = append(answer, enc.alphabet[m%58])
m /= 58
}
}
}
// leading zero bytes
for _, i := range src {
if i != 0 {
break
}
answer = append(answer, enc.alphabetIdx0)
}
// reverse
alen := len(answer)
for i := 0; i < alen/2; i++ {
answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
}
return answer, nil
}
func (enc *B58Encoding) DecodeString(src string) (string, error) {
v, err := enc.Decode([]byte(src))
if err != nil {
return "", err
}
return string(v), nil
}
func (enc *B58Encoding) Decode(src []byte) ([]byte, error) {
answer := big.NewInt(0)
scratch := new(big.Int)
for t := src; len(t) > 0; {
n := len(t)
if n > 10 {
n = 10
}
total := uint64(0)
for _, v := range t[:n] {
if v > 255 {
return []byte{}, errors.New("invalid char in input")
}
tmp := enc.b58[v]
if tmp == 255 {
return []byte{}, errors.New("invalid char in input")
}
total = total*58 + uint64(tmp)
}
answer.Mul(answer, enc.bigRadix[n])
scratch.SetUint64(total)
answer.Add(answer, scratch)
t = t[n:]
}
tmpval := answer.Bytes()
var numZeros int
for numZeros = 0; numZeros < len(src); numZeros++ {
if src[numZeros] != enc.alphabetIdx0 {
break
}
}
flen := numZeros + len(tmpval)
val := make([]byte, flen)
copy(val[numZeros:], tmpval)
return val, nil
}

67
langext/base58_test.go Normal file
View File

@@ -0,0 +1,67 @@
package langext
import (
"testing"
)
func _encStr(t *testing.T, enc *B58Encoding, v string) string {
v, err := enc.EncodeString(v)
if err != nil {
t.Error(err)
}
return v
}
func _decStr(t *testing.T, enc *B58Encoding, v string) string {
v, err := enc.DecodeString(v)
if err != nil {
t.Error(err)
}
return v
}
func TestBase58DefaultEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58DefaultEncoding, "Hello"), "9Ajdvzr")
tst.AssertEqual(t, _encStr(t, Base58DefaultEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX")
}
func TestBase58DefaultDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58DefaultEncoding, "9Ajdvzr"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58DefaultEncoding, "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58RippleEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58RippleEncoding, "Hello"), "9wjdvzi")
tst.AssertEqual(t, _encStr(t, Base58RippleEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "h3as3SMcJu26nokqPxhkUVCnd3Qwdgbd4Cp3gfReYrsFi7N44bMy11jqnGj1iJHEnzeZCq1huJM7JHifVbi7hXB7ZpEA9DVtqt894reXucNWSNZ26XVaAhy1GSWqGdFeYTJCrMdDzTg3vCcQV55CJjZX")
}
func TestBase58RippleDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58RippleEncoding, "9wjdvzi"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58RippleEncoding, "h3as3SMcJu26nokqPxhkUVCnd3Qwdgbd4Cp3gfReYrsFi7N44bMy11jqnGj1iJHEnzeZCq1huJM7JHifVbi7hXB7ZpEA9DVtqt894reXucNWSNZ26XVaAhy1GSWqGdFeYTJCrMdDzTg3vCcQV55CJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58BitcoinEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58BitcoinEncoding, "Hello"), "9Ajdvzr")
tst.AssertEqual(t, _encStr(t, Base58BitcoinEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX")
}
func TestBase58BitcoinDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58BitcoinEncoding, "9Ajdvzr"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58BitcoinEncoding, "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58FlickrEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58FlickrEncoding, "Hello"), "9aJCVZR")
tst.AssertEqual(t, _encStr(t, Base58FlickrEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw")
}
func TestBase58FlickrDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58FlickrEncoding, "9aJCVZR"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58FlickrEncoding, "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func tst.AssertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -60,3 +60,12 @@ func CoalesceStringer(s fmt.Stringer, def string) string {
return s.String() return s.String()
} }
} }
func SafeCast[T any](v any, def T) T {
switch r := v.(type) {
case T:
return r
default:
return def
}
}

View File

@@ -31,16 +31,16 @@ func CompareIntArr(arr1 []int, arr2 []int) bool {
return false return false
} }
func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool { func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) int {
for i := 0; i < len(arr1) || i < len(arr2); i++ { for i := 0; i < len(arr1) || i < len(arr2); i++ {
if i < len(arr1) && i < len(arr2) { if i < len(arr1) && i < len(arr2) {
if arr1[i] < arr2[i] { if arr1[i] < arr2[i] {
return true return -1
} else if arr1[i] > arr2[i] { } else if arr1[i] > arr2[i] {
return false return +2
} else { } else {
continue continue
} }
@@ -49,15 +49,55 @@ func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool {
if i < len(arr1) { if i < len(arr1) {
return true return +1
} else { // if i < len(arr2) } else { // if i < len(arr2)
return false return -1
} }
} }
return false return 0
}
func CompareString(a, b string) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
func CompareInt(a, b int) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
func CompareInt64(a, b int64) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
func Compare[T OrderedConstraint](a, b T) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
} }

View File

@@ -1,5 +1,10 @@
package langext package langext
type MapEntry[T comparable, V any] struct {
Key T
Value V
}
func MapKeyArr[T comparable, V any](v map[T]V) []T { func MapKeyArr[T comparable, V any](v map[T]V) []T {
result := make([]T, 0, len(v)) result := make([]T, 0, len(v))
for k := range v { for k := range v {
@@ -7,3 +12,46 @@ func MapKeyArr[T comparable, V any](v map[T]V) []T {
} }
return result return result
} }
func MapValueArr[T comparable, V any](v map[T]V) []V {
result := make([]V, 0, len(v))
for _, mv := range v {
result = append(result, mv)
}
return result
}
func ArrToMap[T comparable, V any](a []V, keyfunc func(V) T) map[T]V {
result := make(map[T]V, len(a))
for _, v := range a {
result[keyfunc(v)] = v
}
return result
}
func MapToArr[T comparable, V any](v map[T]V) []MapEntry[T, V] {
result := make([]MapEntry[T, V], 0, len(v))
for mk, mv := range v {
result = append(result, MapEntry[T, V]{
Key: mk,
Value: mv,
})
}
return result
}
func CopyMap[K comparable, V any](a map[K]V) map[K]V {
result := make(map[K]V, len(a))
for k, v := range a {
result[k] = v
}
return result
}
func ForceMap[K comparable, V any](v map[K]V) map[K]V {
if v == nil {
return make(map[K]V, 0)
} else {
return v
}
}

71
langext/panic.go Normal file
View File

@@ -0,0 +1,71 @@
package langext
type PanicWrappedErr struct {
panic any
}
func (p PanicWrappedErr) Error() string {
return "A panic occured"
}
func (p PanicWrappedErr) ReoveredObj() any {
return p.panic
}
func RunPanicSafe(fn func()) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = PanicWrappedErr{panic: rec}
}
}()
fn()
return nil
}
func RunPanicSafeR1(fn func() error) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}
func RunPanicSafeR2[T1 any](fn func() (T1, error)) (r1 T1, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}
func RunPanicSafeR3[T1 any, T2 any](fn func() (T1, T2, error)) (r1 T1, r2 T2, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
r2 = *new(T2)
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}
func RunPanicSafeR4[T1 any, T2 any, T3 any](fn func() (T1, T2, T3, error)) (r1 T1, r2 T2, r3 T3, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
r2 = *new(T2)
r3 = *new(T3)
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}

View File

@@ -4,6 +4,12 @@ import (
"reflect" "reflect"
) )
// PTrue := &true
var PTrue = Ptr(true)
// PFalse := &false
var PFalse = Ptr(false)
func Ptr[T any](v T) *T { func Ptr[T any](v T) *T {
return &v return &v
} }

View File

@@ -41,6 +41,14 @@ func NewHexUUID() (string, error) {
return string(dst), nil return string(dst), nil
} }
func MustHexUUID() string {
v, err := NewHexUUID()
if err != nil {
panic(err)
}
return v
}
func NewUpperHexUUID() (string, error) { func NewUpperHexUUID() (string, error) {
uuid, err := NewUUID() uuid, err := NewUUID()
if err != nil { if err != nil {
@@ -64,6 +72,14 @@ func NewUpperHexUUID() (string, error) {
return strings.ToUpper(string(dst)), nil return strings.ToUpper(string(dst)), nil
} }
func MustUpperHexUUID() string {
v, err := NewUpperHexUUID()
if err != nil {
panic(err)
}
return v
}
func NewRawHexUUID() (string, error) { func NewRawHexUUID() (string, error) {
uuid, err := NewUUID() uuid, err := NewUUID()
if err != nil { if err != nil {
@@ -83,6 +99,14 @@ func NewRawHexUUID() (string, error) {
return strings.ToUpper(string(dst)), nil return strings.ToUpper(string(dst)), nil
} }
func MustRawHexUUID() string {
v, err := NewRawHexUUID()
if err != nil {
panic(err)
}
return v
}
func NewBracesUUID() (string, error) { func NewBracesUUID() (string, error) {
uuid, err := NewUUID() uuid, err := NewUUID()
if err != nil { if err != nil {
@@ -108,6 +132,14 @@ func NewBracesUUID() (string, error) {
return strings.ToUpper(string(dst)), nil return strings.ToUpper(string(dst)), nil
} }
func MustBracesUUID() string {
v, err := NewBracesUUID()
if err != nil {
panic(err)
}
return v
}
func NewParensUUID() (string, error) { func NewParensUUID() (string, error) {
uuid, err := NewUUID() uuid, err := NewUUID()
if err != nil { if err != nil {
@@ -132,3 +164,11 @@ func NewParensUUID() (string, error) {
return strings.ToUpper(string(dst)), nil return strings.ToUpper(string(dst)), nil
} }
func MustParensUUID() string {
v, err := NewParensUUID()
if err != nil {
panic(err)
}
return v
}

View File

@@ -22,6 +22,31 @@ func Max[T langext.OrderedConstraint](v1 T, v2 T) T {
} }
} }
func Max3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T {
result := v1
if v2 > result {
result = v2
}
if v3 > result {
result = v3
}
return result
}
func Max4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T {
result := v1
if v2 > result {
result = v2
}
if v3 > result {
result = v3
}
if v4 > result {
result = v4
}
return result
}
func Min[T langext.OrderedConstraint](v1 T, v2 T) T { func Min[T langext.OrderedConstraint](v1 T, v2 T) T {
if v1 < v2 { if v1 < v2 {
return v1 return v1
@@ -30,6 +55,31 @@ func Min[T langext.OrderedConstraint](v1 T, v2 T) T {
} }
} }
func Min3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T {
result := v1
if v2 < result {
result = v2
}
if v3 < result {
result = v3
}
return result
}
func Min4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T {
result := v1
if v2 < result {
result = v2
}
if v3 < result {
result = v3
}
if v4 < result {
result = v4
}
return result
}
func Abs[T langext.NumberConstraint](v T) T { func Abs[T langext.NumberConstraint](v T) T {
if v < 0 { if v < 0 {
return -v return -v

16
mongo/.errcheck-excludes Normal file
View File

@@ -0,0 +1,16 @@
(go.mongodb.org/mongo-driver/x/mongo/driver.Connection).Close
(*go.mongodb.org/mongo-driver/x/network/connection.connection).Close
(go.mongodb.org/mongo-driver/x/network/connection.Connection).Close
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.connection).close
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Topology).Unsubscribe
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Server).Close
(*go.mongodb.org/mongo-driver/x/network/connection.pool).closeConnection
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.pool).close
(go.mongodb.org/mongo-driver/x/network/wiremessage.ReadWriteCloser).Close
(*go.mongodb.org/mongo-driver/mongo.Cursor).Close
(*go.mongodb.org/mongo-driver/mongo.ChangeStream).Close
(*go.mongodb.org/mongo-driver/mongo.Client).Disconnect
(net.Conn).Close
encoding/pem.Encode
fmt.Fprintf
fmt.Fprint

13
mongo/.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
.vscode
debug
.idea
*.iml
*.ipr
*.iws
.idea
*.sublime-project
*.sublime-workspace
driver-test-data.tar.gz
perf
**mongocryptd.pid
*.test

3
mongo/.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "specifications"]
path = specifications
url = git@github.com:mongodb/specifications.git

123
mongo/.golangci.yml Normal file
View File

@@ -0,0 +1,123 @@
run:
timeout: 5m
linters:
disable-all: true
# TODO(GODRIVER-2156): Enable all commented-out linters.
enable:
- errcheck
# - errorlint
- gocritic
- goimports
- gosimple
- gosec
- govet
- ineffassign
- makezero
- misspell
- nakedret
- paralleltest
- prealloc
- revive
- staticcheck
- typecheck
- unused
- unconvert
- unparam
linters-settings:
errcheck:
exclude: .errcheck-excludes
gocritic:
enabled-checks:
# Detects suspicious append result assignments. E.g. "b := append(a, 1, 2, 3)"
- appendAssign
govet:
disable:
- cgocall
- composites
paralleltest:
# Ignore missing calls to `t.Parallel()` and only report incorrect uses of `t.Parallel()`.
ignore-missing: true
staticcheck:
checks: [
"all",
"-SA1019", # Disable deprecation warnings for now.
"-SA1012", # Disable "do not pass a nil Context" to allow testing nil contexts in tests.
]
issues:
exclude-use-default: false
exclude:
# Add all default excluded issues except issues related to exported types/functions not having
# comments; we want those warnings. The defaults are copied from the "--exclude-use-default"
# documentation on https://golangci-lint.run/usage/configuration/#command-line-options
## Defaults ##
# EXC0001 errcheck: Almost all programs ignore errors on these functions and in most cases it's ok
- Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked
# EXC0003 golint: False positive when tests are defined in package 'test'
- func name will be used as test\.Test.* by other packages, and that stutters; consider calling this
# EXC0004 govet: Common false positives
- (possible misuse of unsafe.Pointer|should have signature)
# EXC0005 staticcheck: Developers tend to write in C-style with an explicit 'break' in a 'switch', so it's ok to ignore
- ineffective break statement. Did you mean to break out of the outer loop
# EXC0006 gosec: Too many false-positives on 'unsafe' usage
- Use of unsafe calls should be audited
# EXC0007 gosec: Too many false-positives for parametrized shell calls
- Subprocess launch(ed with variable|ing should be audited)
# EXC0008 gosec: Duplicated errcheck checks
- (G104|G307)
# EXC0009 gosec: Too many issues in popular repos
- (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less)
# EXC0010 gosec: False positive is triggered by 'src, err := ioutil.ReadFile(filename)'
- Potential file inclusion via variable
## End Defaults ##
# Ignore capitalization warning for this weird field name.
- "var-naming: struct field CqCssWxW should be CqCSSWxW"
# Ignore warnings for common "wiremessage.Read..." usage because the safest way to use that API
# is by assigning possibly unused returned byte buffers.
- "SA4006: this value of `wm` is never used"
- "SA4006: this value of `rem` is never used"
- "ineffectual assignment to wm"
- "ineffectual assignment to rem"
skip-dirs-use-default: false
skip-dirs:
- (^|/)vendor($|/)
- (^|/)testdata($|/)
- (^|/)etc($|/)
exclude-rules:
# Ignore some linters for example code that is intentionally simplified.
- path: examples/
linters:
- revive
- errcheck
# Disable unused code linters for the copy/pasted "awsv4" package.
- path: x/mongo/driver/auth/internal/awsv4
linters:
- unused
# Disable "unused" linter for code files that depend on the "mongocrypt.MongoCrypt" type because
# the linter build doesn't work correctly with CGO enabled. As a result, all calls to a
# "mongocrypt.MongoCrypt" API appear to always panic (see mongocrypt_not_enabled.go), leading
# to confusing messages about unused code.
- path: x/mongo/driver/crypt.go|mongo/(crypt_retrievers|mongocryptd).go
linters:
- unused
# Ignore "TLS MinVersion too low", "TLS InsecureSkipVerify set true", and "Use of weak random
# number generator (math/rand instead of crypto/rand)" in tests.
- path: _test\.go
text: G401|G402|G404
linters:
- gosec
# Ignore missing comments for exported variable/function/type for code in the "internal" and
# "benchmark" directories.
- path: (internal\/|benchmark\/)
text: exported (.+) should have comment( \(or a comment on this block\))? or be unexported
# Ignore missing package comments for directories that aren't frequently used by external users.
- path: (internal\/|benchmark\/|x\/|cmd\/|mongo\/integration\/)
text: should have a package comment
# Disable unused linter for "golang.org/x/exp/rand" package in internal/randutil/rand.
- path: internal/randutil/rand
linters:
- unused

201
mongo/LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

210
mongo/Makefile Normal file
View File

@@ -0,0 +1,210 @@
ATLAS_URIS = "$(ATLAS_FREE)" "$(ATLAS_REPLSET)" "$(ATLAS_SHARD)" "$(ATLAS_TLS11)" "$(ATLAS_TLS12)" "$(ATLAS_FREE_SRV)" "$(ATLAS_REPLSET_SRV)" "$(ATLAS_SHARD_SRV)" "$(ATLAS_TLS11_SRV)" "$(ATLAS_TLS12_SRV)" "$(ATLAS_SERVERLESS)" "$(ATLAS_SERVERLESS_SRV)"
TEST_TIMEOUT = 1800
### Utility targets. ###
.PHONY: default
default: build check-license check-fmt check-modules lint test-short
.PHONY: add-license
add-license:
etc/check_license.sh -a
.PHONY: check-license
check-license:
etc/check_license.sh
.PHONY: build
build: cross-compile build-tests build-compile-check
go build ./...
go build $(BUILD_TAGS) ./...
# Use ^$ to match no tests so that no tests are actually run but all tests are
# compiled. Run with -short to ensure none of the TestMain functions try to
# connect to a server.
.PHONY: build-tests
build-tests:
go test -short $(BUILD_TAGS) -run ^$$ ./...
.PHONY: build-compile-check
build-compile-check:
etc/compile_check.sh
# Cross-compiling on Linux for architectures 386, arm, arm64, amd64, ppc64le, and s390x.
# Omit any build tags because we don't expect our build environment to support compiling the C
# libraries for other architectures.
.PHONY: cross-compile
cross-compile:
GOOS=linux GOARCH=386 go build ./...
GOOS=linux GOARCH=arm go build ./...
GOOS=linux GOARCH=arm64 go build ./...
GOOS=linux GOARCH=amd64 go build ./...
GOOS=linux GOARCH=ppc64le go build ./...
GOOS=linux GOARCH=s390x go build ./...
.PHONY: install-lll
install-lll:
go install github.com/walle/lll/...@latest
.PHONY: check-fmt
check-fmt: install-lll
etc/check_fmt.sh
# check-modules runs "go mod tidy" then "go mod vendor" and exits with a non-zero exit code if there
# are any module or vendored modules changes. The intent is to confirm two properties:
#
# 1. Exactly the required modules are declared as dependencies. We should always be able to run
# "go mod tidy" and expect that no unrelated changes are made to the "go.mod" file.
#
# 2. All required modules are copied into the vendor/ directory and are an exact copy of the
# original module source code (i.e. the vendored modules are not modified from their original code).
.PHONY: check-modules
check-modules:
go mod tidy -v
go mod vendor
git diff --exit-code go.mod go.sum ./vendor
.PHONY: doc
doc:
godoc -http=:6060 -index
.PHONY: fmt
fmt:
go fmt ./...
.PHONY: install-golangci-lint
install-golangci-lint:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.51.0
# Lint with various GOOS and GOARCH targets to catch static analysis failures that may only affect
# specific operating systems or architectures. For example, staticcheck will only check for 64-bit
# alignment of atomically accessed variables on 32-bit architectures (see
# https://staticcheck.io/docs/checks#SA1027)
.PHONY: lint
lint: install-golangci-lint
GOOS=linux GOARCH=386 golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=arm golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=arm64 golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=amd64 golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=ppc64le golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=s390x golangci-lint run --config .golangci.yml ./...
.PHONY: update-notices
update-notices:
etc/generate_notices.pl > THIRD-PARTY-NOTICES
### Local testing targets. ###
.PHONY: test
test:
go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -p 1 ./...
.PHONY: test-cover
test-cover:
go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -cover $(COVER_ARGS) -p 1 ./...
.PHONY: test-race
test-race:
go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -race -p 1 ./...
.PHONY: test-short
test-short:
go test $(BUILD_TAGS) -timeout 60s -short ./...
### Evergreen specific targets. ###
.PHONY: build-aws-ecs-test
build-aws-ecs-test:
go build $(BUILD_TAGS) ./cmd/testaws/main.go
.PHONY: evg-test
evg-test:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s -p 1 ./... >> test.suite
.PHONY: evg-test-atlas
evg-test-atlas:
go run ./cmd/testatlas/main.go $(ATLAS_URIS)
.PHONY: evg-test-atlas-data-lake
evg-test-atlas-data-lake:
ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestUnifiedSpecs/atlas-data-lake-testing >> spec_test.suite
ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestAtlasDataLake >> spec_test.suite
.PHONY: evg-test-enterprise-auth
evg-test-enterprise-auth:
go run -tags gssapi ./cmd/testentauth/main.go
.PHONY: evg-test-kmip
evg-test-kmip:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/data_key_and_double_encryption >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/corpus >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/custom_endpoint >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/kms_tls_options_test >> test.suite
.PHONY: evg-test-kms
evg-test-kms:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite
.PHONY: evg-test-load-balancers
evg-test-load-balancers:
# Load balancer should be tested with all unified tests as well as tests in the following
# components: retryable reads, retryable writes, change streams, initial DNS seedlist discovery.
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/retryable-reads -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestChangeStreamSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestInitialDNSSeedlistDiscoverySpec/load_balanced -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
.PHONY: evg-test-ocsp
evg-test-ocsp:
go test -v ./mongo -run TestOCSP $(OCSP_TLS_SHOULD_SUCCEED) >> test.suite
.PHONY: evg-test-serverless
evg-test-serverless:
# Serverless should be tested with all unified tests as well as tests in the following components: CRUD, load balancer,
# retryable reads, retryable writes, sessions, transactions and cursor behavior.
go test $(BUILD_TAGS) ./mongo/integration -run TestCrudSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestWriteErrorsWithLabels -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestWriteErrorsDetails -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestHintErrors -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestWriteConcernError -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestErrorsCodeNamePropagated -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/retryable-reads -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableReadsProse -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesProse -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/sessions -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestSessionsProse -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/transactions/legacy -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestConvenientTransactions -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestCursor -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse >> test.suite
.PHONY: evg-test-versioned-api
evg-test-versioned-api:
# Versioned API related tests are in the mongo, integration and unified packages.
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration/unified >> test.suite
.PHONY: build-gcpkms-test
build-gcpkms-test:
go build $(BUILD_TAGS) ./cmd/testgcpkms
### Benchmark specific targets and support. ###
.PHONY: benchmark
benchmark:perf
go test $(BUILD_TAGS) -benchmem -bench=. ./benchmark
.PHONY: driver-benchmark
driver-benchmark:perf
@go run cmd/godriver-benchmark/main.go | tee perf.suite
perf:driver-test-data.tar.gz
tar -zxf $< $(if $(eq $(UNAME_S),Darwin),-s , --transform=s)/testdata/perf/
@touch $@
driver-test-data.tar.gz:
curl --retry 5 "https://s3.amazonaws.com/boxes.10gen.com/build/driver-test-data.tar.gz" -o driver-test-data.tar.gz --silent --max-time 120

251
mongo/README.md Normal file
View File

@@ -0,0 +1,251 @@
<p align="center"><img src="etc/assets/mongo-gopher.png" width="250"></p>
<p align="center">
<a href="https://goreportcard.com/report/go.mongodb.org/mongo-driver"><img src="https://goreportcard.com/badge/go.mongodb.org/mongo-driver"></a>
<a href="https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo"><img src="etc/assets/godev-mongo-blue.svg" alt="docs"></a>
<a href="https://pkg.go.dev/go.mongodb.org/mongo-driver/bson"><img src="etc/assets/godev-bson-blue.svg" alt="docs"></a>
<a href="https://www.mongodb.com/docs/drivers/go/current/"><img src="etc/assets/docs-mongodb-green.svg"></a>
</p>
# MongoDB Go Driver
The MongoDB supported driver for Go.
-------------------------
- [Requirements](#requirements)
- [Installation](#installation)
- [Usage](#usage)
- [Feedback](#feedback)
- [Testing / Development](#testing--development)
- [Continuous Integration](#continuous-integration)
- [License](#license)
-------------------------
## Requirements
- Go 1.13 or higher. We aim to support the latest versions of Go.
- Go 1.20 or higher is required to run the driver test suite.
- MongoDB 3.6 and higher.
-------------------------
## Installation
The recommended way to get started using the MongoDB Go driver is by using Go modules to install the dependency in
your project. This can be done either by importing packages from `go.mongodb.org/mongo-driver` and having the build
step install the dependency or by explicitly running
```bash
go get go.mongodb.org/mongo-driver/mongo
```
When using a version of Go that does not support modules, the driver can be installed using `dep` by running
```bash
dep ensure -add "go.mongodb.org/mongo-driver/mongo"
```
-------------------------
## Usage
To get started with the driver, import the `mongo` package and create a `mongo.Client` with the `Connect` function:
```go
import (
"context"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017"))
```
Make sure to defer a call to `Disconnect` after instantiating your client:
```go
defer func() {
if err = client.Disconnect(ctx); err != nil {
panic(err)
}
}()
```
For more advanced configuration and authentication, see the [documentation for mongo.Connect](https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo#Connect).
Calling `Connect` does not block for server discovery. If you wish to know if a MongoDB server has been found and connected to,
use the `Ping` method:
```go
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = client.Ping(ctx, readpref.Primary())
```
To insert a document into a collection, first retrieve a `Database` and then `Collection` instance from the `Client`:
```go
collection := client.Database("testing").Collection("numbers")
```
The `Collection` instance can then be used to insert documents:
```go
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := collection.InsertOne(ctx, bson.D{{"name", "pi"}, {"value", 3.14159}})
id := res.InsertedID
```
To use `bson.D`, you will need to add `"go.mongodb.org/mongo-driver/bson"` to your imports.
Your import statement should now look like this:
```go
import (
"context"
"log"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
```
Several query methods return a cursor, which can be used like this:
```go
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cur, err := collection.Find(ctx, bson.D{})
if err != nil { log.Fatal(err) }
defer cur.Close(ctx)
for cur.Next(ctx) {
var result bson.D
err := cur.Decode(&result)
if err != nil { log.Fatal(err) }
// do something with result....
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
```
For methods that return a single item, a `SingleResult` instance is returned:
```go
var result struct {
Value float64
}
filter := bson.D{{"name", "pi"}}
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = collection.FindOne(ctx, filter).Decode(&result)
if err == mongo.ErrNoDocuments {
// Do something when no record was found
fmt.Println("record does not exist")
} else if err != nil {
log.Fatal(err)
}
// Do something with result...
```
Additional examples and documentation can be found under the examples directory and [on the MongoDB Documentation website](https://www.mongodb.com/docs/drivers/go/current/).
-------------------------
## Feedback
For help with the driver, please post in the [MongoDB Community Forums](https://developer.mongodb.com/community/forums/tag/golang/).
New features and bugs can be reported on jira: https://jira.mongodb.org/browse/GODRIVER
-------------------------
## Testing / Development
The driver tests can be run against several database configurations. The most simple configuration is a standalone mongod with no auth, no ssl, and no compression. To run these basic driver tests, make sure a standalone MongoDB server instance is running at localhost:27017. To run the tests, you can run `make` (on Windows, run `nmake`). This will run coverage, run go-lint, run go-vet, and build the examples.
### Testing Different Topologies
To test a **replica set** or **sharded cluster**, set `MONGODB_URI="<connection-string>"` for the `make` command.
For example, for a local replica set named `rs1` comprised of three nodes on ports 27017, 27018, and 27019:
```
MONGODB_URI="mongodb://localhost:27017,localhost:27018,localhost:27019/?replicaSet=rs1" make
```
### Testing Auth and TLS
To test authentication and TLS, first set up a MongoDB cluster with auth and TLS configured. Testing authentication requires a user with the `root` role on the `admin` database. Here is an example command that would run a mongod with TLS correctly configured for tests. Either set or replace PATH_TO_SERVER_KEY_FILE and PATH_TO_CA_FILE with paths to their respective files:
```
mongod \
--auth \
--tlsMode requireTLS \
--tlsCertificateKeyFile $PATH_TO_SERVER_KEY_FILE \
--tlsCAFile $PATH_TO_CA_FILE \
--tlsAllowInvalidCertificates
```
To run the tests with `make`, set:
- `MONGO_GO_DRIVER_CA_FILE` to the location of the CA file used by the database
- `MONGO_GO_DRIVER_KEY_FILE` to the location of the client key file
- `MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE` to the location of the pkcs8 client key file encrypted with the password string: `password`
- `MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE` to the location of the unencrypted pkcs8 key file
- `MONGODB_URI` to the connection string of the server
- `AUTH=auth`
- `SSL=ssl`
For example:
```
AUTH=auth SSL=ssl \
MONGO_GO_DRIVER_CA_FILE=$PATH_TO_CA_FILE \
MONGO_GO_DRIVER_KEY_FILE=$PATH_TO_CLIENT_KEY_FILE \
MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE=$PATH_TO_ENCRYPTED_KEY_FILE \
MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE=$PATH_TO_UNENCRYPTED_KEY_FILE \
MONGODB_URI="mongodb://user:password@localhost:27017/?authSource=admin" \
make
```
Notes:
- The `--tlsAllowInvalidCertificates` flag is required on the server for the test suite to work correctly.
- The test suite requires the auth database to be set with `?authSource=admin`, not `/admin`.
### Testing Compression
The MongoDB Go Driver supports wire protocol compression using Snappy, zLib, or zstd. To run tests with wire protocol compression, set `MONGO_GO_DRIVER_COMPRESSOR` to `snappy`, `zlib`, or `zstd`. For example:
```
MONGO_GO_DRIVER_COMPRESSOR=snappy make
```
Ensure the [`--networkMessageCompressors` flag](https://www.mongodb.com/docs/manual/reference/program/mongod/#cmdoption-mongod-networkmessagecompressors) on mongod or mongos includes `zlib` if testing zLib compression.
-------------------------
## Contribution
Check out the [project page](https://jira.mongodb.org/browse/GODRIVER) for tickets that need completing. See our [contribution guidelines](docs/CONTRIBUTING.md) for details.
-------------------------
## Continuous Integration
Commits to master are run automatically on [evergreen](https://evergreen.mongodb.com/waterfall/mongo-go-driver).
-------------------------
## Frequently Encountered Issues
See our [common issues](docs/common-issues.md) documentation for troubleshooting frequently encountered issues.
-------------------------
## Thanks and Acknowledgement
<a href="https://github.com/ashleymcnamara">@ashleymcnamara</a> - Mongo Gopher Artwork
-------------------------
## License
The MongoDB Go Driver is licensed under the [Apache License](LICENSE).

1554
mongo/THIRD-PARTY-NOTICES Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,307 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"compress/gzip"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
"testing"
)
type encodetest struct {
Field1String string
Field1Int64 int64
Field1Float64 float64
Field2String string
Field2Int64 int64
Field2Float64 float64
Field3String string
Field3Int64 int64
Field3Float64 float64
Field4String string
Field4Int64 int64
Field4Float64 float64
}
type nestedtest1 struct {
Nested nestedtest2
}
type nestedtest2 struct {
Nested nestedtest3
}
type nestedtest3 struct {
Nested nestedtest4
}
type nestedtest4 struct {
Nested nestedtest5
}
type nestedtest5 struct {
Nested nestedtest6
}
type nestedtest6 struct {
Nested nestedtest7
}
type nestedtest7 struct {
Nested nestedtest8
}
type nestedtest8 struct {
Nested nestedtest9
}
type nestedtest9 struct {
Nested nestedtest10
}
type nestedtest10 struct {
Nested nestedtest11
}
type nestedtest11 struct {
Nested encodetest
}
var encodetestInstance = encodetest{
Field1String: "foo",
Field1Int64: 1,
Field1Float64: 3.0,
Field2String: "bar",
Field2Int64: 2,
Field2Float64: 3.1,
Field3String: "baz",
Field3Int64: 3,
Field3Float64: 3.14,
Field4String: "qux",
Field4Int64: 4,
Field4Float64: 3.141,
}
var nestedInstance = nestedtest1{
nestedtest2{
nestedtest3{
nestedtest4{
nestedtest5{
nestedtest6{
nestedtest7{
nestedtest8{
nestedtest9{
nestedtest10{
nestedtest11{
encodetest{
Field1String: "foo",
Field1Int64: 1,
Field1Float64: 3.0,
Field2String: "bar",
Field2Int64: 2,
Field2Float64: 3.1,
Field3String: "baz",
Field3Int64: 3,
Field3Float64: 3.14,
Field4String: "qux",
Field4Int64: 4,
Field4Float64: 3.141,
},
},
},
},
},
},
},
},
},
},
},
}
const extendedBSONDir = "../testdata/extended_bson"
// readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the
// "extended BSON" test data directory (../testdata/extended_bson) and returns it as a
// map[string]interface{}. It panics on any errors.
func readExtJSONFile(filename string) map[string]interface{} {
filePath := path.Join(extendedBSONDir, filename)
file, err := os.Open(filePath)
if err != nil {
panic(fmt.Sprintf("error opening file %q: %s", filePath, err))
}
defer func() {
_ = file.Close()
}()
gz, err := gzip.NewReader(file)
if err != nil {
panic(fmt.Sprintf("error creating GZIP reader: %s", err))
}
defer func() {
_ = gz.Close()
}()
data, err := ioutil.ReadAll(gz)
if err != nil {
panic(fmt.Sprintf("error reading GZIP contents of file: %s", err))
}
var v map[string]interface{}
err = UnmarshalExtJSON(data, false, &v)
if err != nil {
panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err))
}
return v
}
func BenchmarkMarshal(b *testing.B) {
cases := []struct {
desc string
value interface{}
}{
{
desc: "simple struct",
value: encodetestInstance,
},
{
desc: "nested struct",
value: nestedInstance,
},
{
desc: "deep_bson.json.gz",
value: readExtJSONFile("deep_bson.json.gz"),
},
{
desc: "flat_bson.json.gz",
value: readExtJSONFile("flat_bson.json.gz"),
},
{
desc: "full_bson.json.gz",
value: readExtJSONFile("full_bson.json.gz"),
},
}
for _, tc := range cases {
b.Run(tc.desc, func(b *testing.B) {
b.Run("BSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling BSON: %s", err)
}
}
})
b.Run("extJSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := MarshalExtJSON(tc.value, true, false)
if err != nil {
b.Errorf("error marshalling extended JSON: %s", err)
}
}
})
b.Run("JSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := json.Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling JSON: %s", err)
}
}
})
})
}
}
func BenchmarkUnmarshal(b *testing.B) {
cases := []struct {
desc string
value interface{}
}{
{
desc: "simple struct",
value: encodetestInstance,
},
{
desc: "nested struct",
value: nestedInstance,
},
{
desc: "deep_bson.json.gz",
value: readExtJSONFile("deep_bson.json.gz"),
},
{
desc: "flat_bson.json.gz",
value: readExtJSONFile("flat_bson.json.gz"),
},
{
desc: "full_bson.json.gz",
value: readExtJSONFile("full_bson.json.gz"),
},
}
for _, tc := range cases {
b.Run(tc.desc, func(b *testing.B) {
b.Run("BSON", func(b *testing.B) {
data, err := Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling BSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := Unmarshal(data, &v2)
if err != nil {
b.Errorf("error unmarshalling BSON: %s", err)
}
}
})
b.Run("extJSON", func(b *testing.B) {
data, err := MarshalExtJSON(tc.value, true, false)
if err != nil {
b.Errorf("error marshalling extended JSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := UnmarshalExtJSON(data, true, &v2)
if err != nil {
b.Errorf("error unmarshalling extended JSON: %s", err)
}
}
})
b.Run("JSON", func(b *testing.B) {
data, err := json.Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling JSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := json.Unmarshal(data, &v2)
if err != nil {
b.Errorf("error unmarshalling JSON: %s", err)
}
}
})
})
}
}

50
mongo/bson/bson.go Normal file
View File

@@ -0,0 +1,50 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson // import "go.mongodb.org/mongo-driver/bson"
import (
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// D is an ordered representation of a BSON document. This type should be used when the order of the elements matters,
// such as MongoDB command documents. If the order of the elements does not matter, an M should be used instead.
//
// A D should not be constructed with duplicate key names, as that can cause undefined server behavior.
//
// Example usage:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
type D = primitive.D
// E represents a BSON element for a D. It is usually used inside a D.
type E = primitive.E
// M is an unordered representation of a BSON document. This type should be used when the order of the elements does not
// matter. This type is handled as a regular map[string]interface{} when encoding and decoding. Elements will be
// serialized in an undefined, random order. If the order of the elements matters, a D should be used instead.
//
// Example usage:
//
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
type M = primitive.M
// An A is an ordered representation of a BSON array.
//
// Example usage:
//
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
type A = primitive.A

View File

@@ -0,0 +1,530 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/hex"
"encoding/json"
"fmt"
"math"
"os"
"path"
"strconv"
"strings"
"testing"
"unicode"
"unicode/utf8"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/tidwall/pretty"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
type testCase struct {
Description string `json:"description"`
BsonType string `json:"bson_type"`
TestKey *string `json:"test_key"`
Valid []validityTestCase `json:"valid"`
DecodeErrors []decodeErrorTestCase `json:"decodeErrors"`
ParseErrors []parseErrorTestCase `json:"parseErrors"`
Deprecated *bool `json:"deprecated"`
}
type validityTestCase struct {
Description string `json:"description"`
CanonicalBson string `json:"canonical_bson"`
CanonicalExtJSON string `json:"canonical_extjson"`
RelaxedExtJSON *string `json:"relaxed_extjson"`
DegenerateBSON *string `json:"degenerate_bson"`
DegenerateExtJSON *string `json:"degenerate_extjson"`
ConvertedBSON *string `json:"converted_bson"`
ConvertedExtJSON *string `json:"converted_extjson"`
Lossy *bool `json:"lossy"`
}
type decodeErrorTestCase struct {
Description string `json:"description"`
Bson string `json:"bson"`
}
type parseErrorTestCase struct {
Description string `json:"description"`
String string `json:"string"`
}
const dataDir = "../testdata/bson-corpus/"
func findJSONFilesInDir(dir string) ([]string, error) {
files := make([]string, 0)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, entry := range entries {
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
continue
}
files = append(files, entry.Name())
}
return files, nil
}
// seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus.
func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) {
jbytes, err := jsonToBytes(extJSON, extJSONType, desc)
if err != nil {
f.Fatalf("failed to convert JSON to bytes: %v", err)
}
f.Add(jbytes)
}
// seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's
// corpus.
func seedTestCase(f *testing.F, tcase *testCase) {
for _, vtc := range tcase.Valid {
seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description)
// Seed the relaxed extended JSON.
if vtc.RelaxedExtJSON != nil {
seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description)
}
// Seed the degenerate extended JSON.
if vtc.DegenerateExtJSON != nil {
seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description)
}
// Seed the converted extended JSON.
if vtc.ConvertedExtJSON != nil {
seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description)
}
}
}
// seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then
// marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus.
func seedBSONCorpus(f *testing.F) {
fileNames, err := findJSONFilesInDir(dataDir)
if err != nil {
f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err)
}
for _, fileName := range fileNames {
filePath := path.Join(dataDir, fileName)
file, err := os.Open(filePath)
if err != nil {
f.Fatalf("failed to open file %q: %v", filePath, err)
}
var tcase testCase
if err := json.NewDecoder(file).Decode(&tcase); err != nil {
f.Fatal(err)
}
seedTestCase(f, &tcase)
}
}
func needsEscapedUnicode(bsonType string) bool {
return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F"
}
func unescapeUnicode(s, bsonType string) string {
if !needsEscapedUnicode(bsonType) {
return s
}
newS := ""
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '\\':
switch s[i+1] {
case 'u':
us := s[i : i+6]
u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1))
if err != nil {
return ""
}
for _, r := range u {
if r < ' ' {
newS += fmt.Sprintf(`\u%04x`, r)
} else {
newS += string(r)
}
}
i += 5
default:
newS += string(c)
}
default:
if c > unicode.MaxASCII {
r, size := utf8.DecodeRune([]byte(s[i:]))
newS += string(r)
i += size - 1
} else {
newS += string(c)
}
}
}
return newS
}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string {
// Unmarshal string into map
cEJMap := make(map[string]map[string]string)
err := json.Unmarshal([]byte(cEJ), &cEJMap)
require.NoError(t, err)
// Parse the float contained by the map.
expectedString := cEJMap[key]["$numberDouble"]
expectedFloat, err := strconv.ParseFloat(expectedString, 64)
require.NoError(t, err)
// Normalize the string
return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat))
}
func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
// Unmarshal string into map
rEJMap := make(map[string]float64)
err := json.Unmarshal([]byte(rEJ), &rEJMap)
if err != nil {
return normalizeCanonicalDouble(t, key, rEJ)
}
// Parse the float contained by the map.
expectedFloat := rEJMap[key]
// Normalize the string
return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat))
}
// bsonToNative decodes the BSON bytes (b) into a native Document
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType))
return doc
}
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType))
if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
testDesc, docSrcDesc, cB, actual)
t.FailNow()
}
}
// jsonToNative decodes the extended JSON string (ej) into a native Document
func jsonToNative(ej, ejType, testDesc string) (D, error) {
var doc D
if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil {
return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err)
}
return doc, nil
}
// jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice.
func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
native, err := jsonToNative(ej, ejType, testDesc)
if err != nil {
return nil, err
}
b, err := Marshal(native)
if err != nil {
return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err)
}
return b, nil
}
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType))
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
testDesc, ejType, docSrcDesc, ejShortName, diff)
t.FailNow()
}
}
func runTest(t *testing.T, file string) {
filepath := path.Join(dataDir, file)
content, err := os.ReadFile(filepath)
require.NoError(t, err)
// Remove ".json" from filename.
file = file[:len(file)-5]
testName := "bson_corpus--" + file
t.Run(testName, func(t *testing.T) {
var test testCase
require.NoError(t, json.Unmarshal(content, &test))
t.Run("valid", func(t *testing.T) {
for _, v := range test.Valid {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description))
// get canonical extended JSON
cEJ := unescapeUnicode(string(pretty.Ugly([]byte(v.CanonicalExtJSON))), test.BsonType)
if test.BsonType == "0x01" {
cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ)
}
/*** canonical BSON round-trip tests ***/
doc := bsonToNative(t, cB, "canonical", v.Description)
// native_to_bson(bson_to_native(cB)) = cB
nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)")
// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)")
// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists)
if v.RelaxedExtJSON != nil {
rEJ := unescapeUnicode(string(pretty.Ugly([]byte(*v.RelaxedExtJSON))), test.BsonType)
if test.BsonType == "0x01" {
rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ)
}
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)")
/*** relaxed extended JSON round-trip tests (if exists) ***/
doc, err = jsonToNative(rEJ, "relaxed", v.Description)
require.NoError(t, err)
// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)")
}
/*** canonical extended JSON round-trip tests ***/
doc, err = jsonToNative(cEJ, "canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)")
// native_to_bson(json_to_native(cEJ)) = cb (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)")
}
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description))
doc = bsonToNative(t, dB, "degenerate", v.Description)
// native_to_bson(bson_to_native(dB)) = cB
nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)")
}
/*** degenerate JSON round-trip tests (if exists) ***/
if v.DegenerateExtJSON != nil {
dEJ := unescapeUnicode(string(pretty.Ugly([]byte(*v.DegenerateExtJSON))), test.BsonType)
if test.BsonType == "0x01" {
dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ)
}
doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
// native_to_bson(json_to_native(dEJ)) = cB (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)")
}
}
})
}
})
t.Run("decode error", func(t *testing.T) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
expectNoError(t, err, d.Description)
var doc D
err = Unmarshal(b, &doc)
// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements
// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8
// characters.
for _, elem := range doc {
str, ok := elem.Value.(string)
invalidString := ok && !utf8.ValidString(str)
dbPtr, ok := elem.Value.(primitive.DBPointer)
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
if invalidString || invalidDBPtr {
expectNoError(t, err, d.Description)
return
}
}
expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description))
})
}
})
t.Run("parse error", func(t *testing.T) {
for _, p := range test.ParseErrors {
t.Run(p.Description, func(t *testing.T) {
s := unescapeUnicode(p.String, test.BsonType)
if test.BsonType == "0x13" {
s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s)
}
switch test.BsonType {
case "0x00", "0x05", "0x13":
var doc D
err := UnmarshalExtJSON([]byte(s), true, &doc)
// Null bytes are validated when marshaling to BSON
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
}
})
}
})
})
}
func Test_BsonCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
if err != nil {
t.Fatalf("error finding JSON files in %s: %v", dataDir, err)
}
for _, file := range jsonFiles {
runTest(t, file)
}
}
func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}
func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}
func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
canonicalExtJSON string
degenerateExtJSON string
expectedErr string
}{
{
"valid uuid",
"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"",
},
{
"invalid uuid--no hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--trailing hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--malformed hex",
"",
"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'",
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
// get canonical extended JSON
cEJ := unescapeUnicode(string(pretty.Ugly([]byte(tc.canonicalExtJSON))), "0x05")
// get degenerate extended JSON
dEJ := unescapeUnicode(string(pretty.Ugly([]byte(tc.degenerateExtJSON))), "0x05")
// convert dEJ to native doc
var doc D
err := UnmarshalExtJSON([]byte(dEJ), true, &doc)
if tc.expectedErr != "" {
assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err)
} else {
assert.Nil(t, err, "expected no error, got error: %v", err)
// Marshal doc into extended JSON and compare with cEJ
nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
}
})
}
}

279
mongo/bson/bson_test.go Normal file
View File

@@ -0,0 +1,279 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}
func TestTimeRoundTrip(t *testing.T) {
val := struct {
Value time.Time
ID string
}{
ID: "time-rt-test",
}
if !val.Value.IsZero() {
t.Errorf("Did not get zero time as expected.")
}
bsonOut, err := Marshal(val)
noerr(t, err)
rtval := struct {
Value time.Time
ID string
}{}
err = Unmarshal(bsonOut, &rtval)
noerr(t, err)
if !cmp.Equal(val, rtval) {
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
}
if !rtval.Value.IsZero() {
t.Errorf("Did not get zero time as expected.")
}
}
func TestNonNullTimeRoundTrip(t *testing.T) {
now := time.Now()
now = time.Unix(now.Unix(), 0)
val := struct {
Value time.Time
ID string
}{
ID: "time-rt-test",
Value: now,
}
bsonOut, err := Marshal(val)
noerr(t, err)
rtval := struct {
Value time.Time
ID string
}{}
err = Unmarshal(bsonOut, &rtval)
noerr(t, err)
if !cmp.Equal(val, rtval) {
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
}
}
func TestD(t *testing.T) {
t.Run("can marshal", func(t *testing.T) {
d := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendStringElement(want, "foo", "bar")
want = bsoncore.AppendStringElement(want, "hello", "world")
want = bsoncore.AppendDoubleElement(want, "pi", 3.14159)
want, err := bsoncore.AppendDocumentEnd(want, idx)
noerr(t, err)
got, err := Marshal(d)
noerr(t, err)
if !bytes.Equal(got, want) {
t.Errorf("Marshaled documents do not match. got %v; want %v", Raw(got), Raw(want))
}
})
t.Run("can unmarshal", func(t *testing.T) {
want := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "foo", "bar")
doc = bsoncore.AppendStringElement(doc, "hello", "world")
doc = bsoncore.AppendDoubleElement(doc, "pi", 3.14159)
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
var got D
err = Unmarshal(doc, &got)
noerr(t, err)
if !cmp.Equal(got, want) {
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, want)
}
})
}
type stringerString string
func (ss stringerString) String() string {
return "bar"
}
type keyBool bool
func (kb keyBool) MarshalKey() (string, error) {
return fmt.Sprintf("%v", kb), nil
}
func (kb *keyBool) UnmarshalKey(key string) error {
switch key {
case "true":
*kb = true
case "false":
*kb = false
default:
return fmt.Errorf("invalid bool value %v", key)
}
return nil
}
type keyStruct struct {
val int64
}
func (k keyStruct) MarshalText() (text []byte, err error) {
str := strconv.FormatInt(k.val, 10)
return []byte(str), nil
}
func (k *keyStruct) UnmarshalText(text []byte) error {
val, err := strconv.ParseInt(string(text), 10, 64)
if err != nil {
return err
}
*k = keyStruct{
val: val,
}
return nil
}
func TestMapCodec(t *testing.T) {
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
strstr := stringerString("foo")
mapObj := map[stringerString]int{strstr: 1}
testCases := []struct {
name string
opts *bsonoptions.MapCodecOptions
key string
}{
{"default", bsonoptions.MapCodec(), "foo"},
{"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"},
{"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mapCodec := bsoncodec.NewMapCodec(tc.opts)
mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build()
val, err := MarshalWithRegistry(mapRegistry, mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
assert.True(t, strings.Contains(string(val), tc.key), "expected result to contain %v, got: %v", tc.key, string(val))
})
}
})
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
mapObj := map[keyBool]int{keyBool(true): 1}
doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "true", 1)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
var got map[keyBool]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
})
t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
mapObj := map[keyStruct]int{
{val: 10}: 100,
}
doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "10", 100)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
var got map[keyStruct]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
})
}
func TestExtJSONEscapeKey(t *testing.T) {
doc := D{{Key: "\\usb#", Value: int32(1)}}
b, err := MarshalExtJSON(&doc, false, false)
noerr(t, err)
want := "{\"\\\\usb#\":1}"
if diff := cmp.Diff(want, string(b)); diff != "" {
t.Errorf("Marshaled documents do not match. got %v, want %v", string(b), want)
}
var got D
err = UnmarshalExtJSON(b, false, &got)
noerr(t, err)
if !cmp.Equal(got, doc) {
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, doc)
}
}
func TestBsoncoreArray(t *testing.T) {
type BSONDocumentArray struct {
Array []D `bson:"array"`
}
type BSONArray struct {
Array bsoncore.Array `bson:"array"`
}
bda := BSONDocumentArray{
Array: []D{
{{"x", 1}},
{{"x", 2}},
{{"x", 3}},
},
}
expectedBSON, err := Marshal(bda)
assert.Nil(t, err, "Marshal bsoncore.Document array error: %v", err)
var ba BSONArray
err = Unmarshal(expectedBSON, &ba)
assert.Nil(t, err, "Unmarshal error: %v", err)
actualBSON, err := Marshal(ba)
assert.Nil(t, err, "Marshal bsoncore.Array error: %v", err)
assert.Equal(t, expectedBSON, actualBSON,
"expected BSON to be %v after Marshalling again; got %v", expectedBSON, actualBSON)
doc := bsoncore.Document(actualBSON)
v := doc.Lookup("array")
assert.Equal(t, bsontype.Array, v.Type, "expected type array, got %v", v.Type)
}

View File

@@ -0,0 +1,50 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ArrayCodec is the Codec used for bsoncore.Array values.
type ArrayCodec struct{}
var defaultArrayCodec = NewArrayCodec()
// NewArrayCodec returns an ArrayCodec.
func NewArrayCodec() *ArrayCodec {
return &ArrayCodec{}
}
// EncodeValue is the ValueEncoder for bsoncore.Array values.
func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreArray {
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
arr := val.Interface().(bsoncore.Array)
return bsonrw.Copier{}.CopyArrayFromBytes(vw, arr)
}
// DecodeValue is the ValueDecoder for bsoncore.Array values.
func (ac *ArrayCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCoreArray {
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
arr, err := bsonrw.Copier{}.AppendArrayBytes(val.Interface().(bsoncore.Array), vr)
val.Set(reflect.ValueOf(arr))
return err
}

View File

@@ -0,0 +1,238 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec // import "go.mongodb.org/mongo-driver/bson/bsoncodec"
import (
"fmt"
"reflect"
"strings"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var (
emptyValue = reflect.Value{}
)
// Marshaler is an interface implemented by types that can marshal themselves
// into a BSON document represented as bytes. The bytes returned must be a valid
// BSON document if the error is nil.
type Marshaler interface {
MarshalBSON() ([]byte, error)
}
// ValueMarshaler is an interface implemented by types that can marshal
// themselves into a BSON value as bytes. The type must be the valid type for
// the bytes returned. The bytes and byte type together must be valid if the
// error is nil.
type ValueMarshaler interface {
MarshalBSONValue() (bsontype.Type, []byte, error)
}
// Unmarshaler is an interface implemented by types that can unmarshal a BSON
// document representation of themselves. The BSON bytes can be assumed to be
// valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data
// after returning.
type Unmarshaler interface {
UnmarshalBSON([]byte) error
}
// ValueUnmarshaler is an interface implemented by types that can unmarshal a
// BSON value representation of themselves. The BSON bytes and type can be
// assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it
// wishes to retain the data after returning.
type ValueUnmarshaler interface {
UnmarshalBSONValue(bsontype.Type, []byte) error
}
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
// encoded by the ValueEncoder.
type ValueEncoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vee ValueEncoderError) Error() string {
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
for _, t := range vee.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vee.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vee.Received.Kind().String()
if vee.Received.IsValid() {
received = vee.Received.Type().String()
}
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
}
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
// decoded by the ValueDecoder.
type ValueDecoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vde ValueDecoderError) Error() string {
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
for _, t := range vde.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vde.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vde.Received.Kind().String()
if vde.Received.IsValid() {
received = vde.Received.Type().String()
}
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}
// EncodeContext is the contextual information required for a Codec to encode a
// value.
type EncodeContext struct {
*Registry
MinSize bool
}
// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry
Truncate bool
// Ancestor is the type of a containing document. This is mainly used to determine what type
// should be used when decoding an embedded document into an empty interface. For example, if
// Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface
// will be decoded into a bson.M.
//
// Deprecated: Use DefaultDocumentM or DefaultDocumentD instead.
Ancestor reflect.Type
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
// error. DocumentType overrides the Ancestor field.
defaultDocumentType reflect.Type
}
// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
func (dc *DecodeContext) DefaultDocumentM() {
dc.defaultDocumentType = reflect.TypeOf(primitive.M{})
}
// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
func (dc *DecodeContext) DefaultDocumentD() {
dc.defaultDocumentType = reflect.TypeOf(primitive.D{})
}
// ValueCodec is the interface that groups the methods to encode and decode
// values.
type ValueCodec interface {
ValueEncoder
ValueDecoder
}
// ValueEncoder is the interface implemented by types that can handle the encoding of a value.
type ValueEncoder interface {
EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error
}
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueEncoder.
type ValueEncoderFunc func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error
// EncodeValue implements the ValueEncoder interface.
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}
// ValueDecoder is the interface implemented by types that can handle the decoding of a value.
type ValueDecoder interface {
DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error
}
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueDecoder.
type ValueDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) error
// DecodeValue implements the ValueDecoder interface.
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
return fn(dc, vr, val)
}
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
type typeDecoder interface {
decodeType(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error)
}
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
type typeDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error)
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
return fn(dc, vr, t)
}
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
type decodeAdapter struct {
ValueDecoderFunc
typeDecoderFunc
}
var _ ValueDecoder = decodeAdapter{}
var _ typeDecoder = decodeAdapter{}
// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type
// t and calls decoder.DecodeValue on it.
func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
td, _ := decoder.(typeDecoder)
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true)
}
func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) {
if td != nil {
val, err := td.decodeType(dc, vr, t)
if err == nil && convert && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
//
// type myBool bool
// var m map[string]myBool
//
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
// because we'll try to assign a value of type bool to one of type myBool.
val = val.Convert(t)
}
return val, err
}
val := reflect.New(t).Elem()
err := vd.DecodeValue(dc, vr, val)
return val, err
}
// CodecZeroer is the interface implemented by Codecs that can also determine if
// a value of the type that would be encoded is zero.
type CodecZeroer interface {
IsTypeZero(interface{}) bool
}

View File

@@ -0,0 +1,143 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"testing"
"time"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func ExampleValueEncoder() {
var _ ValueEncoderFunc = func(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
return vw.WriteString(val.String())
}
}
func ExampleValueDecoder() {
var _ ValueDecoderFunc = func(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
if vr.Type() != bsontype.String {
return fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
str, err := vr.ReadString()
if err != nil {
return err
}
val.SetString(str)
return nil
}
}
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}
func compareTime(t1, t2 time.Time) bool {
if t1.Location() != t2.Location() {
return false
}
return t1.Equal(t2)
}
func compareErrors(err1, err2 error) bool {
if err1 == nil && err2 == nil {
return true
}
if err1 == nil || err2 == nil {
return false
}
if err1.Error() != err2.Error() {
return false
}
return true
}
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
d1H, d1L := d1.GetBytes()
d2H, d2L := d2.GetBytes()
if d1H != d2H {
return false
}
if d1L != d2L {
return false
}
return true
}
type noPrivateFields struct {
a string
}
func compareNoPrivateFields(npf1, npf2 noPrivateFields) bool {
return npf1.a != npf2.a // We don't want these to be equal
}
type zeroTest struct {
reportZero bool
}
func (z zeroTest) IsZero() bool { return z.reportZero }
func compareZeroTest(_, _ zeroTest) bool { return true }
type nonZeroer struct {
value bool
}
type llCodec struct {
t *testing.T
decodeval interface{}
encodeval interface{}
err error
}
func (llc *llCodec) EncodeValue(_ EncodeContext, _ bsonrw.ValueWriter, i interface{}) error {
if llc.err != nil {
return llc.err
}
llc.encodeval = i
return nil
}
func (llc *llCodec) DecodeValue(_ DecodeContext, _ bsonrw.ValueReader, val reflect.Value) error {
if llc.err != nil {
return llc.err
}
if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type()) {
llc.t.Errorf("decodeval must be assignable to val provided to DecodeValue, but is not. decodeval %T; val %T", llc.decodeval, val)
return nil
}
val.Set(reflect.ValueOf(llc.decodeval))
return nil
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ByteSliceCodec is the Codec used for []byte values.
type ByteSliceCodec struct {
EncodeNilAsEmpty bool
}
var (
defaultByteSliceCodec = NewByteSliceCodec()
_ ValueCodec = defaultByteSliceCodec
_ typeDecoder = defaultByteSliceCodec
)
// NewByteSliceCodec returns a StringCodec with options opts.
func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec {
byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...)
codec := ByteSliceCodec{}
if byteSliceOpt.EncodeNilAsEmpty != nil {
codec.EncodeNilAsEmpty = *byteSliceOpt.EncodeNilAsEmpty
}
return &codec
}
// EncodeValue is the ValueEncoder for []byte.
func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() && !bsc.EncodeNilAsEmpty {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
func (bsc *ByteSliceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tByteSlice {
return emptyValue, ValueDecoderError{
Name: "ByteSliceDecodeValue",
Types: []reflect.Type{tByteSlice},
Received: reflect.Zero(t),
}
}
var data []byte
var err error
switch vrType := vr.Type(); vrType {
case bsontype.String:
str, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
data = []byte(str)
case bsontype.Symbol:
sym, err := vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
data = []byte(sym)
case bsontype.Binary:
var subtype byte
data, subtype, err = vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
}
case bsontype.Null:
err = vr.ReadNull()
case bsontype.Undefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(data), nil
}
// DecodeValue is the ValueDecoder for []byte.
func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tByteSlice {
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
elem, err := bsc.decodeType(dc, vr, tByteSlice)
if err != nil {
return err
}
val.Set(elem)
return nil
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonrw"
)
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
type condAddrEncoder struct {
canAddrEnc ValueEncoder
elseEnc ValueEncoder
}
var _ ValueEncoder = (*condAddrEncoder)(nil)
// newCondAddrEncoder returns an condAddrEncoder.
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return &encoder
}
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.CanAddr() {
return cae.canAddrEnc.EncodeValue(ec, vw, val)
}
if cae.elseEnc != nil {
return cae.elseEnc.EncodeValue(ec, vw, val)
}
return ErrNoEncoder{Type: val.Type()}
}
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
type condAddrDecoder struct {
canAddrDec ValueDecoder
elseDec ValueDecoder
}
var _ ValueDecoder = (*condAddrDecoder)(nil)
// newCondAddrDecoder returns an CondAddrDecoder.
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
return &decoder
}
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if val.CanAddr() {
return cad.canAddrDec.DecodeValue(dc, vr, val)
}
if cad.elseDec != nil {
return cad.elseDec.DecodeValue(dc, vr, val)
}
return ErrNoDecoder{Type: val.Type()}
}

View File

@@ -0,0 +1,97 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestCondAddrCodec(t *testing.T) {
var inner int
canAddrVal := reflect.ValueOf(&inner)
addressable := canAddrVal.Elem()
unaddressable := reflect.ValueOf(inner)
rw := &bsonrwtest.ValueReaderWriter{}
t.Run("addressEncode", func(t *testing.T) {
invoked := 0
encode1 := ValueEncoderFunc(func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
invoked = 1
return nil
})
encode2 := ValueEncoderFunc(func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
invoked = 2
return nil
})
condEncoder := newCondAddrEncoder(encode1, encode2)
testCases := []struct {
name string
val reflect.Value
invoked int
}{
{"canAddr", addressable, 1},
{"else", unaddressable, 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val)
assert.Nil(t, err, "CondAddrEncoder error: %v", err)
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
})
}
t.Run("error", func(t *testing.T) {
errEncoder := newCondAddrEncoder(encode1, nil)
err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable)
want := ErrNoEncoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
})
t.Run("addressDecode", func(t *testing.T) {
invoked := 0
decode1 := ValueDecoderFunc(func(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
invoked = 1
return nil
})
decode2 := ValueDecoderFunc(func(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
invoked = 2
return nil
})
condDecoder := newCondAddrDecoder(decode1, decode2)
testCases := []struct {
name string
val reflect.Value
invoked int
}{
{"canAddr", addressable, 1},
{"else", unaddressable, 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
assert.Nil(t, err, "CondAddrDecoder error: %v", err)
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
})
}
t.Run("error", func(t *testing.T) {
errDecoder := newCondAddrDecoder(decode1, nil)
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
want := ErrNoDecoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,766 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/url"
"reflect"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var defaultValueEncoders DefaultValueEncoders
var bvwPool = bsonrw.NewBSONValueWriterPool()
var errInvalidValue = errors.New("cannot encode invalid element")
var sliceWriterPool = sync.Pool{
New: func() interface{} {
sw := make(bsonrw.SliceWriter, 0)
return &sw
},
}
func encodeElement(ec EncodeContext, dw bsonrw.DocumentWriter, e primitive.E) error {
vw, err := dw.WriteDocumentElement(e.Key)
if err != nil {
return err
}
if e.Value == nil {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
if err != nil {
return err
}
return nil
}
// DefaultValueEncoders is a namespace type for the default ValueEncoders used
// when creating a registry.
type DefaultValueEncoders struct{}
// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
// the provided RegistryBuilder.
func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) {
if rb == nil {
panic(errors.New("argument to RegisterDefaultEncoders must not be nil"))
}
rb.
RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec).
RegisterTypeEncoder(tTime, defaultTimeCodec).
RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec).
RegisterTypeEncoder(tCoreArray, defaultArrayCodec).
RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)).
RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)).
RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)).
RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)).
RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)).
RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)).
RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)).
RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)).
RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)).
RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)).
RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)).
RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)).
RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)).
RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)).
RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)).
RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)).
RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)).
RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)).
RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)).
RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)).
RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)).
RegisterDefaultEncoder(reflect.Map, defaultMapCodec).
RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec).
RegisterDefaultEncoder(reflect.String, defaultStringCodec).
RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()).
RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()).
RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)).
RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)).
RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue))
}
// BooleanEncodeValue is the ValueEncoderFunc for bool types.
func (dve DefaultValueEncoders) BooleanEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Bool {
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
return vw.WriteBoolean(val.Bool())
}
func fitsIn32Bits(i int64) bool {
return math.MinInt32 <= i && i <= math.MaxInt32
}
// IntEncodeValue is the ValueEncoderFunc for int types.
func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
return vw.WriteInt32(int32(val.Int()))
case reflect.Int:
i64 := val.Int()
if fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
case reflect.Int64:
i64 := val.Int()
if ec.MinSize && fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
}
return ValueEncoderError{
Name: "IntEncodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
// UintEncodeValue is the ValueEncoderFunc for uint types.
//
// Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead.
func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Uint8, reflect.Uint16:
return vw.WriteInt32(int32(val.Uint()))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
u64 := val.Uint()
if ec.MinSize && u64 <= math.MaxInt32 {
return vw.WriteInt32(int32(u64))
}
if u64 > math.MaxInt64 {
return fmt.Errorf("%d overflows int64", u64)
}
return vw.WriteInt64(int64(u64))
}
return ValueEncoderError{
Name: "UintEncodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
// FloatEncodeValue is the ValueEncoderFunc for float types.
func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Float32, reflect.Float64:
return vw.WriteDouble(val.Float())
}
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
}
// StringEncodeValue is the ValueEncoderFunc for string types.
//
// Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead.
func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
// ObjectIDEncodeValue is the ValueEncoderFunc for primitive.ObjectID.
func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tOID {
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
}
return vw.WriteObjectID(val.Interface().(primitive.ObjectID))
}
// Decimal128EncodeValue is the ValueEncoderFunc for primitive.Decimal128.
func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDecimal {
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
return vw.WriteDecimal128(val.Interface().(primitive.Decimal128))
}
// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number.
func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJSONNumber {
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
jsnum := val.Interface().(json.Number)
// Attempt int first, then float64
if i64, err := jsnum.Int64(); err == nil {
return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64))
}
f64, err := jsnum.Float64()
if err != nil {
return err
}
return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64))
}
// URLEncodeValue is the ValueEncoderFunc for url.URL.
func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tURL {
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
}
u := val.Interface().(url.URL)
return vw.WriteString(u.String())
}
// TimeEncodeValue is the ValueEncoderFunc for time.TIme.
//
// Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead.
func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTime {
return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val}
}
tt := val.Interface().(time.Time)
dt := primitive.NewDateTimeFromTime(tt)
return vw.WriteDateTime(int64(dt))
}
// ByteSliceEncodeValue is the ValueEncoderFunc for []byte.
//
// Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead.
func (dve DefaultValueEncoders) ByteSliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
// MapEncodeValue is the ValueEncoderFunc for map[string]* types.
//
// Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead.
func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() {
// If we have a nill map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return dve.mapEncodeValue(ec, dw, val, nil)
}
// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
keys := val.MapKeys()
for _, key := range keys {
if collisionFn != nil && collisionFn(key.String()) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := dw.WriteDocumentElement(key.String())
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// ArrayEncodeValue is the ValueEncoderFunc for array types.
func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().Elem() == tE {
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
e := val.Index(idx).Interface().(primitive.E)
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// SliceEncodeValue is the ValueEncoderFunc for slice types.
//
// Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead.
func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
return origEncoder, currVal, nil
}
currVal = currVal.Elem()
if !currVal.IsValid() {
return nil, currVal, errInvalidValue
}
currEncoder, err := ec.LookupEncoder(currVal.Type())
return currEncoder, currVal, err
}
// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}.
//
// Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead.
func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement ValueMarshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
case val.Type().Implements(tValueMarshaler):
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tValueMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}
fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue")
returns := fn.Call(nil)
if !returns[2].IsNil() {
return returns[2].Interface().(error)
}
t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data)
}
// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Marshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
case val.Type().Implements(tMarshaler):
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}
fn := val.Convert(tMarshaler).MethodByName("MarshalBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
}
data := returns[0].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data)
}
// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations.
func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Proxy
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
case val.Type().Implements(tProxy):
// If Proxy is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tProxy) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tProxy) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
}
fn := val.Convert(tProxy).MethodByName("ProxyBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
}
data := returns[0]
var encoder ValueEncoder
var err error
if data.Elem().IsValid() {
encoder, err = ec.LookupEncoder(data.Elem().Type())
} else {
encoder, err = ec.LookupEncoder(nil)
}
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, data.Elem())
}
// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type.
func (DefaultValueEncoders) JavaScriptEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJavaScript {
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
return vw.WriteJavascript(val.String())
}
// SymbolEncodeValue is the ValueEncoderFunc for the primitive.Symbol type.
func (DefaultValueEncoders) SymbolEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tSymbol {
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
return vw.WriteSymbol(val.String())
}
// BinaryEncodeValue is the ValueEncoderFunc for Binary.
func (DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tBinary {
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
b := val.Interface().(primitive.Binary)
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// UndefinedEncodeValue is the ValueEncoderFunc for Undefined.
func (DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tUndefined {
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
return vw.WriteUndefined()
}
// DateTimeEncodeValue is the ValueEncoderFunc for DateTime.
func (DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDateTime {
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
return vw.WriteDateTime(val.Int())
}
// NullEncodeValue is the ValueEncoderFunc for Null.
func (DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tNull {
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
}
return vw.WriteNull()
}
// RegexEncodeValue is the ValueEncoderFunc for Regex.
func (DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRegex {
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
regex := val.Interface().(primitive.Regex)
return vw.WriteRegex(regex.Pattern, regex.Options)
}
// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer.
func (DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDBPointer {
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
dbp := val.Interface().(primitive.DBPointer)
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
}
// TimestampEncodeValue is the ValueEncoderFunc for Timestamp.
func (DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTimestamp {
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
ts := val.Interface().(primitive.Timestamp)
return vw.WriteTimestamp(ts.T, ts.I)
}
// MinKeyEncodeValue is the ValueEncoderFunc for MinKey.
func (DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMinKey {
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
return vw.WriteMinKey()
}
// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
func (DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMaxKey {
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
return vw.WriteMaxKey()
}
// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
func (DefaultValueEncoders) CoreDocumentEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreDocument {
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
cdoc := val.Interface().(bsoncore.Document)
return bsonrw.Copier{}.CopyDocumentFromBytes(vw, cdoc)
}
// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCodeWithScope {
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
cws := val.Interface().(primitive.CodeWithScope)
dw, err := vw.WriteCodeWithScope(string(cws.Code))
if err != nil {
return err
}
sw := sliceWriterPool.Get().(*bsonrw.SliceWriter)
defer sliceWriterPool.Put(sw)
*sw = (*sw)[:0]
scopeVW := bvwPool.Get(sw)
defer bvwPool.Put(scopeVW)
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
if err != nil {
return err
}
err = bsonrw.Copier{}.CopyBytesToDocumentWriter(dw, *sw)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
vt := val.Type()
for vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,90 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsoncodec provides a system for encoding values to BSON representations and decoding
// values from BSON representations. This package considers both binary BSON and ExtendedJSON as
// BSON representations. The types in this package enable a flexible system for handling this
// encoding and decoding.
//
// The codec system is composed of two parts:
//
// 1) ValueEncoders and ValueDecoders that handle encoding and decoding Go values to and from BSON
// representations.
//
// 2) A Registry that holds these ValueEncoders and ValueDecoders and provides methods for
// retrieving them.
//
// # ValueEncoders and ValueDecoders
//
// The ValueEncoder interface is implemented by types that can encode a provided Go type to BSON.
// The value to encode is provided as a reflect.Value and a bsonrw.ValueWriter is used within the
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
// to provide configuration information.
//
// The ValueDecoder interface is the inverse of the ValueEncoder. Implementations should ensure that
// the value they receive is settable. Similar to ValueEncoderFunc, ValueDecoderFunc is provided to
// allow the use of a function with the correct signature as a ValueDecoder. A DecodeContext
// instance is provided and serves similar functionality to the EncodeContext.
//
// # Registry and RegistryBuilder
//
// A Registry is an immutable store for ValueEncoders, ValueDecoders, and a type map. See the Registry type
// documentation for examples of registering various custom encoders and decoders. A Registry can be constructed using a
// RegistryBuilder, which handles three main types of codecs:
//
// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and RegisterTypeDecoder methods.
// The registered codec will be invoked when encoding/decoding a value whose type matches the registered type exactly.
// If the registered type is an interface, the codec will be invoked when encoding or decoding values whose type is the
// interface, but not for values with concrete types that implement the interface.
//
// 2. Hook encoders/decoders - These can be registered using the RegisterHookEncoder and RegisterHookDecoder methods.
// These methods only accept interface types and the registered codecs will be invoked when encoding or decoding values
// whose types implement the interface. An example of a hook defined by the driver is bson.Marshaler. The driver will
// call the MarshalBSON method for any value whose type implements bson.Marshaler, regardless of the value's concrete
// type.
//
// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type associations are used when
// decoding into a bson.D/bson.M or a struct field of type interface{}. For example, by default, BSON int32 and int64
// values decode as Go int32 and int64 instances, respectively, when decoding into a bson.D. The following code would
// change the behavior so these values decode as Go int instances instead:
//
// intType := reflect.TypeOf(int(0))
// registryBuilder.RegisterTypeMapEntry(bsontype.Int32, intType).RegisterTypeMapEntry(bsontype.Int64, intType)
//
// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and RegisterDefaultDecoder
// methods. The registered codec will be invoked when encoding or decoding values whose reflect.Kind matches the
// registered reflect.Kind as long as the value's type doesn't match a registered type or hook encoder/decoder first.
// These methods should be used to change the behavior for all values for a specific kind.
//
// # Registry Lookup Procedure
//
// When looking up an encoder in a Registry, the precedence rules are as follows:
//
// 1. A type encoder registered for the exact type of the value.
//
// 2. A hook encoder registered for an interface that is implemented by the value or by a pointer to the value. If the
// value matches multiple hooks (e.g. the type implements bsoncodec.Marshaler and bsoncodec.ValueMarshaler), the first
// one registered will be selected. Note that registries constructed using bson.NewRegistryBuilder have driver-defined
// hooks registered for the bsoncodec.Marshaler, bsoncodec.ValueMarshaler, and bsoncodec.Proxy interfaces, so those
// will take precedence over any new hooks.
//
// 3. A kind encoder registered for the value's kind.
//
// If all of these lookups fail to find an encoder, an error of type ErrNoEncoder is returned. The same precedence
// rules apply for decoders, with the exception that an error of type ErrNoDecoder will be returned if no decoder is
// found.
//
// # DefaultValueEncoders and DefaultValueDecoders
//
// The DefaultValueEncoders and DefaultValueDecoders types provide a full set of ValueEncoders and
// ValueDecoders for handling a wide range of Go types, including all of the types within the
// primitive package. To make registering these codecs easier, a helper method on each type is
// provided. For the DefaultValueEncoders type the method is called RegisterDefaultEncoders and for
// the DefaultValueDecoders type the method is called RegisterDefaultDecoders, this method also
// handles registering type map entries for each BSON type.
package bsoncodec

View File

@@ -0,0 +1,147 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// EmptyInterfaceCodec is the Codec used for interface{} values.
type EmptyInterfaceCodec struct {
DecodeBinaryAsSlice bool
}
var (
defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec()
_ ValueCodec = defaultEmptyInterfaceCodec
_ typeDecoder = defaultEmptyInterfaceCodec
)
// NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts.
func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec {
interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...)
codec := EmptyInterfaceCodec{}
if interfaceOpt.DecodeBinaryAsSlice != nil {
codec.DecodeBinaryAsSlice = *interfaceOpt.DecodeBinaryAsSlice
}
return &codec
}
// EncodeValue is the ValueEncoderFunc for interface{}.
func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType bsontype.Type) (reflect.Type, error) {
isDocument := valueType == bsontype.Type(0) || valueType == bsontype.EmbeddedDocument
if isDocument {
if dc.defaultDocumentType != nil {
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
// that type.
return dc.defaultDocumentType, nil
}
if dc.Ancestor != nil {
// Using ancestor information rather than looking up the type map entry forces consistent decoding.
// If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry
// has been registered.
return dc.Ancestor, nil
}
}
rtype, err := dc.LookupTypeMapEntry(valueType)
if err == nil {
return rtype, nil
}
if isDocument {
// For documents, fallback to looking up a type map entry for bsontype.Type(0) or bsontype.EmbeddedDocument,
// depending on the original valueType.
var lookupType bsontype.Type
switch valueType {
case bsontype.Type(0):
lookupType = bsontype.EmbeddedDocument
case bsontype.EmbeddedDocument:
lookupType = bsontype.Type(0)
}
rtype, err = dc.LookupTypeMapEntry(lookupType)
if err == nil {
return rtype, nil
}
}
return nil, err
}
func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tEmpty {
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
}
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
if err != nil {
switch vr.Type() {
case bsontype.Null:
return reflect.Zero(t), vr.ReadNull()
default:
return emptyValue, err
}
}
decoder, err := dc.LookupDecoder(rtype)
if err != nil {
return emptyValue, err
}
elem, err := decodeTypeOrValue(decoder, dc, vr, rtype)
if err != nil {
return emptyValue, err
}
if eic.DecodeBinaryAsSlice && rtype == tBinary {
binElem := elem.Interface().(primitive.Binary)
if binElem.Subtype == bsontype.BinaryGeneric || binElem.Subtype == bsontype.BinaryBinaryOld {
elem = reflect.ValueOf(binElem.Data)
}
}
return elem, nil
}
// DecodeValue is the ValueDecoderFunc for interface{}.
func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tEmpty {
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
elem, err := eic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.Set(elem)
return nil
}

View File

@@ -0,0 +1,309 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding"
"fmt"
"reflect"
"strconv"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var defaultMapCodec = NewMapCodec()
// MapCodec is the Codec used for map values.
type MapCodec struct {
DecodeZerosMap bool
EncodeNilAsEmpty bool
EncodeKeysWithStringer bool
}
var _ ValueCodec = &MapCodec{}
// KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
// This applies to types used as map keys and is similar to encoding.TextMarshaler.
type KeyMarshaler interface {
MarshalKey() (key string, err error)
}
// KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
// of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
//
// UnmarshalKey must be able to decode the form generated by MarshalKey.
// UnmarshalKey must copy the text if it wishes to retain the text
// after returning.
type KeyUnmarshaler interface {
UnmarshalKey(key string) error
}
// NewMapCodec returns a MapCodec with options opts.
func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
codec := MapCodec{}
if mapOpt.DecodeZerosMap != nil {
codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
}
if mapOpt.EncodeNilAsEmpty != nil {
codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
}
if mapOpt.EncodeKeysWithStringer != nil {
codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
}
return &codec
}
// EncodeValue is the ValueEncoder for map[*]* types.
func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() && !mc.EncodeNilAsEmpty {
// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return mc.mapEncodeValue(ec, dw, val, nil)
}
// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
keys := val.MapKeys()
for _, key := range keys {
keyStr, err := mc.encodeKey(key)
if err != nil {
return err
}
if collisionFn != nil && collisionFn(keyStr) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := dw.WriteDocumentElement(keyStr)
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// DecodeValue is the ValueDecoder for map[string/decimal]* types.
func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
switch vrType := vr.Type(); vrType {
case bsontype.Type(0), bsontype.EmbeddedDocument:
case bsontype.Null:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case bsontype.Undefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeMap(val.Type()))
}
if val.Len() > 0 && mc.DecodeZerosMap {
clearMap(val)
}
eType := val.Type().Elem()
decoder, err := dc.LookupDecoder(eType)
if err != nil {
return err
}
eTypeDecoder, _ := decoder.(typeDecoder)
if eType == tEmpty {
dc.Ancestor = val.Type()
}
keyType := val.Type().Key()
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
k, err := mc.decodeKey(key, keyType)
if err != nil {
return err
}
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
if err != nil {
return newDecodeError(key, err)
}
val.SetMapIndex(k, elem)
}
return nil
}
func clearMap(m reflect.Value) {
var none reflect.Value
for _, k := range m.MapKeys() {
m.SetMapIndex(k, none)
}
}
func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
if mc.EncodeKeysWithStringer {
return fmt.Sprint(val), nil
}
// keys of any string type are used directly
if val.Kind() == reflect.String {
return val.String(), nil
}
// KeyMarshalers are marshaled
if km, ok := val.Interface().(KeyMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalKey()
if err == nil {
return buf, nil
}
return "", err
}
// keys implement encoding.TextMarshaler are marshaled.
if km, ok := val.Interface().(encoding.TextMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalText()
if err != nil {
return "", err
}
return string(buf), nil
}
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(val.Uint(), 10), nil
}
return "", fmt.Errorf("unsupported key type: %v", val.Type())
}
var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
keyVal := reflect.ValueOf(key)
var err error
switch {
// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(KeyUnmarshaler)
err = v.UnmarshalKey(key)
keyVal = keyVal.Elem()
// Try to decode encoding.TextUnmarshalers.
case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(encoding.TextUnmarshaler)
err = v.UnmarshalText([]byte(key))
keyVal = keyVal.Elem()
// Otherwise, go to type specific behavior
default:
switch keyType.Kind() {
case reflect.String:
keyVal = reflect.ValueOf(key).Convert(keyType)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, parseErr := strconv.ParseInt(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, parseErr := strconv.ParseUint(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
break
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Float32, reflect.Float64:
if mc.EncodeKeysWithStringer {
parsed, err := strconv.ParseFloat(key, 64)
if err != nil {
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
}
keyVal = reflect.ValueOf(parsed)
break
}
fallthrough
default:
return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
}
}
return keyVal, err
}

View File

@@ -0,0 +1,65 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import "fmt"
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "DocumentMode"
case mArray:
str = "ArrayMode"
case mValue:
str = "ValueMode"
case mElement:
str = "ElementMode"
case mCodeWithScope:
str = "CodeWithScopeMode"
case mSpacer:
str = "CodeWithScopeSpacerFrame"
default:
str = "UnknownMode"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
type TransitionError struct {
parent mode
current mode
destination mode
}
func (te TransitionError) Error() string {
if te.destination == mode(0) {
return fmt.Sprintf("invalid state transition: cannot read/write value while in %s", te.current)
}
if te.parent == mode(0) {
return fmt.Sprintf("invalid state transition: %s -> %s", te.current, te.destination)
}
return fmt.Sprintf("invalid state transition: %s -> %s; parent %s", te.current, te.destination, te.parent)
}

View File

@@ -0,0 +1,109 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var _ ValueEncoder = &PointerCodec{}
var _ ValueDecoder = &PointerCodec{}
// PointerCodec is the Codec used for pointers.
type PointerCodec struct {
ecache map[reflect.Type]ValueEncoder
dcache map[reflect.Type]ValueDecoder
l sync.RWMutex
}
// NewPointerCodec returns a PointerCodec that has been initialized.
func NewPointerCodec() *PointerCodec {
return &PointerCodec{
ecache: make(map[reflect.Type]ValueEncoder),
dcache: make(map[reflect.Type]ValueDecoder),
}
}
// EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil
// or looking up an encoder for the type of value the pointer points to.
func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.Ptr {
if !val.IsValid() {
return vw.WriteNull()
}
return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
pc.l.RLock()
enc, ok := pc.ecache[val.Type()]
pc.l.RUnlock()
if ok {
if enc == nil {
return ErrNoEncoder{Type: val.Type()}
}
return enc.EncodeValue(ec, vw, val.Elem())
}
enc, err := ec.LookupEncoder(val.Type().Elem())
pc.l.Lock()
pc.ecache[val.Type()] = enc
pc.l.Unlock()
if err != nil {
return err
}
return enc.EncodeValue(ec, vw, val.Elem())
}
// DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and
// using that to decode. If the BSON value is Null, this method will set the pointer to nil.
func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Ptr {
return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if vr.Type() == bsontype.Null {
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
}
if vr.Type() == bsontype.Undefined {
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
}
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
pc.l.RLock()
dec, ok := pc.dcache[val.Type()]
pc.l.RUnlock()
if ok {
if dec == nil {
return ErrNoDecoder{Type: val.Type()}
}
return dec.DecodeValue(dc, vr, val.Elem())
}
dec, err := dc.LookupDecoder(val.Type().Elem())
pc.l.Lock()
pc.dcache[val.Type()] = dec
pc.l.Unlock()
if err != nil {
return err
}
return dec.DecodeValue(dc, vr, val.Elem())
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
// Proxy is an interface implemented by types that cannot themselves be directly encoded. Types
// that implement this interface with have ProxyBSON called during the encoding process and that
// value will be encoded in place for the implementer.
type Proxy interface {
ProxyBSON() (interface{}, error)
}

View File

@@ -0,0 +1,469 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"errors"
"fmt"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder.
var ErrNilType = errors.New("cannot perform a decoder lookup on <nil>")
// ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder.
var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder")
// ErrNoEncoder is returned when there wasn't an encoder available for a type.
type ErrNoEncoder struct {
Type reflect.Type
}
func (ene ErrNoEncoder) Error() string {
if ene.Type == nil {
return "no encoder found for <nil>"
}
return "no encoder found for " + ene.Type.String()
}
// ErrNoDecoder is returned when there wasn't a decoder available for a type.
type ErrNoDecoder struct {
Type reflect.Type
}
func (end ErrNoDecoder) Error() string {
return "no decoder found for " + end.Type.String()
}
// ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type.
type ErrNoTypeMapEntry struct {
Type bsontype.Type
}
func (entme ErrNoTypeMapEntry) Error() string {
return "no type map entry found for " + entme.Type.String()
}
// ErrNotInterface is returned when the provided type is not an interface.
var ErrNotInterface = errors.New("The provided type is not an interface")
// A RegistryBuilder is used to build a Registry. This type is not goroutine
// safe.
type RegistryBuilder struct {
typeEncoders map[reflect.Type]ValueEncoder
interfaceEncoders []interfaceValueEncoder
kindEncoders map[reflect.Kind]ValueEncoder
typeDecoders map[reflect.Type]ValueDecoder
interfaceDecoders []interfaceValueDecoder
kindDecoders map[reflect.Kind]ValueDecoder
typeMap map[bsontype.Type]reflect.Type
}
// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main
// typed passed around and Encoders and Decoders are constructed from it.
type Registry struct {
typeEncoders map[reflect.Type]ValueEncoder
typeDecoders map[reflect.Type]ValueDecoder
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
kindEncoders map[reflect.Kind]ValueEncoder
kindDecoders map[reflect.Kind]ValueDecoder
typeMap map[bsontype.Type]reflect.Type
mu sync.RWMutex
}
// NewRegistryBuilder creates a new empty RegistryBuilder.
func NewRegistryBuilder() *RegistryBuilder {
return &RegistryBuilder{
typeEncoders: make(map[reflect.Type]ValueEncoder),
typeDecoders: make(map[reflect.Type]ValueDecoder),
interfaceEncoders: make([]interfaceValueEncoder, 0),
interfaceDecoders: make([]interfaceValueDecoder, 0),
kindEncoders: make(map[reflect.Kind]ValueEncoder),
kindDecoders: make(map[reflect.Kind]ValueDecoder),
typeMap: make(map[bsontype.Type]reflect.Type),
}
}
func buildDefaultRegistry() *Registry {
rb := NewRegistryBuilder()
defaultValueEncoders.RegisterDefaultEncoders(rb)
defaultValueDecoders.RegisterDefaultDecoders(rb)
return rb.Build()
}
// RegisterCodec will register the provided ValueCodec for the provided type.
func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder {
rb.RegisterTypeEncoder(t, codec)
rb.RegisterTypeDecoder(t, codec)
return rb
}
// RegisterTypeEncoder will register the provided ValueEncoder for the provided type.
//
// The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered
// for a pointer to that type.
//
// If the given type is an interface, the encoder will be called when marshalling a type that is that interface. It
// will not be called when marshalling a non-interface type that implements the interface.
func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
rb.typeEncoders[t] = enc
return rb
}
// RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when
// marshalling a type if the type implements t or a pointer to the type implements t. If the provided type is not
// an interface (i.e. t.Kind() != reflect.Interface), this method will panic.
func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
if t.Kind() != reflect.Interface {
panicStr := fmt.Sprintf("RegisterHookEncoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", t, t.Kind())
panic(panicStr)
}
for idx, encoder := range rb.interfaceEncoders {
if encoder.i == t {
rb.interfaceEncoders[idx].ve = enc
return rb
}
}
rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc})
return rb
}
// RegisterTypeDecoder will register the provided ValueDecoder for the provided type.
//
// The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered
// for a pointer to that type.
//
// If the given type is an interface, the decoder will be called when unmarshalling into a type that is that interface.
// It will not be called when unmarshalling into a non-interface type that implements the interface.
func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
rb.typeDecoders[t] = dec
return rb
}
// RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when
// unmarshalling into a type if the type implements t or a pointer to the type implements t. If the provided type is not
// an interface (i.e. t.Kind() != reflect.Interface), this method will panic.
func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
if t.Kind() != reflect.Interface {
panicStr := fmt.Sprintf("RegisterHookDecoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", t, t.Kind())
panic(panicStr)
}
for idx, decoder := range rb.interfaceDecoders {
if decoder.i == t {
rb.interfaceDecoders[idx].vd = dec
return rb
}
}
rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec})
return rb
}
// RegisterEncoder registers the provided type and encoder pair.
//
// Deprecated: Use RegisterTypeEncoder or RegisterHookEncoder instead.
func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
if t == tEmpty {
rb.typeEncoders[t] = enc
return rb
}
switch t.Kind() {
case reflect.Interface:
for idx, ir := range rb.interfaceEncoders {
if ir.i == t {
rb.interfaceEncoders[idx].ve = enc
return rb
}
}
rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc})
default:
rb.typeEncoders[t] = enc
}
return rb
}
// RegisterDecoder registers the provided type and decoder pair.
//
// Deprecated: Use RegisterTypeDecoder or RegisterHookDecoder instead.
func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
if t == nil {
rb.typeDecoders[nil] = dec
return rb
}
if t == tEmpty {
rb.typeDecoders[t] = dec
return rb
}
switch t.Kind() {
case reflect.Interface:
for idx, ir := range rb.interfaceDecoders {
if ir.i == t {
rb.interfaceDecoders[idx].vd = dec
return rb
}
}
rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec})
default:
rb.typeDecoders[t] = dec
}
return rb
}
// RegisterDefaultEncoder will registr the provided ValueEncoder to the provided
// kind.
func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder {
rb.kindEncoders[kind] = enc
return rb
}
// RegisterDefaultDecoder will register the provided ValueDecoder to the
// provided kind.
func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder {
rb.kindDecoders[kind] = dec
return rb
}
// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this
// mapping is decoding situations where an empty interface is used and a default type needs to be
// created and decoded into.
//
// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON
// documents, a type map entry for bsontype.EmbeddedDocument should be registered. For example, to force BSON documents
// to decode to bson.Raw, use the following code:
//
// rb.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{}))
func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) *RegistryBuilder {
rb.typeMap[bt] = rt
return rb
}
// Build creates a Registry from the current state of this RegistryBuilder.
func (rb *RegistryBuilder) Build() *Registry {
registry := new(Registry)
registry.typeEncoders = make(map[reflect.Type]ValueEncoder)
for t, enc := range rb.typeEncoders {
registry.typeEncoders[t] = enc
}
registry.typeDecoders = make(map[reflect.Type]ValueDecoder)
for t, dec := range rb.typeDecoders {
registry.typeDecoders[t] = dec
}
registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.interfaceEncoders))
copy(registry.interfaceEncoders, rb.interfaceEncoders)
registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.interfaceDecoders))
copy(registry.interfaceDecoders, rb.interfaceDecoders)
registry.kindEncoders = make(map[reflect.Kind]ValueEncoder)
for kind, enc := range rb.kindEncoders {
registry.kindEncoders[kind] = enc
}
registry.kindDecoders = make(map[reflect.Kind]ValueDecoder)
for kind, dec := range rb.kindDecoders {
registry.kindDecoders[kind] = dec
}
registry.typeMap = make(map[bsontype.Type]reflect.Type)
for bt, rt := range rb.typeMap {
registry.typeMap[bt] = rt
}
return registry
}
// LookupEncoder inspects the registry for an encoder for the given type. The lookup precedence works as follows:
//
// 1. An encoder registered for the exact type. If the given type represents an interface, an encoder registered using
// RegisterTypeEncoder for the interface will be selected.
//
// 2. An encoder registered using RegisterHookEncoder for an interface implemented by the type or by a pointer to the
// type.
//
// 3. An encoder registered for the reflect.Kind of the value.
//
// If no encoder is found, an error of type ErrNoEncoder is returned.
func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) {
encodererr := ErrNoEncoder{Type: t}
r.mu.RLock()
enc, found := r.lookupTypeEncoder(t)
r.mu.RUnlock()
if found {
if enc == nil {
return nil, ErrNoEncoder{Type: t}
}
return enc, nil
}
enc, found = r.lookupInterfaceEncoder(t, true)
if found {
r.mu.Lock()
r.typeEncoders[t] = enc
r.mu.Unlock()
return enc, nil
}
if t == nil {
r.mu.Lock()
r.typeEncoders[t] = nil
r.mu.Unlock()
return nil, encodererr
}
enc, found = r.kindEncoders[t.Kind()]
if !found {
r.mu.Lock()
r.typeEncoders[t] = nil
r.mu.Unlock()
return nil, encodererr
}
r.mu.Lock()
r.typeEncoders[t] = enc
r.mu.Unlock()
return enc, nil
}
func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) {
enc, found := r.typeEncoders[t]
return enc, found
}
func (r *Registry) lookupInterfaceEncoder(t reflect.Type, allowAddr bool) (ValueEncoder, bool) {
if t == nil {
return nil, false
}
for _, ienc := range r.interfaceEncoders {
if t.Implements(ienc.i) {
return ienc.ve, true
}
if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(ienc.i) {
// if *t implements an interface, this will catch if t implements an interface further ahead
// in interfaceEncoders
defaultEnc, found := r.lookupInterfaceEncoder(t, false)
if !found {
defaultEnc = r.kindEncoders[t.Kind()]
}
return newCondAddrEncoder(ienc.ve, defaultEnc), true
}
}
return nil, false
}
// LookupDecoder inspects the registry for an decoder for the given type. The lookup precedence works as follows:
//
// 1. A decoder registered for the exact type. If the given type represents an interface, a decoder registered using
// RegisterTypeDecoder for the interface will be selected.
//
// 2. A decoder registered using RegisterHookDecoder for an interface implemented by the type or by a pointer to the
// type.
//
// 3. A decoder registered for the reflect.Kind of the value.
//
// If no decoder is found, an error of type ErrNoDecoder is returned.
func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) {
if t == nil {
return nil, ErrNilType
}
decodererr := ErrNoDecoder{Type: t}
r.mu.RLock()
dec, found := r.lookupTypeDecoder(t)
r.mu.RUnlock()
if found {
if dec == nil {
return nil, ErrNoDecoder{Type: t}
}
return dec, nil
}
dec, found = r.lookupInterfaceDecoder(t, true)
if found {
r.mu.Lock()
r.typeDecoders[t] = dec
r.mu.Unlock()
return dec, nil
}
dec, found = r.kindDecoders[t.Kind()]
if !found {
r.mu.Lock()
r.typeDecoders[t] = nil
r.mu.Unlock()
return nil, decodererr
}
r.mu.Lock()
r.typeDecoders[t] = dec
r.mu.Unlock()
return dec, nil
}
func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) {
dec, found := r.typeDecoders[t]
return dec, found
}
func (r *Registry) lookupInterfaceDecoder(t reflect.Type, allowAddr bool) (ValueDecoder, bool) {
for _, idec := range r.interfaceDecoders {
if t.Implements(idec.i) {
return idec.vd, true
}
if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(idec.i) {
// if *t implements an interface, this will catch if t implements an interface further ahead
// in interfaceDecoders
defaultDec, found := r.lookupInterfaceDecoder(t, false)
if !found {
defaultDec = r.kindDecoders[t.Kind()]
}
return newCondAddrDecoder(idec.vd, defaultDec), true
}
}
return nil, false
}
// LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON
// type. If no type is found, ErrNoTypeMapEntry is returned.
func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) {
t, ok := r.typeMap[bt]
if !ok || t == nil {
return nil, ErrNoTypeMapEntry{Type: bt}
}
return t, nil
}
type interfaceValueEncoder struct {
i reflect.Type
ve ValueEncoder
}
type interfaceValueDecoder struct {
i reflect.Type
vd ValueDecoder
}

View File

@@ -0,0 +1,125 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec_test
import (
"fmt"
"math"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
func ExampleRegistry_customEncoder() {
// Write a custom encoder for an integer type that is multiplied by -1 when
// encoding.
// To register the default encoders and decoders in addition to this custom
// one, use bson.NewRegistryBuilder instead.
rb := bsoncodec.NewRegistryBuilder()
type negatedInt int
niType := reflect.TypeOf(negatedInt(0))
encoder := func(
ec bsoncodec.EncodeContext,
vw bsonrw.ValueWriter,
val reflect.Value,
) error {
// All encoder implementations should check that val is valid and is of
// the correct type before proceeding.
if !val.IsValid() || val.Type() != niType {
return bsoncodec.ValueEncoderError{
Name: "negatedIntEncodeValue",
Types: []reflect.Type{niType},
Received: val,
}
}
// Negate val and encode as a BSON int32 if it can fit in 32 bits and a
// BSON int64 otherwise.
negatedVal := val.Int() * -1
if math.MinInt32 <= negatedVal && negatedVal <= math.MaxInt32 {
return vw.WriteInt32(int32(negatedVal))
}
return vw.WriteInt64(negatedVal)
}
rb.RegisterTypeEncoder(niType, bsoncodec.ValueEncoderFunc(encoder))
}
func ExampleRegistry_customDecoder() {
// Write a custom decoder for a boolean type that can be stored in the
// database as a BSON boolean, int32, int64, double, or null. BSON int32,
// int64, and double values are considered "true" in this decoder if they
// are non-zero. BSON null values are always considered false.
// To register the default encoders and decoders in addition to this custom
// one, use bson.NewRegistryBuilder instead.
rb := bsoncodec.NewRegistryBuilder()
type lenientBool bool
decoder := func(
dc bsoncodec.DecodeContext,
vr bsonrw.ValueReader,
val reflect.Value,
) error {
// All decoder implementations should check that val is valid, settable,
// and is of the correct kind before proceeding.
if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool {
return bsoncodec.ValueDecoderError{
Name: "lenientBoolDecodeValue",
Kinds: []reflect.Kind{reflect.Bool},
Received: val,
}
}
var result bool
switch vr.Type() {
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return err
}
result = b
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return err
}
result = i32 != 0
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return err
}
result = i64 != 0
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return err
}
result = f64 != 0
case bsontype.Null:
if err := vr.ReadNull(); err != nil {
return err
}
result = false
default:
return fmt.Errorf(
"received invalid BSON type to decode into lenientBool: %s",
vr.Type())
}
val.SetBool(result)
return nil
}
lbType := reflect.TypeOf(lenientBool(true))
rb.RegisterTypeDecoder(lbType, bsoncodec.ValueDecoderFunc(decoder))
}

View File

@@ -0,0 +1,452 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestRegistry(t *testing.T) {
t.Run("Register", func(t *testing.T) {
fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
t.Run("interface", func(t *testing.T) {
var t1f *testInterface1
var t2f *testInterface2
var t4f *testInterface4
ips := []interfaceValueEncoder{
{i: reflect.TypeOf(t1f).Elem(), ve: fc1},
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
}
want := []interfaceValueEncoder{
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
}
rb := NewRegistryBuilder()
for _, ip := range ips {
rb.RegisterHookEncoder(ip.i, ip.ve)
}
got := rb.interfaceEncoders
if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
t.Errorf("The registered interfaces are not correct. got %v; want %v", got, want)
}
})
t.Run("type", func(t *testing.T) {
ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
rb := NewRegistryBuilder().
RegisterTypeEncoder(reflect.TypeOf(ft1), fc1).
RegisterTypeEncoder(reflect.TypeOf(ft2), fc2).
RegisterTypeEncoder(reflect.TypeOf(ft1), fc3).
RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
want := []struct {
t reflect.Type
c ValueEncoder
}{
{reflect.TypeOf(ft1), fc3},
{reflect.TypeOf(ft2), fc2},
{reflect.TypeOf(ft4), fc4},
}
got := rb.typeEncoders
for _, s := range want {
wantT, wantC := s.t, s.c
gotC, exists := got[wantT]
if !exists {
t.Errorf("Did not find type in the type registry: %v", wantT)
}
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
}
}
})
t.Run("kind", func(t *testing.T) {
k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
rb := NewRegistryBuilder().
RegisterDefaultEncoder(k1, fc1).
RegisterDefaultEncoder(k2, fc2).
RegisterDefaultEncoder(k1, fc3).
RegisterDefaultEncoder(k4, fc4)
want := []struct {
k reflect.Kind
c ValueEncoder
}{
{k1, fc3},
{k2, fc2},
{k4, fc4},
}
got := rb.kindEncoders
for _, s := range want {
wantK, wantC := s.k, s.c
gotC, exists := got[wantK]
if !exists {
t.Errorf("Did not find kind in the kind registry: %v", wantK)
}
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
}
}
})
t.Run("RegisterDefault", func(t *testing.T) {
t.Run("MapCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Map, codec)
if rb.kindEncoders[reflect.Map] != codec {
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec)
}
rb.RegisterDefaultEncoder(reflect.Map, codec2)
if rb.kindEncoders[reflect.Map] != codec2 {
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec2)
}
})
t.Run("StructCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Struct, codec)
if rb.kindEncoders[reflect.Struct] != codec {
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec)
}
rb.RegisterDefaultEncoder(reflect.Struct, codec2)
if rb.kindEncoders[reflect.Struct] != codec2 {
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec2)
}
})
t.Run("SliceCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Slice, codec)
if rb.kindEncoders[reflect.Slice] != codec {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec)
}
rb.RegisterDefaultEncoder(reflect.Slice, codec2)
if rb.kindEncoders[reflect.Slice] != codec2 {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec2)
}
})
t.Run("ArrayCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Array, codec)
if rb.kindEncoders[reflect.Array] != codec {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec)
}
rb.RegisterDefaultEncoder(reflect.Array, codec2)
if rb.kindEncoders[reflect.Array] != codec2 {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec2)
}
})
})
t.Run("Lookup", func(t *testing.T) {
type Codec interface {
ValueEncoder
ValueDecoder
}
var arrinstance [12]int
arr := reflect.TypeOf(arrinstance)
slc := reflect.TypeOf(make([]int, 12))
m := reflect.TypeOf(make(map[string]int))
strct := reflect.TypeOf(struct{ Foo string }{})
ft1 := reflect.PtrTo(reflect.TypeOf(fakeType1{}))
ft2 := reflect.TypeOf(fakeType2{})
ft3 := reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
ti1 := reflect.TypeOf((*testInterface1)(nil)).Elem()
ti2 := reflect.TypeOf((*testInterface2)(nil)).Elem()
ti1Impl := reflect.TypeOf(testInterface1Impl{})
ti2Impl := reflect.TypeOf(testInterface2Impl{})
ti3 := reflect.TypeOf((*testInterface3)(nil)).Elem()
ti3Impl := reflect.TypeOf(testInterface3Impl{})
ti3ImplPtr := reflect.TypeOf((*testInterface3Impl)(nil))
fc1, fc2 := fakeCodec{num: 1}, fakeCodec{num: 2}
fsc, fslcc, fmc := new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
pc := NewPointerCodec()
reg := NewRegistryBuilder().
RegisterTypeEncoder(ft1, fc1).
RegisterTypeEncoder(ft2, fc2).
RegisterTypeEncoder(ti1, fc1).
RegisterDefaultEncoder(reflect.Struct, fsc).
RegisterDefaultEncoder(reflect.Slice, fslcc).
RegisterDefaultEncoder(reflect.Array, fslcc).
RegisterDefaultEncoder(reflect.Map, fmc).
RegisterDefaultEncoder(reflect.Ptr, pc).
RegisterTypeDecoder(ft1, fc1).
RegisterTypeDecoder(ft2, fc2).
RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder
RegisterDefaultDecoder(reflect.Struct, fsc).
RegisterDefaultDecoder(reflect.Slice, fslcc).
RegisterDefaultDecoder(reflect.Array, fslcc).
RegisterDefaultDecoder(reflect.Map, fmc).
RegisterDefaultDecoder(reflect.Ptr, pc).
RegisterHookEncoder(ti2, fc2).
RegisterHookDecoder(ti2, fc2).
RegisterHookEncoder(ti3, fc3).
RegisterHookDecoder(ti3, fc3).
Build()
testCases := []struct {
name string
t reflect.Type
wantcodec Codec
wanterr error
testcache bool
}{
{
"type registry (pointer)",
ft1,
fc1,
nil,
false,
},
{
"type registry (non-pointer)",
ft2,
fc2,
nil,
false,
},
{
// lookup an interface type and expect that the registered encoder is returned
"interface with type encoder",
ti1,
fc1,
nil,
true,
},
{
// lookup a type that implements an interface and expect that the default struct codec is returned
"interface implementation with type encoder",
ti1Impl,
fsc,
nil,
false,
},
{
// lookup an interface type and expect that the registered hook is returned
"interface with hook",
ti2,
fc2,
nil,
false,
},
{
// lookup a type that implements an interface and expect that the registered hook is returned
"interface implementation with hook",
ti2Impl,
fc2,
nil,
false,
},
{
// lookup a pointer to a type where the pointer implements an interface and expect that the
// registered hook is returned
"interface pointer to implementation with hook (pointer)",
ti3ImplPtr,
fc3,
nil,
false,
},
{
"default struct codec (pointer)",
reflect.PtrTo(strct),
pc,
nil,
false,
},
{
"default struct codec (non-pointer)",
strct,
fsc,
nil,
false,
},
{
"default array codec",
arr,
fslcc,
nil,
false,
},
{
"default slice codec",
slc,
fslcc,
nil,
false,
},
{
"default map",
m,
fmc,
nil,
false,
},
{
"map non-string key",
reflect.TypeOf(map[int]int{}),
fmc,
nil,
false,
},
{
"No Codec Registered",
ft3,
nil,
ErrNoEncoder{Type: ft3},
false,
},
}
allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
gotcodec, goterr := reg.LookupEncoder(tc.t)
if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr)
}
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
}
})
t.Run("Decoder", func(t *testing.T) {
var wanterr error
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
wanterr = ErrNoDecoder(ene)
} else {
wanterr = tc.wanterr
}
gotcodec, goterr := reg.LookupDecoder(tc.t)
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Errors did not match. got %v; want %v", goterr, wanterr)
}
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
t.Errorf("Codecs did not match. got %T; want %T", gotcodec, tc.wantcodec)
}
})
})
}
// lookup a type whose pointer implements an interface and expect that the registered hook is
// returned
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
gotEnc, err := reg.LookupEncoder(ti3Impl)
assert.Nil(t, err, "LookupEncoder error: %v", err)
cae, ok := gotEnc.(*condAddrEncoder)
assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected canAddrEnc %v, got %v", cae.canAddrEnc, fc3)
}
if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected elseEnc %v, got %v", cae.elseEnc, fsc)
}
})
t.Run("Decoder", func(t *testing.T) {
gotDec, err := reg.LookupDecoder(ti3Impl)
assert.Nil(t, err, "LookupDecoder error: %v", err)
cad, ok := gotDec.(*condAddrDecoder)
assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected canAddrDec %v, got %v", cad.canAddrDec, fc3)
}
if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected elseDec %v, got %v", cad.elseDec, fsc)
}
})
})
})
})
t.Run("Type Map", func(t *testing.T) {
reg := NewRegistryBuilder().
RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")).
RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))).
Build()
var got, want reflect.Type
want = reflect.TypeOf("")
got, err := reg.LookupTypeMapEntry(bsontype.String)
noerr(t, err)
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
want = reflect.TypeOf(int(0))
got, err = reg.LookupTypeMapEntry(bsontype.Int32)
noerr(t, err)
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
want = nil
wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
if err != wanterr {
t.Errorf("Did not get expected error. got %v; want %v", err, wanterr)
}
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
})
}
type fakeType1 struct{}
type fakeType2 struct{}
type fakeType4 struct{}
type fakeType5 func(string, string) string
type fakeStructCodec struct{ fakeCodec }
type fakeSliceCodec struct{ fakeCodec }
type fakeMapCodec struct{ fakeCodec }
type fakeCodec struct{ num int }
func (fc fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
return nil
}
func (fc fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
return nil
}
type testInterface1 interface{ test1() }
type testInterface2 interface{ test2() }
type testInterface3 interface{ test3() }
type testInterface4 interface{ test4() }
type testInterface1Impl struct{}
var _ testInterface1 = testInterface1Impl{}
func (testInterface1Impl) test1() {}
type testInterface2Impl struct{}
var _ testInterface2 = testInterface2Impl{}
func (testInterface2Impl) test2() {}
type testInterface3Impl struct{}
var _ testInterface3 = (*testInterface3Impl)(nil)
func (*testInterface3Impl) test3() {}
func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }

View File

@@ -0,0 +1,199 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var defaultSliceCodec = NewSliceCodec()
// SliceCodec is the Codec used for slice values.
type SliceCodec struct {
EncodeNilAsEmpty bool
}
var _ ValueCodec = &MapCodec{}
// NewSliceCodec returns a MapCodec with options opts.
func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec {
sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...)
codec := SliceCodec{}
if sliceOpt.EncodeNilAsEmpty != nil {
codec.EncodeNilAsEmpty = *sliceOpt.EncodeNilAsEmpty
}
return &codec
}
// EncodeValue is the ValueEncoder for slice types.
func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() && !sc.EncodeNilAsEmpty {
return vw.WriteNull()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// DecodeValue is the ValueDecoder for slice types.
func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Slice {
return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
switch vrType := vr.Type(); vrType {
case bsontype.Array:
case bsontype.Null:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case bsontype.Undefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
case bsontype.Type(0), bsontype.EmbeddedDocument:
if val.Type().Elem() != tE {
return fmt.Errorf("cannot decode document into %s", val.Type())
}
case bsontype.Binary:
if val.Type().Elem() != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType)
}
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld {
return fmt.Errorf("SliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype)
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(data)))
}
val.SetLen(0)
for _, elem := range data {
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
}
return nil
case bsontype.String:
if sliceType := val.Type().Elem(); sliceType != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a string into a byte array, got %v", sliceType)
}
str, err := vr.ReadString()
if err != nil {
return err
}
byteStr := []byte(str)
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(byteStr)))
}
val.SetLen(0)
for _, elem := range byteStr {
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
}
return nil
default:
return fmt.Errorf("cannot decode %v into a slice", vrType)
}
var elemsFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) ([]reflect.Value, error)
switch val.Type().Elem() {
case tE:
dc.Ancestor = val.Type()
elemsFunc = defaultValueDecoders.decodeD
default:
elemsFunc = defaultValueDecoders.decodeDefault
}
elems, err := elemsFunc(dc, vr, val)
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(elems)))
}
val.SetLen(0)
val.Set(reflect.Append(val, elems...))
return nil
}

View File

@@ -0,0 +1,119 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// StringCodec is the Codec used for struct values.
type StringCodec struct {
DecodeObjectIDAsHex bool
}
var (
defaultStringCodec = NewStringCodec()
_ ValueCodec = defaultStringCodec
_ typeDecoder = defaultStringCodec
)
// NewStringCodec returns a StringCodec with options opts.
func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec {
stringOpt := bsonoptions.MergeStringCodecOptions(opts...)
return &StringCodec{*stringOpt.DecodeObjectIDAsHex}
}
// EncodeValue is the ValueEncoder for string types.
func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
func (sc *StringCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t.Kind() != reflect.String {
return emptyValue, ValueDecoderError{
Name: "StringDecodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: reflect.Zero(t),
}
}
var str string
var err error
switch vr.Type() {
case bsontype.String:
str, err = vr.ReadString()
if err != nil {
return emptyValue, err
}
case bsontype.ObjectID:
oid, err := vr.ReadObjectID()
if err != nil {
return emptyValue, err
}
if sc.DecodeObjectIDAsHex {
str = oid.Hex()
} else {
byteArray := [12]byte(oid)
str = string(byteArray[:])
}
case bsontype.Symbol:
str, err = vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
case bsontype.Binary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "string"}
}
str = string(data)
case bsontype.Null:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case bsontype.Undefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
return reflect.ValueOf(str), nil
}
// DecodeValue is the ValueDecoder for string types.
func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
elem, err := sc.decodeType(dctx, vr, val.Type())
if err != nil {
return err
}
val.SetString(elem.String())
return nil
}

View File

@@ -0,0 +1,48 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestStringCodec(t *testing.T) {
t.Run("ObjectIDAsHex", func(t *testing.T) {
oid := primitive.NewObjectID()
byteArray := [12]byte(oid)
reader := &bsonrwtest.ValueReaderWriter{BSONType: bsontype.ObjectID, Return: oid}
testCases := []struct {
name string
opts *bsonoptions.StringCodecOptions
hex bool
result string
}{
{"default", bsonoptions.StringCodec(), true, oid.Hex()},
{"true", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(true), true, oid.Hex()},
{"false", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false), false, string(byteArray[:])},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
stringCodec := NewStringCodec(tc.opts)
actual := reflect.New(reflect.TypeOf("")).Elem()
err := stringCodec.DecodeValue(DecodeContext{}, reader, actual)
assert.Nil(t, err, "StringCodec.DecodeValue error: %v", err)
actualString := actual.Interface().(string)
assert.Equal(t, tc.result, actualString, "Expected string %v, got %v", tc.result, actualString)
})
}
})
}

View File

@@ -0,0 +1,675 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
type DecodeError struct {
keys []string
wrapped error
}
// Unwrap returns the underlying error
func (de *DecodeError) Unwrap() error {
return de.wrapped
}
// Error implements the error interface.
func (de *DecodeError) Error() string {
// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
// stack of BSON keys, so we call de.Keys(), which reverses them.
keyPath := strings.Join(de.Keys(), ".")
return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
}
// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
// a string, the keys slice will be ["a", "b", "c"].
func (de *DecodeError) Keys() []string {
reversedKeys := make([]string, 0, len(de.keys))
for idx := len(de.keys) - 1; idx >= 0; idx-- {
reversedKeys = append(reversedKeys, de.keys[idx])
}
return reversedKeys
}
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// StructCodec is the Codec used for struct values.
type StructCodec struct {
cache map[reflect.Type]*structDescription
l sync.RWMutex
parser StructTagParser
DecodeZeroStruct bool
DecodeDeepZeroInline bool
EncodeOmitDefaultStruct bool
AllowUnexportedFields bool
OverwriteDuplicatedInlinedFields bool
}
var _ ValueEncoder = &StructCodec{}
var _ ValueDecoder = &StructCodec{}
// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
if p == nil {
return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
}
structOpt := bsonoptions.MergeStructCodecOptions(opts...)
codec := &StructCodec{
cache: make(map[reflect.Type]*structDescription),
parser: p,
}
if structOpt.DecodeZeroStruct != nil {
codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
}
if structOpt.DecodeDeepZeroInline != nil {
codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
}
if structOpt.EncodeOmitDefaultStruct != nil {
codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
}
if structOpt.OverwriteDuplicatedInlinedFields != nil {
codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
}
if structOpt.AllowUnexportedFields != nil {
codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
}
return codec, nil
}
// EncodeValue handles encoding generic struct types.
func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Struct {
return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
sd, err := sc.describeStruct(r.Registry, val.Type())
if err != nil {
return err
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
var rv reflect.Value
for _, desc := range sd.fl {
if desc.omitAlways {
continue
}
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
rv, err = fieldByIndexErr(val, desc.inline)
if err != nil {
continue
}
}
desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv)
if err != nil && err != errInvalidValue {
return err
}
if err == errInvalidValue {
if desc.omitEmpty {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
err = vw2.WriteNull()
if err != nil {
return err
}
continue
}
if desc.encoder == nil {
return ErrNoEncoder{Type: rv.Type()}
}
encoder := desc.encoder
var isZero bool
rvInterface := rv.Interface()
if cz, ok := encoder.(CodecZeroer); ok {
isZero = cz.IsTypeZero(rvInterface)
} else if rv.Kind() == reflect.Interface {
// sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately.
isZero = rv.IsNil()
} else {
isZero = sc.isZero(rvInterface)
}
if desc.omitEmpty && isZero {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
err = encoder.EncodeValue(ectx, vw2, rv)
if err != nil {
return err
}
}
if sd.inlineMap >= 0 {
rv := val.Field(sd.inlineMap)
collisionFn := func(key string) bool {
_, exists := sd.fm[key]
return exists
}
return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn)
}
return dw.WriteDocumentEnd()
}
func newDecodeError(key string, original error) error {
de, ok := original.(*DecodeError)
if !ok {
return &DecodeError{
keys: []string{key},
wrapped: original,
}
}
de.keys = append(de.keys, key)
return de
}
// DecodeValue implements the Codec interface.
// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Struct {
return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
switch vrType := vr.Type(); vrType {
case bsontype.Type(0), bsontype.EmbeddedDocument:
case bsontype.Null:
if err := vr.ReadNull(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
case bsontype.Undefined:
if err := vr.ReadUndefined(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
sd, err := sc.describeStruct(r.Registry, val.Type())
if err != nil {
return err
}
if sc.DecodeZeroStruct {
val.Set(reflect.Zero(val.Type()))
}
if sc.DecodeDeepZeroInline && sd.inline {
val.Set(deepZero(val.Type()))
}
var decoder ValueDecoder
var inlineMap reflect.Value
if sd.inlineMap >= 0 {
inlineMap = val.Field(sd.inlineMap)
decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
if err != nil {
return err
}
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
for {
name, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
fd, exists := sd.fm[name]
if !exists {
// if the original name isn't found in the struct description, try again with the name in lowercase
// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
// names
fd, exists = sd.fm[strings.ToLower(name)]
}
if !exists {
if sd.inlineMap < 0 {
// The encoding/json package requires a flag to return on error for non-existent fields.
// This functionality seems appropriate for the struct codec.
err = vr.Skip()
if err != nil {
return err
}
continue
}
if inlineMap.IsNil() {
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
}
elem := reflect.New(inlineMap.Type().Elem()).Elem()
r.Ancestor = inlineMap.Type()
err = decoder.DecodeValue(r, vr, elem)
if err != nil {
return err
}
inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
continue
}
var field reflect.Value
if fd.inline == nil {
field = val.Field(fd.idx)
} else {
field, err = getInlineField(val, fd.inline)
if err != nil {
return err
}
}
if !field.CanSet() { // Being settable is a super set of being addressable.
innerErr := fmt.Errorf("field %v is not settable", field)
return newDecodeError(fd.name, innerErr)
}
if field.Kind() == reflect.Ptr && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Addr()
dctx := DecodeContext{
Registry: r.Registry,
Truncate: fd.truncate || r.Truncate,
defaultDocumentType: r.defaultDocumentType,
}
if fd.decoder == nil {
return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
}
err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
if err != nil {
return newDecodeError(fd.name, err)
}
}
return nil
}
func (sc *StructCodec) isZero(i interface{}) bool {
v := reflect.ValueOf(i)
// check the value validity
if !v.IsValid() {
return true
}
if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
return z.IsZero()
}
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
case reflect.Struct:
if sc.EncodeOmitDefaultStruct {
vt := v.Type()
if vt == tTime {
return v.Interface().(time.Time).IsZero()
}
for i := 0; i < v.NumField(); i++ {
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
continue // Private field
}
fld := v.Field(i)
if !sc.isZero(fld.Interface()) {
return false
}
}
return true
}
}
return false
}
type structDescription struct {
fm map[string]fieldDescription
fl []fieldDescription
inlineMap int
inline bool
}
type fieldDescription struct {
name string // BSON key name
fieldName string // struct field name
idx int
omitEmpty bool
omitAlways bool
minSize bool
truncate bool
inline []int
encoder ValueEncoder
decoder ValueDecoder
}
type byIndex []fieldDescription
func (bi byIndex) Len() int { return len(bi) }
func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
func (bi byIndex) Less(i, j int) bool {
// If a field is inlined, its index in the top level struct is stored at inline[0]
iIdx, jIdx := bi[i].idx, bi[j].idx
if len(bi[i].inline) > 0 {
iIdx = bi[i].inline[0]
}
if len(bi[j].inline) > 0 {
jIdx = bi[j].inline[0]
}
if iIdx != jIdx {
return iIdx < jIdx
}
for k, biik := range bi[i].inline {
if k >= len(bi[j].inline) {
return false
}
if biik != bi[j].inline[k] {
return biik < bi[j].inline[k]
}
}
return len(bi[i].inline) < len(bi[j].inline)
}
func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
// We need to analyze the struct, including getting the tags, collecting
// information about inlining, and create a map of the field name to the field.
sc.l.RLock()
ds, exists := sc.cache[t]
sc.l.RUnlock()
if exists {
return ds, nil
}
numFields := t.NumField()
sd := &structDescription{
fm: make(map[string]fieldDescription, numFields),
fl: make([]fieldDescription, 0, numFields),
inlineMap: -1,
}
var fields []fieldDescription
for i := 0; i < numFields; i++ {
sf := t.Field(i)
if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
// field is private or unexported fields aren't allowed, ignore
continue
}
sfType := sf.Type
encoder, err := r.LookupEncoder(sfType)
if err != nil {
encoder = nil
}
decoder, err := r.LookupDecoder(sfType)
if err != nil {
decoder = nil
}
description := fieldDescription{
fieldName: sf.Name,
idx: i,
encoder: encoder,
decoder: decoder,
}
stags, err := sc.parser.ParseStructTags(sf)
if err != nil {
return nil, err
}
if stags.Skip {
continue
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
description.omitAlways = stags.OmitAlways
description.minSize = stags.MinSize
description.truncate = stags.Truncate
if stags.Inline {
sd.inline = true
switch sfType.Kind() {
case reflect.Map:
if sd.inlineMap >= 0 {
return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
}
if sfType.Key() != tString {
return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
}
sd.inlineMap = description.idx
case reflect.Ptr:
sfType = sfType.Elem()
if sfType.Kind() != reflect.Struct {
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
fallthrough
case reflect.Struct:
inlinesf, err := sc.describeStruct(r, sfType)
if err != nil {
return nil, err
}
for _, fd := range inlinesf.fl {
if fd.inline == nil {
fd.inline = []int{i, fd.idx}
} else {
fd.inline = append([]int{i}, fd.inline...)
}
fields = append(fields, fd)
}
default:
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
continue
}
fields = append(fields, description)
}
// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
sort.Slice(fields, func(i, j int) bool {
x := fields
// sort field by name, breaking ties with depth, then
// breaking ties with index sequence.
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].inline) != len(x[j].inline) {
return len(x[i].inline) < len(x[j].inline)
}
return byIndex(x).Less(i, j)
})
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
sd.fl = append(sd.fl, fi)
sd.fm[name] = fi
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if !ok || !sc.OverwriteDuplicatedInlinedFields {
return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
}
sd.fl = append(sd.fl, dominant)
sd.fm[name] = dominant
}
sort.Sort(byIndex(sd.fl))
sc.l.Lock()
sc.cache[t] = sd
sc.l.Unlock()
return sd, nil
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's inlining rules. If there are multiple top-level
// fields, the boolean will be false: This condition is an error in Go
// and we skip all the fields.
func dominantField(fields []fieldDescription) (fieldDescription, bool) {
// The fields are sorted in increasing index-length order, then by presence of tag.
// That means that the first field is the dominant one. We need only check
// for error cases: two fields at top level.
if len(fields) > 1 &&
len(fields[0].inline) == len(fields[1].inline) {
return fieldDescription{}, false
}
return fields[0], true
}
func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
defer func() {
if recovered := recover(); recovered != nil {
switch r := recovered.(type) {
case string:
err = fmt.Errorf("%s", r)
case error:
err = r
}
}
}()
result = v.FieldByIndex(index)
return
}
func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
field, err := fieldByIndexErr(val, index)
if err == nil {
return field, nil
}
// if parent of this element doesn't exist, fix its parent
inlineParent := index[:len(index)-1]
var fParent reflect.Value
if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
fParent, err = getInlineField(val, inlineParent)
if err != nil {
return fParent, err
}
}
fParent.Set(reflect.New(fParent.Type().Elem()))
return fieldByIndexErr(val, index)
}
// DeepZero returns recursive zero object
func deepZero(st reflect.Type) (result reflect.Value) {
result = reflect.Indirect(reflect.New(st))
if result.Kind() == reflect.Struct {
for i := 0; i < result.NumField(); i++ {
if f := result.Field(i); f.Kind() == reflect.Ptr {
if f.CanInterface() {
if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
}
}
}
}
}
return
}
// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
func recursivePointerTo(v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
result := reflect.New(v.Type())
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
if f := v.Field(i); f.Kind() == reflect.Ptr {
if f.Elem().Kind() == reflect.Struct {
result.Elem().Field(i).Set(recursivePointerTo(f))
}
}
}
}
return result
}

View File

@@ -0,0 +1,47 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestZeoerInterfaceUsedByDecoder(t *testing.T) {
enc := &StructCodec{}
// cases that are zero, because they are known types or pointers
var st *nonZeroer
assert.True(t, enc.isZero(st))
assert.True(t, enc.isZero(0))
assert.True(t, enc.isZero(false))
// cases that shouldn't be zero
st = &nonZeroer{value: false}
assert.False(t, enc.isZero(struct{ val bool }{val: true}))
assert.False(t, enc.isZero(struct{ val bool }{val: false}))
assert.False(t, enc.isZero(st))
st.value = true
assert.False(t, enc.isZero(st))
// a test to see if the interface impacts the outcome
z := zeroTest{}
assert.False(t, enc.isZero(z))
z.reportZero = true
assert.True(t, enc.isZero(z))
// *time.Time with nil should be zero
var tp *time.Time
assert.True(t, enc.isZero(tp))
// actually all zeroer if nil should also be zero
var zp *zeroTest
assert.True(t, enc.isZero(zp))
}

View File

@@ -0,0 +1,142 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"strings"
)
// StructTagParser returns the struct tags for a given struct field.
type StructTagParser interface {
ParseStructTags(reflect.StructField) (StructTags, error)
}
// StructTagParserFunc is an adapter that allows a generic function to be used
// as a StructTagParser.
type StructTagParserFunc func(reflect.StructField) (StructTags, error)
// ParseStructTags implements the StructTagParser interface.
func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) {
return stpf(sf)
}
// StructTags represents the struct tag fields that the StructCodec uses during
// the encoding and decoding process.
//
// In the case of a struct, the lowercased field name is used as the key for each exported
// field but this behavior may be changed using a struct tag. The tag may also contain flags to
// adjust the marshalling behavior for the field.
//
// The properties are defined below:
//
// OmitEmpty Only include the field if it's not set to the zero value for the type or to
// empty slices or maps.
//
// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's
// feasible while preserving the numeric value.
//
// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within
// a float32.
//
// Inline Inline the field, which must be a struct or a map, causing all of its fields
// or keys to be processed as if they were part of the outer struct. For maps,
// keys must not conflict with the bson keys of other struct fields.
//
// Skip This struct field should be skipped. This is usually denoted by parsing a "-"
// for the name.
//
// TODO(skriptble): Add tags for undefined as nil and for null as nil.
type StructTags struct {
Name string
OmitEmpty bool
OmitAlways bool
MinSize bool
Truncate bool
Inline bool
Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
// It will handle the bson struct tag. See the documentation for StructTags to see
// what each of the returned fields means.
//
// If there is no name in the struct tag fields, the struct field name is lowercased.
// The tag formats accepted are:
//
// "[<key>][,<flag1>[,<flag2>]]"
//
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
//
// An example:
//
// type T struct {
// A bool
// B int "myb"
// C string "myc,omitempty"
// D string `bson:",omitempty" json:"jsonkey"`
// E int64 ",minsize"
// F int64 "myf,omitempty,minsize"
// }
//
// A struct tag either consisting entirely of '-' or with a bson key with a
// value consisting entirely of '-' will return a StructTags with Skip true and
// the remaining fields will be their default values.
var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}
func parseTags(key string, tag string) (StructTags, error) {
var st StructTags
if tag == "-" {
st.Skip = true
return st, nil
}
for idx, str := range strings.Split(tag, ",") {
if idx == 0 && str != "" {
key = str
}
switch str {
case "omitempty":
st.OmitEmpty = true
case "omitalways":
st.OmitAlways = true
case "minsize":
st.MinSize = true
case "truncate":
st.Truncate = true
case "inline":
st.Inline = true
}
}
st.Name = key
return st, nil
}
// JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser
// but will also fallback to parsing the json tag instead on a field where the
// bson tag isn't available.
var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok {
tag, ok = sf.Tag.Lookup("json")
}
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}

Some files were not shown because too many files have changed in this diff Show More