Compare commits

..

3 Commits

Author SHA1 Message Date
9d88ab3a2b mongo data 2023-06-22 15:06:07 +02:00
2ec88e81f3 Patch mongo (add omitalways) 2023-06-18 16:04:34 +02:00
d471d7c396 Copied mongo repo (to patch it) 2023-06-18 15:52:17 +02:00
624 changed files with 142491 additions and 4997 deletions

View File

@@ -1,9 +0,0 @@
FROM golang:latest
RUN apt install -y make curl python3 && go install gotest.tools/gotestsum@latest
COPY . /source
WORKDIR /source
CMD ["make", "test"]

View File

@@ -1,30 +0,0 @@
# https://docs.gitea.com/next/usage/actions/quickstart
# https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions
# https://docs.github.com/en/actions/learn-github-actions/contexts#github-context
name: Build Docker and Deploy
run-name: Build & Deploy ${{ gitea.ref }} on ${{ gitea.actor }}
on: [push]
jobs:
run_tests:
name: Run goext test-suite
runs-on: bfb-cicd-latest
steps:
- name: Check out code
uses: actions/checkout@v3
- name: Build test docker
id: build_docker
run: echo "DOCKER_IMG_ID=$(docker build -q . -f .gitea/workflows/Dockerfile_tests || echo __err_build__)" >> $GITHUB_OUTPUT
- name: Run tests
run: docker run --rm "${{ steps.build_docker.outputs.DOCKER_IMG_ID }}"
- name: Cleanup
if: always()
run: docker image rm "${{ steps.build_docker.outputs.DOCKER_IMG_ID }}"

View File

@@ -1,17 +1,16 @@
.PHONY: run test version update-mongo
run:
echo "This is a library - can't be run" && false
test:
# go test ./...
which gotestsum || go install gotest.tools/gotestsum@latest
gotestsum --format "testname" -- -tags="timetzdata sqlite_fts5 sqlite_foreign_keys" "./..."
test-in-docker:
tag="goext_temp_test_image:$(shell uuidgen | tr -d '-')"; \
docker build --tag $$tag . -f .gitea/workflows/Dockerfile_tests; \
docker run --rm $$tag; \
docker rmi $$tag
gotestsum --format "testname" -- -tags="timetzdata sqlite_fts5 sqlite_foreign_keys" "./test"
version:
_data/version.sh
update-mongo:
_data/update-mongo.sh

12
TODO.md
View File

@@ -2,6 +2,12 @@
- cronext
- rfctime.DateOnly
- rfctime.HMSTimeOnly
- rfctime.NanoTimeOnly
- cursortoken
- typed/geenric mongo wrapper
- error package
- rfctime.DateOnly
- rfctime.HMSTimeOnly
- rfctime.NanoTimeOnly

80
_data/mongo.patch Normal file
View File

@@ -0,0 +1,80 @@
diff --git a/mongo/bson/bsoncodec/struct_codec.go b/mongo/bson/bsoncodec/struct_codec.go
--- a/mongo/bson/bsoncodec/struct_codec.go
+++ b/mongo/bson/bsoncodec/struct_codec.go
@@ -122,6 +122,10 @@ func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val r
}
var rv reflect.Value
for _, desc := range sd.fl {
+ if desc.omitAlways {
+ continue
+ }
+
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
@@ -400,15 +404,16 @@ type structDescription struct {
}
type fieldDescription struct {
- name string // BSON key name
- fieldName string // struct field name
- idx int
- omitEmpty bool
- minSize bool
- truncate bool
- inline []int
- encoder ValueEncoder
- decoder ValueDecoder
+ name string // BSON key name
+ fieldName string // struct field name
+ idx int
+ omitEmpty bool
+ omitAlways bool
+ minSize bool
+ truncate bool
+ inline []int
+ encoder ValueEncoder
+ decoder ValueDecoder
}
type byIndex []fieldDescription
@@ -491,6 +496,7 @@ func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescr
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
+ description.omitAlways = stags.OmitAlways
description.minSize = stags.MinSize
description.truncate = stags.Truncate
diff --git a/mongo/bson/bsoncodec/struct_tag_parser.go b/mongo/bson/bsoncodec/struct_tag_parser.go
--- a/mongo/bson/bsoncodec/struct_tag_parser.go
+++ b/mongo/bson/bsoncodec/struct_tag_parser.go
@@ -52,12 +52,13 @@ func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructT
//
// TODO(skriptble): Add tags for undefined as nil and for null as nil.
type StructTags struct {
- Name string
- OmitEmpty bool
- MinSize bool
- Truncate bool
- Inline bool
- Skip bool
+ Name string
+ OmitEmpty bool
+ OmitAlways bool
+ MinSize bool
+ Truncate bool
+ Inline bool
+ Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
@@ -108,6 +109,8 @@ func parseTags(key string, tag string) (StructTags, error) {
switch str {
case "omitempty":
st.OmitEmpty = true
+ case "omitalways":
+ st.OmitAlways = true
case "minsize":
st.MinSize = true
case "truncate":

95
_data/update-mongo.sh Executable file
View File

@@ -0,0 +1,95 @@
#!/bin/bash
set -o nounset # disallow usage of unset vars ( set -u )
set -o errexit # Exit immediately if a pipeline returns non-zero. ( set -e )
set -o errtrace # Allow the above trap be inherited by all functions in the script. ( set -E )
set -o pipefail # Return value of a pipeline is the value of the last (rightmost) command to exit with a non-zero status
IFS=$'\n\t' # Set $IFS to only newline and tab.
dir="/tmp/mongo_repo_$( uuidgen )"
echo ""
echo "> Clone https://github.dev/mongodb/mongo-go-driver"
echo ""
git clone "https://github.com/mongodb/mongo-go-driver" "$dir"
pushd "$dir"
git fetch --tags
latestTag="$( git describe --tags `git rev-list --tags --max-count=1` )"
git -c "advice.detachedHead=false" checkout $latestTag
latestSHA="$( git rev-parse HEAD )"
popd
existingTag=$( cat mongoPatchVersion.go | grep -oP "(?<=const MongoCloneTag = \")([A-Za-z0-9.]+)(?=\")" )
existingSHA=$( cat mongoPatchVersion.go | grep -oP "(?<=const MongoCloneCommit = \")([A-Za-z0-9.]+)(?=\")" )
echo "===================================="
echo "ID (online) $latestSHA"
echo "ID (local) $existingSHA"
echo "Tag (online) $latestTag"
echo "Tag (local) $existingTag"
echo "===================================="
if [[ "$latestTag" == "$existingTag" ]]; then
echo "Nothing to do"
rm -rf "$dir"
exit 0
fi
echo ""
echo "> Copy repository"
echo ""
rm -rf mongo
cp -r "$dir" "mongo"
rm -rf "$dir"
echo ""
echo "> Clean repository"
echo ""
rm -rf "mongo/.git"
rm -rf "mongo/.evergreen"
rm -rf "mongo/cmd"
rm -rf "mongo/docs"
rm -rf "mongo/etc"
rm -rf "mongo/examples"
rm -rf "mongo/testdata"
rm -rf "mongo/benchmark"
rm -rf "mongo/vendor"
rm -rf "mongo/internal/test"
rm -rf "mongo/go.mod"
rm -rf "mongo/go.sum"
echo ""
echo "> Update mongoPatchVersion.go"
echo ""
{
printf "package goext\n"
printf "\n"
printf "// %s\n" "$( date +"%Y-%m-%d %H:%M:%S%z" )"
printf "\n"
printf "const MongoCloneTag = \"%s\"\n" "$latestTag"
printf "const MongoCloneCommit = \"%s\"\n" "$latestSHA"
} > mongoPatchVersion.go
echo ""
echo "> Patch mongo"
echo ""
git apply -v _data/mongo.patch
echo ""
echo "Done."

View File

@@ -21,11 +21,6 @@ if [ "$( git rev-parse --abbrev-ref HEAD )" != "master" ]; then
exit 1
fi
echo ""
echo -n "Insert optional commit message: "
read commitMessage
echo ""
git pull --ff
go get -u ./...
@@ -45,11 +40,6 @@ git add --verbose .
msg="v${next_ver}"
if [[ "$commitMessage" != "" ]]; then
msg="${msg} ${commitMessage}"
fi
if [ $# -gt 0 ]; then
msg="$1"
fi

Binary file not shown.

View File

@@ -31,13 +31,13 @@ type EnumDef struct {
Values []EnumDefVal
}
var rexEnumPackage = rext.W(regexp.MustCompile(`^package\s+(?P<name>[A-Za-z0-9_]+)\s*$`))
var rexPackage = rext.W(regexp.MustCompile("^package\\s+(?P<name>[A-Za-z0-9_]+)\\s*$"))
var rexEnumDef = rext.W(regexp.MustCompile(`^\s*type\s+(?P<name>[A-Za-z0-9_]+)\s+(?P<type>[A-Za-z0-9_]+)\s*//\s*(@enum:type).*$`))
var rexEnumDef = rext.W(regexp.MustCompile("^\\s*type\\s+(?P<name>[A-Za-z0-9_]+)\\s+(?P<type>[A-Za-z0-9_]+)\\s*//\\s*(@enum:type).*$"))
var rexEnumValueDef = rext.W(regexp.MustCompile(`^\s*(?P<name>[A-Za-z0-9_]+)\s+(?P<type>[A-Za-z0-9_]+)\s*=\s*(?P<value>("[A-Za-z0-9_:\s]+"|[0-9]+))\s*(//(?P<descr>.*))?.*$`))
var rexValueDef = rext.W(regexp.MustCompile("^\\s*(?P<name>[A-Za-z0-9_]+)\\s+(?P<type>[A-Za-z0-9_]+)\\s*=\\s*(?P<value>(\"[A-Za-z0-9_:]+\"|[0-9]+))\\s*(//(?P<descr>.*))?.*$"))
var rexEnumChecksumConst = rext.W(regexp.MustCompile(`const ChecksumEnumGenerator = "(?P<cs>[A-Za-z0-9_]*)"`))
var rexChecksumConst = rext.W(regexp.MustCompile("const ChecksumGenerator = \"(?P<cs>[A-Za-z0-9_]*)\""))
func GenerateEnumSpecs(sourceDir string, destFile string) error {
@@ -52,14 +52,13 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error {
if err != nil {
return err
}
if m, ok := rexEnumChecksumConst.MatchFirst(string(content)); ok {
if m, ok := rexChecksumConst.MatchFirst(string(content)); ok {
oldChecksum = m.GroupByName("cs").Value()
}
}
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return v.Name() != path.Base(destFile) })
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return strings.HasSuffix(v.Name(), ".go") })
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return !strings.HasSuffix(v.Name(), "_gen.go") })
langext.SortBy(files, func(v os.DirEntry) string { return v.Name() })
newChecksumStr := goext.GoextVersion
@@ -86,7 +85,7 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error {
for _, f := range files {
fmt.Printf("========= %s =========\n\n", f.Name())
fileEnums, pn, err := processEnumFile(sourceDir, path.Join(sourceDir, f.Name()))
fileEnums, pn, err := processFile(sourceDir, path.Join(sourceDir, f.Name()))
if err != nil {
return err
}
@@ -104,7 +103,7 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error {
return errors.New("no package name found in any file")
}
err = os.WriteFile(destFile, []byte(fmtEnumOutput(newChecksum, allEnums, pkgname)), 0o755)
err = os.WriteFile(destFile, []byte(fmtOutput(newChecksum, allEnums, pkgname)), 0o755)
if err != nil {
return err
}
@@ -126,7 +125,7 @@ func GenerateEnumSpecs(sourceDir string, destFile string) error {
return nil
}
func processEnumFile(basedir string, fn string) ([]EnumDef, string, error) {
func processFile(basedir string, fn string) ([]EnumDef, string, error) {
file, err := os.Open(fn)
if err != nil {
return nil, "", err
@@ -150,7 +149,7 @@ func processEnumFile(basedir string, fn string) ([]EnumDef, string, error) {
break
}
if match, ok := rexEnumPackage.MatchFirst(line); i == 0 && ok {
if match, ok := rexPackage.MatchFirst(line); i == 0 && ok {
pkgname = match.GroupByName("name").Value()
continue
}
@@ -173,7 +172,7 @@ func processEnumFile(basedir string, fn string) ([]EnumDef, string, error) {
fmt.Printf("Found enum definition { '%s' -> '%s' }\n", def.EnumTypeName, def.Type)
}
if match, ok := rexEnumValueDef.MatchFirst(line); ok {
if match, ok := rexValueDef.MatchFirst(line); ok {
typename := match.GroupByName("type").Value()
def := EnumDefVal{
VarName: match.GroupByName("name").Value(),
@@ -203,17 +202,43 @@ func processEnumFile(basedir string, fn string) ([]EnumDef, string, error) {
return enums, pkgname, nil
}
func fmtEnumOutput(cs string, enums []EnumDef, pkgname string) string {
func fmtOutput(cs string, enums []EnumDef, pkgname string) string {
str := "// Code generated by enum-generate.go DO NOT EDIT.\n"
str += "\n"
str += "package " + pkgname + "\n"
str += "\n"
str += "import \"gogs.mikescher.com/BlackForestBytes/goext/langext\"" + "\n"
str += "import \"gogs.mikescher.com/BlackForestBytes/goext/enums\"" + "\n"
str += "\n"
str += "const ChecksumEnumGenerator = \"" + cs + "\" // GoExtVersion: " + goext.GoextVersion + "\n"
str += "const ChecksumGenerator = \"" + cs + "\"" + "\n"
str += "\n"
str += "type Enum interface {" + "\n"
str += " Valid() bool" + "\n"
str += " ValuesAny() []any" + "\n"
str += " ValuesMeta() []EnumMetaValue" + "\n"
str += " VarName() string" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "type StringEnum interface {" + "\n"
str += " Enum" + "\n"
str += " String() string" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "type DescriptionEnum interface {" + "\n"
str += " Enum" + "\n"
str += " Description() string" + "\n"
str += "}" + "\n"
str += "\n"
str += "type EnumMetaValue struct {" + "\n"
str += " VarName string `json:\"varName\"`" + "\n"
str += " Value any `json:\"value\"`" + "\n"
str += " Description *string `json:\"description\"`" + "\n"
str += "}" + "\n"
str += "\n"
for _, enumdef := range enums {
@@ -267,8 +292,16 @@ func fmtEnumOutput(cs string, enums []EnumDef, pkgname string) string {
str += "}" + "\n"
str += "" + "\n"
str += "func (e " + enumdef.EnumTypeName + ") ValuesMeta() []enums.EnumMetaValue {" + "\n"
str += " return " + enumdef.EnumTypeName + "ValuesMeta()"
str += "func (e " + enumdef.EnumTypeName + ") ValuesMeta() []EnumMetaValue {" + "\n"
str += " return []EnumMetaValue{" + "\n"
for _, v := range enumdef.Values {
if hasDescr {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: langext.Ptr(\"%s\")},", v.VarName, v.VarName, strings.TrimSpace(*v.Description)) + "\n"
} else {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: nil},", v.VarName, v.VarName) + "\n"
}
}
str += " }" + "\n"
str += "}" + "\n"
str += "" + "\n"
@@ -297,15 +330,6 @@ func fmtEnumOutput(cs string, enums []EnumDef, pkgname string) string {
str += "}" + "\n"
str += "" + "\n"
str += "func (e " + enumdef.EnumTypeName + ") Meta() enums.EnumMetaValue {" + "\n"
if hasDescr {
str += " return enums.EnumMetaValue{VarName: e.VarName(), Value: e, Description: langext.Ptr(e.Description())}"
} else {
str += " return enums.EnumMetaValue{VarName: e.VarName(), Value: e, Description: nil}"
}
str += "}" + "\n"
str += "" + "\n"
str += "func Parse" + enumdef.EnumTypeName + "(vv string) (" + enumdef.EnumTypeName + ", bool) {" + "\n"
str += " for _, ev := range __" + enumdef.EnumTypeName + "Values {" + "\n"
str += " if string(ev) == vv {" + "\n"
@@ -321,10 +345,14 @@ func fmtEnumOutput(cs string, enums []EnumDef, pkgname string) string {
str += "}" + "\n"
str += "" + "\n"
str += "func " + enumdef.EnumTypeName + "ValuesMeta() []enums.EnumMetaValue {" + "\n"
str += " return []enums.EnumMetaValue{" + "\n"
str += "func " + enumdef.EnumTypeName + "ValuesMeta() []EnumMetaValue {" + "\n"
str += " return []EnumMetaValue{" + "\n"
for _, v := range enumdef.Values {
str += " " + v.VarName + ".Meta(),\n"
if hasDescr {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: langext.Ptr(\"%s\")},", v.VarName, v.VarName, strings.TrimSpace(*v.Description)) + "\n"
} else {
str += " " + fmt.Sprintf("EnumMetaValue{VarName: \"%s\", Value: %s, Description: nil},", v.VarName, v.VarName) + "\n"
}
}
str += " }" + "\n"
str += "}" + "\n"

View File

@@ -1,42 +1,15 @@
package bfcodegen
import (
_ "embed"
"gogs.mikescher.com/BlackForestBytes/goext/cmdext"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"os"
"path/filepath"
"testing"
"time"
)
//go:embed _test_example.tgz
var ExampleModels []byte
func TestApplyEnvOverridesSimple(t *testing.T) {
func TestGenerateEnumSpecs(t *testing.T) {
tmpFile := filepath.Join(t.TempDir(), langext.MustHexUUID()+".tgz")
tmpDir := filepath.Join(t.TempDir(), langext.MustHexUUID())
err := os.WriteFile(tmpFile, ExampleModels, 0o777)
tst.AssertNoErr(t, err)
t.Cleanup(func() { _ = os.Remove(tmpFile) })
err = os.Mkdir(tmpDir, 0o777)
tst.AssertNoErr(t, err)
t.Cleanup(func() { _ = os.RemoveAll(tmpFile) })
_, err = cmdext.Runner("tar").Arg("-xvzf").Arg(tmpFile).Arg("-C").Arg(tmpDir).FailOnExitCode().FailOnTimeout().Timeout(time.Minute).Run()
tst.AssertNoErr(t, err)
err = GenerateEnumSpecs(tmpDir, tmpDir+"/enums_gen.go")
tst.AssertNoErr(t, err)
err = GenerateEnumSpecs(tmpDir, tmpDir+"/enums_gen.go")
tst.AssertNoErr(t, err)
err := GenerateEnumSpecs("/home/mike/Code/reiff/badennet/bnet-backend/models", "/home/mike/Code/reiff/badennet/bnet-backend/models/enums_gen.go")
if err != nil {
t.Error(err)
t.Fail()
}
}

View File

@@ -1,236 +0,0 @@
package bfcodegen
import (
"errors"
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext"
"gogs.mikescher.com/BlackForestBytes/goext/cmdext"
"gogs.mikescher.com/BlackForestBytes/goext/cryptext"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/rext"
"io"
"os"
"path"
"path/filepath"
"regexp"
"strings"
"time"
)
type IDDef struct {
File string
FileRelative string
Name string
}
var rexIDPackage = rext.W(regexp.MustCompile(`^package\s+(?P<name>[A-Za-z0-9_]+)\s*$`))
var rexIDDef = rext.W(regexp.MustCompile(`^\s*type\s+(?P<name>[A-Za-z0-9_]+)\s+string\s*//\s*(@id:type).*$`))
var rexIDChecksumConst = rext.W(regexp.MustCompile(`const ChecksumIDGenerator = "(?P<cs>[A-Za-z0-9_]*)"`))
func GenerateIDSpecs(sourceDir string, destFile string) error {
files, err := os.ReadDir(sourceDir)
if err != nil {
return err
}
oldChecksum := "N/A"
if _, err := os.Stat(destFile); !os.IsNotExist(err) {
content, err := os.ReadFile(destFile)
if err != nil {
return err
}
if m, ok := rexIDChecksumConst.MatchFirst(string(content)); ok {
oldChecksum = m.GroupByName("cs").Value()
}
}
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return v.Name() != path.Base(destFile) })
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return strings.HasSuffix(v.Name(), ".go") })
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return !strings.HasSuffix(v.Name(), "_gen.go") })
langext.SortBy(files, func(v os.DirEntry) string { return v.Name() })
newChecksumStr := goext.GoextVersion
for _, f := range files {
content, err := os.ReadFile(path.Join(sourceDir, f.Name()))
if err != nil {
return err
}
newChecksumStr += "\n" + f.Name() + "\t" + cryptext.BytesSha256(content)
}
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr))
if newChecksum != oldChecksum {
fmt.Printf("[IDGenerate] Checksum has changed ( %s -> %s ), will generate new file\n\n", oldChecksum, newChecksum)
} else {
fmt.Printf("[IDGenerate] Checksum unchanged ( %s ), nothing to do\n", oldChecksum)
return nil
}
allIDs := make([]IDDef, 0)
pkgname := ""
for _, f := range files {
fmt.Printf("========= %s =========\n\n", f.Name())
fileIDs, pn, err := processIDFile(sourceDir, path.Join(sourceDir, f.Name()))
if err != nil {
return err
}
fmt.Printf("\n")
allIDs = append(allIDs, fileIDs...)
if pn != "" {
pkgname = pn
}
}
if pkgname == "" {
return errors.New("no package name found in any file")
}
err = os.WriteFile(destFile, []byte(fmtIDOutput(newChecksum, allIDs, pkgname)), 0o755)
if err != nil {
return err
}
res, err := cmdext.RunCommand("go", []string{"fmt", destFile}, langext.Ptr(2*time.Second))
if err != nil {
return err
}
if res.CommandTimedOut {
fmt.Println(res.StdCombined)
return errors.New("go fmt timed out")
}
if res.ExitCode != 0 {
fmt.Println(res.StdCombined)
return errors.New("go fmt did not succeed")
}
return nil
}
func processIDFile(basedir string, fn string) ([]IDDef, string, error) {
file, err := os.Open(fn)
if err != nil {
return nil, "", err
}
defer func() { _ = file.Close() }()
bin, err := io.ReadAll(file)
if err != nil {
return nil, "", err
}
lines := strings.Split(string(bin), "\n")
ids := make([]IDDef, 0)
pkgname := ""
for i, line := range lines {
if i == 0 && strings.HasPrefix(line, "// Code generated by") {
break
}
if match, ok := rexIDPackage.MatchFirst(line); i == 0 && ok {
pkgname = match.GroupByName("name").Value()
continue
}
if match, ok := rexIDDef.MatchFirst(line); ok {
rfp, err := filepath.Rel(basedir, fn)
if err != nil {
return nil, "", err
}
def := IDDef{
File: fn,
FileRelative: rfp,
Name: match.GroupByName("name").Value(),
}
fmt.Printf("Found ID definition { '%s' }\n", def.Name)
ids = append(ids, def)
}
}
return ids, pkgname, nil
}
func fmtIDOutput(cs string, ids []IDDef, pkgname string) string {
str := "// Code generated by id-generate.go DO NOT EDIT.\n"
str += "\n"
str += "package " + pkgname + "\n"
str += "\n"
str += "import \"go.mongodb.org/mongo-driver/bson\"" + "\n"
str += "import \"go.mongodb.org/mongo-driver/bson/bsontype\"" + "\n"
str += "import \"go.mongodb.org/mongo-driver/bson/primitive\"" + "\n"
str += "import \"gogs.mikescher.com/BlackForestBytes/goext/exerr\"" + "\n"
str += "\n"
str += "const ChecksumIDGenerator = \"" + cs + "\" // GoExtVersion: " + goext.GoextVersion + "\n"
str += "\n"
anyDef := langext.ArrFirstOrNil(ids, func(def IDDef) bool { return def.Name == "AnyID" || def.Name == "AnyId" })
for _, iddef := range ids {
str += "// ================================ " + iddef.Name + " (" + iddef.FileRelative + ") ================================" + "\n"
str += "" + "\n"
str += "func (i " + iddef.Name + ") MarshalBSONValue() (bsontype.Type, []byte, error) {" + "\n"
str += " if objId, err := primitive.ObjectIDFromHex(string(i)); err == nil {" + "\n"
str += " return bson.MarshalValue(objId)" + "\n"
str += " } else {" + "\n"
str += " return 0, nil, exerr.New(exerr.TypeMarshalEntityID, \"Failed to marshal " + iddef.Name + "(\"+i.String()+\") to ObjectId\").Str(\"value\", string(i)).Type(\"type\", i).Build()" + "\n"
str += " }" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func (i " + iddef.Name + ") String() string {" + "\n"
str += " return string(i)" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func (i " + iddef.Name + ") ObjID() (primitive.ObjectID, error) {" + "\n"
str += " return primitive.ObjectIDFromHex(string(i))" + "\n"
str += "}" + "\n"
str += "" + "\n"
str += "func (i " + iddef.Name + ") Valid() bool {" + "\n"
str += " _, err := primitive.ObjectIDFromHex(string(i))" + "\n"
str += " return err == nil" + "\n"
str += "}" + "\n"
str += "" + "\n"
if anyDef != nil {
str += "func (i " + iddef.Name + ") AsAny() " + anyDef.Name + " {" + "\n"
str += " return " + anyDef.Name + "(i)" + "\n"
str += "}" + "\n"
str += "" + "\n"
}
str += "func New" + iddef.Name + "() " + iddef.Name + " {" + "\n"
str += " return " + iddef.Name + "(primitive.NewObjectID().Hex())" + "\n"
str += "}" + "\n"
str += "" + "\n"
}
return str
}

View File

@@ -14,7 +14,6 @@ type CommandRunner struct {
listener []CommandListener
enforceExitCodes *[]int
enforceNoTimeout bool
enforceNoStderr bool
}
func Runner(program string) *CommandRunner {
@@ -26,7 +25,6 @@ func Runner(program string) *CommandRunner {
listener: make([]CommandListener, 0),
enforceExitCodes: nil,
enforceNoTimeout: false,
enforceNoStderr: false,
}
}
@@ -75,11 +73,6 @@ func (r *CommandRunner) FailOnTimeout() *CommandRunner {
return r
}
func (r *CommandRunner) FailOnStderr() *CommandRunner {
r.enforceNoStderr = true
return r
}
func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner {
r.listener = append(r.listener, lstr)
return r

View File

@@ -11,7 +11,6 @@ import (
var ErrExitCode = errors.New("process exited with an unexpected exitcode")
var ErrTimeout = errors.New("process did not exit after the specified timeout")
var ErrStderrPrint = errors.New("process did print to stderr stream")
type CommandResult struct {
StdOut string
@@ -54,27 +53,12 @@ func run(opt CommandRunner) (CommandResult, error) {
err error
}
stderrFailChan := make(chan bool)
outputChan := make(chan resultObj)
go func() {
// we need to first fully read the pipes and then call Wait
// see https://pkg.go.dev/os/exec#Cmd.StdoutPipe
listener := make([]CommandListener, 0)
listener = append(listener, opt.listener...)
if opt.enforceNoStderr {
listener = append(listener, genericCommandListener{
_readRawStderr: langext.Ptr(func(v []byte) {
if len(v) > 0 {
stderrFailChan <- true
}
}),
})
}
stdout, stderr, stdcombined, err := preader.Read(listener)
stdout, stderr, stdcombined, err := preader.Read(opt.listener)
if err != nil {
outputChan <- resultObj{stdout, stderr, stdcombined, err}
_ = cmd.Process.Kill()
@@ -131,34 +115,8 @@ func run(opt CommandRunner) (CommandResult, error) {
return res, nil
}
case <-stderrFailChan:
_ = cmd.Process.Kill()
if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, 32*time.Millisecond); ok {
// most of the time the cmd.Process.Kill() should also have finished the pipereader
// and we can at least return the already collected stdout, stderr, etc
res := CommandResult{
StdOut: fallback.stdout,
StdErr: fallback.stderr,
StdCombined: fallback.stdcombined,
ExitCode: -1,
CommandTimedOut: false,
}
return res, ErrStderrPrint
} else {
res := CommandResult{
StdOut: "",
StdErr: "",
StdCombined: "",
ExitCode: -1,
CommandTimedOut: false,
}
return res, ErrStderrPrint
}
case outobj := <-outputChan:
var exiterr *exec.ExitError
if errors.As(outobj.err, &exiterr) {
if exiterr, ok := outobj.err.(*exec.ExitError); ok {
excode := exiterr.ExitCode()
for _, lstr := range opt.listener {
lstr.Finished(excode)

View File

@@ -1,7 +1,6 @@
package cmdext
import (
"errors"
"fmt"
"testing"
"time"
@@ -33,7 +32,7 @@ func TestStdout(t *testing.T) {
func TestStderr(t *testing.T) {
res1, err := Runner("python3").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").Run()
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
@@ -56,7 +55,7 @@ func TestStderr(t *testing.T) {
}
func TestStdcombined(t *testing.T) {
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"1\", file=sys.stderr, flush=True); time.sleep(0.1); print(\"2\", file=sys.stdout, flush=True); time.sleep(0.1); print(\"3\", file=sys.stderr, flush=True)").
Run()
@@ -82,7 +81,7 @@ func TestStdcombined(t *testing.T) {
}
func TestPartialRead(t *testing.T) {
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", flush=True); time.sleep(5); print(\"cant see me\", flush=True);").
Timeout(100 * time.Millisecond).
@@ -106,7 +105,7 @@ func TestPartialRead(t *testing.T) {
}
func TestPartialReadStderr(t *testing.T) {
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, flush=True); time.sleep(5); print(\"cant see me\", file=sys.stderr, flush=True);").
Timeout(100 * time.Millisecond).
@@ -131,7 +130,7 @@ func TestPartialReadStderr(t *testing.T) {
func TestReadUnflushedStdout(t *testing.T) {
res1, err := Runner("python3").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stdout, end='')").Run()
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stdout, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
@@ -155,7 +154,7 @@ func TestReadUnflushedStdout(t *testing.T) {
func TestReadUnflushedStderr(t *testing.T) {
res1, err := Runner("python3").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stderr, end='')").Run()
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
@@ -180,7 +179,7 @@ func TestReadUnflushedStderr(t *testing.T) {
func TestPartialReadUnflushed(t *testing.T) {
t.SkipNow()
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", end=''); time.sleep(5); print(\"cant see me\", end='');").
Timeout(100 * time.Millisecond).
@@ -206,7 +205,7 @@ func TestPartialReadUnflushed(t *testing.T) {
func TestPartialReadUnflushedStderr(t *testing.T) {
t.SkipNow()
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, end=''); time.sleep(5); print(\"cant see me\", file=sys.stderr, end='');").
Timeout(100 * time.Millisecond).
@@ -231,7 +230,7 @@ func TestPartialReadUnflushedStderr(t *testing.T) {
func TestListener(t *testing.T) {
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys;" +
"import time;" +
@@ -264,7 +263,7 @@ func TestListener(t *testing.T) {
func TestLongStdout(t *testing.T) {
res1, err := Runner("python3").
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"X\" * 125001 + \"\\n\"); print(\"Y\" * 125001 + \"\\n\"); print(\"Z\" * 125001 + \"\\n\");").
Timeout(5000 * time.Millisecond).
@@ -290,40 +289,16 @@ func TestLongStdout(t *testing.T) {
func TestFailOnTimeout(t *testing.T) {
_, err := Runner("sleep").Arg("2").Timeout(200 * time.Millisecond).FailOnTimeout().Run()
if !errors.Is(err, ErrTimeout) {
if err != ErrTimeout {
t.Errorf("wrong err := %v", err)
}
}
func TestFailOnStderr(t *testing.T) {
res1, err := Runner("python3").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").FailOnStderr().Run()
if err == nil {
t.Errorf("no err")
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != -1 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "error" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "error\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestFailOnExitcode(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).FailOnExitCode().Run()
if !errors.Is(err, ErrExitCode) {
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}

View File

@@ -32,8 +32,8 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
stdout := ""
go func() {
buf := make([]byte, 128)
for {
n, err := pr.stdout.Read(buf)
for true {
n, out := pr.stdout.Read(buf)
if n > 0 {
txt := string(buf[:n])
stdout += txt
@@ -42,11 +42,11 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
lstr.ReadRawStdout(buf[:n])
}
}
if err == io.EOF {
if out == io.EOF {
break
}
if err != nil {
errch <- err
if out != nil {
errch <- out
break
}
}
@@ -61,7 +61,7 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
stderr := ""
go func() {
buf := make([]byte, 128)
for {
for true {
n, err := pr.stderr.Read(buf)
if n > 0 {

View File

@@ -3,7 +3,6 @@ package cryptext
import (
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"encoding/hex"
"errors"
@@ -15,15 +14,14 @@ import (
"strings"
)
const LatestPassHashVersion = 5
const LatestPassHashVersion = 4
// PassHash
// - [v0]: plaintext password ( `0|...` ) // simple, used to write PW's directly in DB
// - [v1]: sha256(plaintext) // simple hashing
// - [v2]: seed | sha256<seed>(plaintext) // add seed
// - [v3]: seed | sha256<seed>(plaintext) | [hex(totp)] // add TOTP support
// - [v4]: bcrypt(plaintext) | [hex(totp)] // use proper bcrypt
// - [v5]: bcrypt(sha512(plaintext)) | [hex(totp)] // hash pw before bcrypt (otherwise max pw-len = 72)
// - [v0]: plaintext password ( `0|...` )
// - [v1]: sha256(plaintext)
// - [v2]: seed | sha256<seed>(plaintext)
// - [v3]: seed | sha256<seed>(plaintext) | [hex(totp)]
// - [v4]: bcrypt(plaintext) | [hex(totp)]
type PassHash string
func (ph PassHash) Valid() bool {
@@ -111,21 +109,7 @@ func (ph PassHash) Data() (_version int, _seed []byte, _payload []byte, _totp bo
totp := false
totpsecret := make([]byte, 0)
if split[2] != "0" {
totpsecret, err = hex.DecodeString(split[2])
totp = true
}
return int(version), nil, payload, totp, totpsecret, true
}
if version == 5 {
if len(split) != 3 {
return -1, nil, nil, false, nil, false
}
payload := []byte(split[1])
totp := false
totpsecret := make([]byte, 0)
if split[2] != "0" {
totpsecret, err = hex.DecodeString(split[2])
totpsecret, err = hex.DecodeString(split[3])
totp = true
}
return int(version), nil, payload, totp, totpsecret, true
@@ -172,14 +156,6 @@ func (ph PassHash) Verify(plainpass string, totp *string) bool {
}
}
if version == 5 {
if !hastotp {
return bcrypt.CompareHashAndPassword(payload, hash512(plainpass)) == nil
} else {
return bcrypt.CompareHashAndPassword(payload, hash512(plainpass)) == nil && totpext.Validate(totpsecret, *totp)
}
}
return false
}
@@ -233,12 +209,6 @@ func (ph PassHash) ClearTOTP() (PassHash, error) {
return PassHash(strings.Join(split, "|")), nil
}
if version == 5 {
split := strings.Split(string(ph), "|")
split[2] = "0"
return PassHash(strings.Join(split, "|")), nil
}
return "", errors.New("unknown version")
}
@@ -272,12 +242,6 @@ func (ph PassHash) WithTOTP(totpSecret []byte) (PassHash, error) {
return PassHash(strings.Join(split, "|")), nil
}
if version == 5 {
split := strings.Split(string(ph), "|")
split[2] = hex.EncodeToString(totpSecret)
return PassHash(strings.Join(split, "|")), nil
}
return "", errors.New("unknown version")
}
@@ -307,10 +271,6 @@ func (ph PassHash) Change(newPlainPass string) (PassHash, error) {
return HashPasswordV4(newPlainPass, langext.Conditional(hastotp, totpsecret, nil))
}
if version == 5 {
return HashPasswordV5(newPlainPass, langext.Conditional(hastotp, totpsecret, nil))
}
return "", errors.New("unknown version")
}
@@ -319,24 +279,7 @@ func (ph PassHash) String() string {
}
func HashPassword(plainpass string, totpSecret []byte) (PassHash, error) {
return HashPasswordV5(plainpass, totpSecret)
}
func HashPasswordV5(plainpass string, totpSecret []byte) (PassHash, error) {
var strtotp string
if totpSecret == nil {
strtotp = "0"
} else {
strtotp = hex.EncodeToString(totpSecret)
}
payload, err := bcrypt.GenerateFromPassword(hash512(plainpass), bcrypt.MinCost)
if err != nil {
return "", err
}
return PassHash(fmt.Sprintf("5|%s|%s", string(payload), strtotp)), nil
return HashPasswordV4(plainpass, totpSecret)
}
func HashPasswordV4(plainpass string, totpSecret []byte) (PassHash, error) {
@@ -397,13 +340,6 @@ func HashPasswordV0(plainpass string) (PassHash, error) {
return PassHash(fmt.Sprintf("0|%s", plainpass)), nil
}
func hash512(s string) []byte {
h := sha512.New()
h.Write([]byte(s))
bs := h.Sum(nil)
return bs
}
func hash256(s string) []byte {
h := sha256.New()
h.Write([]byte(s))

View File

@@ -1,210 +0,0 @@
package cryptext
import (
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/totpext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
func TestPassHash1(t *testing.T) {
ph, err := HashPassword("test123", nil)
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashTOTP(t *testing.T) {
sec, err := totpext.GenerateSecret()
tst.AssertNoErr(t, err)
ph, err := HashPassword("test123", sec)
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertTrue(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertFalse(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V0(t *testing.T) {
ph, err := HashPasswordV0("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V1(t *testing.T) {
ph, err := HashPasswordV1("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V2(t *testing.T) {
ph, err := HashPasswordV2("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V3(t *testing.T) {
ph, err := HashPasswordV3("test123", nil)
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V3_TOTP(t *testing.T) {
sec, err := totpext.GenerateSecret()
tst.AssertNoErr(t, err)
ph, err := HashPasswordV3("test123", sec)
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertTrue(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertFalse(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertTrue(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertFalse(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V4(t *testing.T) {
ph, err := HashPasswordV4("test123", nil)
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertFalse(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertTrue(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
}
func TestPassHashUpgrade_V4_TOTP(t *testing.T) {
sec, err := totpext.GenerateSecret()
tst.AssertNoErr(t, err)
ph, err := HashPasswordV4("test123", sec)
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertTrue(t, ph.HasTOTP())
tst.AssertTrue(t, ph.NeedsPasswordUpgrade())
tst.AssertFalse(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
tst.AssertFalse(t, ph.Verify("test124", nil))
ph, err = ph.Upgrade("test123")
tst.AssertNoErr(t, err)
tst.AssertTrue(t, ph.Valid())
tst.AssertTrue(t, ph.HasTOTP())
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
tst.AssertFalse(t, ph.Verify("test123", nil))
tst.AssertFalse(t, ph.Verify("test124", nil))
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
tst.AssertFalse(t, ph.Verify("test124", nil))
}

View File

@@ -7,9 +7,6 @@ type SyncSet[TData comparable] struct {
lock sync.Mutex
}
// Add adds `value` to the set
// returns true if the value was actually inserted
// returns false if the value already existed
func (s *SyncSet[TData]) Add(value TData) bool {
s.lock.Lock()
defer s.lock.Unlock()
@@ -18,10 +15,10 @@ func (s *SyncSet[TData]) Add(value TData) bool {
s.data = make(map[TData]bool)
}
_, existsInPreState := s.data[value]
_, ok := s.data[value]
s.data[value] = true
return !existsInPreState
return !ok
}
func (s *SyncSet[TData]) AddAll(values []TData) {

View File

@@ -1,170 +0,0 @@
package dataext
type ValueGroup interface {
TupleLength() int
TupleValues() []any
}
// ----------------------------------------------------------------------------
type Single[T1 any] struct {
V1 T1
}
func (s Single[T1]) TupleLength() int {
return 1
}
func (s Single[T1]) TupleValues() []any {
return []any{s.V1}
}
// ----------------------------------------------------------------------------
type Tuple[T1 any, T2 any] struct {
V1 T1
V2 T2
}
func (t Tuple[T1, T2]) TupleLength() int {
return 2
}
func (t Tuple[T1, T2]) TupleValues() []any {
return []any{t.V1, t.V2}
}
// ----------------------------------------------------------------------------
type Triple[T1 any, T2 any, T3 any] struct {
V1 T1
V2 T2
V3 T3
}
func (t Triple[T1, T2, T3]) TupleLength() int {
return 3
}
func (t Triple[T1, T2, T3]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3}
}
// ----------------------------------------------------------------------------
type Quadruple[T1 any, T2 any, T3 any, T4 any] struct {
V1 T1
V2 T2
V3 T3
V4 T4
}
func (t Quadruple[T1, T2, T3, T4]) TupleLength() int {
return 4
}
func (t Quadruple[T1, T2, T3, T4]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3, t.V4}
}
// ----------------------------------------------------------------------------
type Quintuple[T1 any, T2 any, T3 any, T4 any, T5 any] struct {
V1 T1
V2 T2
V3 T3
V4 T4
V5 T5
}
func (t Quintuple[T1, T2, T3, T4, T5]) TupleLength() int {
return 5
}
func (t Quintuple[T1, T2, T3, T4, T5]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3, t.V4, t.V5}
}
// ----------------------------------------------------------------------------
type Sextuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any] struct {
V1 T1
V2 T2
V3 T3
V4 T4
V5 T5
V6 T6
}
func (t Sextuple[T1, T2, T3, T4, T5, T6]) TupleLength() int {
return 6
}
func (t Sextuple[T1, T2, T3, T4, T5, T6]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6}
}
// ----------------------------------------------------------------------------
type Septuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any] struct {
V1 T1
V2 T2
V3 T3
V4 T4
V5 T5
V6 T6
V7 T7
}
func (t Septuple[T1, T2, T3, T4, T5, T6, T7]) TupleLength() int {
return 7
}
func (t Septuple[T1, T2, T3, T4, T5, T6, T7]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6, t.V7}
}
// ----------------------------------------------------------------------------
type Octuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any] struct {
V1 T1
V2 T2
V3 T3
V4 T4
V5 T5
V6 T6
V7 T7
V8 T8
}
func (t Octuple[T1, T2, T3, T4, T5, T6, T7, T8]) TupleLength() int {
return 8
}
func (t Octuple[T1, T2, T3, T4, T5, T6, T7, T8]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6, t.V7, t.V8}
}
// ----------------------------------------------------------------------------
type Nonuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any, T9 any] struct {
V1 T1
V2 T2
V3 T3
V4 T4
V5 T5
V6 T6
V7 T7
V8 T8
V9 T9
}
func (t Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]) TupleLength() int {
return 9
}
func (t Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]) TupleValues() []any {
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6, t.V7, t.V8, t.V9}
}

View File

@@ -1,24 +0,0 @@
package enums
type Enum interface {
Valid() bool
ValuesAny() []any
ValuesMeta() []EnumMetaValue
VarName() string
}
type StringEnum interface {
Enum
String() string
}
type DescriptionEnum interface {
Enum
Description() string
}
type EnumMetaValue struct {
VarName string `json:"varName"`
Value any `json:"value"`
Description *string `json:"description"`
}

View File

@@ -1,470 +0,0 @@
package exerr
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog"
"go.mongodb.org/mongo-driver/bson/primitive"
"gogs.mikescher.com/BlackForestBytes/goext/dataext"
"gogs.mikescher.com/BlackForestBytes/goext/enums"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"net/http"
"os"
"runtime/debug"
"strings"
"time"
)
//
// ==== USAGE =====
//
// If some method returns an error _always wrap it into an exerror:
// value, err := do_something(..)
// if err != nil {
// return nil, exerror.Wrap(err, "do something failed").Build()
// }
//
// If possible add metadata to the error (eg the id that was not found, ...), the methods are the same as in zerolog
// return nil, exerror.Wrap(err, "do something failed").Str("someid", id).Int("count", in.Count).Build()
//
// You can change the errortype with `.User()` and `.System()` (User-errors are 400 and System-errors 500)
// You can also manually set the statuscode with `.WithStatuscode(http.NotFound)`
// You can set the type with `WithType(..)`
//
// New Errors (that don't wrap an existing err object) are created with New
// return nil, exerror.New(exerror.TypeInternal, "womethign wen horrible wrong").Build()
// You can eitehr use an existing ErrorType, the "catch-all" ErrInternal, or add you own ErrType in consts.go
//
// All errors should be handled one of the following four ways:
// - return the error to the caller and let him handle it:
// (also auto-prints the error to the log)
// => Wrap/New + Build
// - Print the error
// (also auto-sends it to the error-service)
// This is useful for errors that happen asynchron or are non-fatal for the current request
// => Wrap/New + Print
// - Return the error to the Rest-API caller
// (also auto-prints the error to the log)
// (also auto-sends it to the error-service)
// => Wrap/New + Output
// - Print and stop the service
// (also auto-sends it to the error-service)
// => Wrap/New + Fatal
//
var stackSkipLogger zerolog.Logger
func init() {
cw := zerolog.ConsoleWriter{
Out: os.Stdout,
TimeFormat: "2006-01-02 15:04:05 Z07:00",
}
multi := zerolog.MultiLevelWriter(cw)
stackSkipLogger = zerolog.New(multi).With().Timestamp().CallerWithSkipFrameCount(4).Logger()
}
type Builder struct {
errorData *ExErr
containsGinData bool
}
func Get(err error) *Builder {
return &Builder{errorData: FromError(err)}
}
func New(t ErrorType, msg string) *Builder {
return &Builder{errorData: newExErr(CatSystem, t, msg)}
}
func Wrap(err error, msg string) *Builder {
if err == nil {
return &Builder{errorData: newExErr(CatSystem, TypeInternal, msg)} // prevent NPE if we call Wrap with err==nil
}
if !pkgconfig.RecursiveErrors {
v := FromError(err)
v.Message = msg
return &Builder{errorData: v}
}
return &Builder{errorData: wrapExErr(FromError(err), msg, CatWrap, 1)}
}
// ----------------------------------------------------------------------------
func (b *Builder) WithType(t ErrorType) *Builder {
b.errorData.Type = t
return b
}
func (b *Builder) WithStatuscode(status int) *Builder {
b.errorData.StatusCode = &status
return b
}
func (b *Builder) WithMessage(msg string) *Builder {
b.errorData.Message = msg
return b
}
// ----------------------------------------------------------------------------
// Err changes the Severity to ERROR (default)
// The error will be:
//
// - On Build():
//
// - Short-Logged as Err
//
// - On Print():
//
// - Logged as Err
//
// - Send to the error-service
//
// - On Output():
//
// - Logged as Err
//
// - Send to the error-service
func (b *Builder) Err() *Builder {
b.errorData.Severity = SevErr
return b
}
// Warn changes the Severity to WARN
// The error will be:
//
// - On Build():
//
// - -(nothing)-
//
// - On Print():
//
// - Short-Logged as Warn
//
// - On Output():
//
// - Logged as Warn
func (b *Builder) Warn() *Builder {
b.errorData.Severity = SevWarn
return b
}
// Info changes the Severity to INFO
// The error will be:
//
// - On Build():
//
// - -(nothing)-
//
// - On Print():
//
// - -(nothing)-
//
// - On Output():
//
// - -(nothing)-
func (b *Builder) Info() *Builder {
b.errorData.Severity = SevInfo
return b
}
// ----------------------------------------------------------------------------
// User sets the Category to CatUser
//
// Errors with category
func (b *Builder) User() *Builder {
b.errorData.Category = CatUser
return b
}
func (b *Builder) System() *Builder {
b.errorData.Category = CatSystem
return b
}
// ----------------------------------------------------------------------------
func (b *Builder) Id(key string, val fmt.Stringer) *Builder {
return b.addMeta(key, MDTID, newIDWrap(val))
}
func (b *Builder) StrPtr(key string, val *string) *Builder {
return b.addMeta(key, MDTStringPtr, val)
}
func (b *Builder) Str(key string, val string) *Builder {
return b.addMeta(key, MDTString, val)
}
func (b *Builder) Int(key string, val int) *Builder {
return b.addMeta(key, MDTInt, val)
}
func (b *Builder) Int8(key string, val int8) *Builder {
return b.addMeta(key, MDTInt8, val)
}
func (b *Builder) Int16(key string, val int16) *Builder {
return b.addMeta(key, MDTInt16, val)
}
func (b *Builder) Int32(key string, val int32) *Builder {
return b.addMeta(key, MDTInt32, val)
}
func (b *Builder) Int64(key string, val int64) *Builder {
return b.addMeta(key, MDTInt64, val)
}
func (b *Builder) Float32(key string, val float32) *Builder {
return b.addMeta(key, MDTFloat32, val)
}
func (b *Builder) Float64(key string, val float64) *Builder {
return b.addMeta(key, MDTFloat64, val)
}
func (b *Builder) Bool(key string, val bool) *Builder {
return b.addMeta(key, MDTBool, val)
}
func (b *Builder) Bytes(key string, val []byte) *Builder {
return b.addMeta(key, MDTBytes, val)
}
func (b *Builder) ObjectID(key string, val primitive.ObjectID) *Builder {
return b.addMeta(key, MDTObjectID, val)
}
func (b *Builder) Time(key string, val time.Time) *Builder {
return b.addMeta(key, MDTTime, val)
}
func (b *Builder) Dur(key string, val time.Duration) *Builder {
return b.addMeta(key, MDTDuration, val)
}
func (b *Builder) Strs(key string, val []string) *Builder {
return b.addMeta(key, MDTStringArray, val)
}
func (b *Builder) Ints(key string, val []int) *Builder {
return b.addMeta(key, MDTIntArray, val)
}
func (b *Builder) Ints32(key string, val []int32) *Builder {
return b.addMeta(key, MDTInt32Array, val)
}
func (b *Builder) Type(key string, cls interface{}) *Builder {
return b.addMeta(key, MDTString, fmt.Sprintf("%T", cls))
}
func (b *Builder) Interface(key string, val interface{}) *Builder {
return b.addMeta(key, MDTAny, newAnyWrap(val))
}
func (b *Builder) Any(key string, val any) *Builder {
return b.addMeta(key, MDTAny, newAnyWrap(val))
}
func (b *Builder) Stringer(key string, val fmt.Stringer) *Builder {
if val == nil {
return b.addMeta(key, MDTString, "(!nil)")
} else {
return b.addMeta(key, MDTString, val.String())
}
}
func (b *Builder) Enum(key string, val enums.Enum) *Builder {
return b.addMeta(key, MDTEnum, newEnumWrap(val))
}
func (b *Builder) Stack() *Builder {
return b.addMeta("@Stack", MDTString, string(debug.Stack()))
}
func (b *Builder) Errs(key string, val []error) *Builder {
for i, valerr := range val {
b.addMeta(fmt.Sprintf("%v[%v]", key, i), MDTString, Get(valerr).errorData.FormatLog(LogPrintFull))
}
return b
}
func (b *Builder) GinReq(ctx context.Context, g *gin.Context, req *http.Request) *Builder {
if v := ctx.Value("start_timestamp"); v != nil {
if t, ok := v.(time.Time); ok {
b.Time("ctx.startTimestamp", t)
b.Time("ctx.endTimestamp", time.Now())
}
}
b.Str("gin.method", req.Method)
b.Str("gin.path", g.FullPath())
b.Strs("gin.header", extractHeader(g.Request.Header))
if req.URL != nil {
b.Str("gin.url", req.URL.String())
}
if ctxVal := g.GetString("apiversion"); ctxVal != "" {
b.Str("gin.context.apiversion", ctxVal)
}
if ctxVal := g.GetString("uid"); ctxVal != "" {
b.Str("gin.context.uid", ctxVal)
}
if ctxVal := g.GetString("fcmId"); ctxVal != "" {
b.Str("gin.context.fcmid", ctxVal)
}
if ctxVal := g.GetString("reqid"); ctxVal != "" {
b.Str("gin.context.reqid", ctxVal)
}
if req.Method != "GET" && req.Body != nil {
if req.Header.Get("Content-Type") == "application/json" {
if brc, ok := req.Body.(dataext.BufferedReadCloser); ok {
if bin, err := brc.BufferedAll(); err == nil {
if len(bin) < 16*1024 {
var prettyJSON bytes.Buffer
err = json.Indent(&prettyJSON, bin, "", " ")
if err == nil {
b.Str("gin.body", string(prettyJSON.Bytes()))
} else {
b.Bytes("gin.body", bin)
}
} else {
b.Str("gin.body", fmt.Sprintf("[[%v bytes | %s]]", len(bin), req.Header.Get("Content-Type")))
}
}
}
}
if req.Header.Get("Content-Type") == "multipart/form-data" || req.Header.Get("Content-Type") == "x-www-form-urlencoded" {
if brc, ok := req.Body.(dataext.BufferedReadCloser); ok {
if bin, err := brc.BufferedAll(); err == nil {
if len(bin) < 16*1024 {
b.Bytes("gin.body", bin)
} else {
b.Str("gin.body", fmt.Sprintf("[[%v bytes | %s]]", len(bin), req.Header.Get("Content-Type")))
}
}
}
}
}
b.containsGinData = true
return b
}
func formatHeader(header map[string][]string) string {
ml := 1
for k, _ := range header {
if len(k) > ml {
ml = len(k)
}
}
r := ""
for k, v := range header {
if r != "" {
r += "\n"
}
for _, hval := range v {
value := hval
value = strings.ReplaceAll(value, "\n", "\\n")
value = strings.ReplaceAll(value, "\r", "\\r")
value = strings.ReplaceAll(value, "\t", "\\t")
r += langext.StrPadRight(k, " ", ml) + " := " + value
}
}
return r
}
func extractHeader(header map[string][]string) []string {
r := make([]string, 0, len(header))
for k, v := range header {
for _, hval := range v {
value := hval
value = strings.ReplaceAll(value, "\n", "\\n")
value = strings.ReplaceAll(value, "\r", "\\r")
value = strings.ReplaceAll(value, "\t", "\\t")
r = append(r, k+": "+value)
}
}
return r
}
// ----------------------------------------------------------------------------
// Build creates a new error, ready to pass up the stack
// If the errors is not SevWarn or SevInfo it gets also logged (in short form, without stacktrace) onto stdout
func (b *Builder) Build() error {
warnOnPkgConfigNotInitialized()
if pkgconfig.ZeroLogErrTraces && (b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal) {
b.errorData.ShortLog(stackSkipLogger.Error())
} else if pkgconfig.ZeroLogAllTraces {
b.errorData.ShortLog(stackSkipLogger.Error())
}
b.CallListener(MethodBuild)
return b.errorData
}
// Output prints the error onto the gin stdout.
// The error also gets printed to stdout/stderr
// If the error is SevErr|SevFatal we also send it to the error-service
func (b *Builder) Output(ctx context.Context, g *gin.Context) {
if !b.containsGinData && g.Request != nil {
// Auto-Add gin metadata if the caller hasn't already done it
b.GinReq(ctx, g, g.Request)
}
b.errorData.Output(g)
if b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal {
b.errorData.Log(stackSkipLogger.Error())
} else if b.errorData.Severity == SevWarn {
b.errorData.Log(stackSkipLogger.Warn())
}
b.CallListener(MethodOutput)
}
// Print prints the error
// If the error is SevErr we also send it to the error-service
func (b *Builder) Print() {
if b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal {
b.errorData.Log(stackSkipLogger.Error())
} else if b.errorData.Severity == SevWarn {
b.errorData.ShortLog(stackSkipLogger.Warn())
}
b.CallListener(MethodPrint)
}
func (b *Builder) Format(level LogPrintLevel) string {
return b.errorData.FormatLog(level)
}
// Fatal prints the error and terminates the program
// If the error is SevErr we also send it to the error-service
func (b *Builder) Fatal() {
b.errorData.Severity = SevFatal
b.errorData.Log(stackSkipLogger.WithLevel(zerolog.FatalLevel))
b.CallListener(MethodFatal)
os.Exit(1)
}
// ----------------------------------------------------------------------------
func (b *Builder) addMeta(key string, mdtype metaDataType, val interface{}) *Builder {
b.errorData.Meta.add(key, mdtype, val)
return b
}

View File

@@ -1,204 +0,0 @@
package exerr
import (
"encoding/json"
"fmt"
"go.mongodb.org/mongo-driver/bson/primitive"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"reflect"
"time"
)
var reflectTypeStr = reflect.TypeOf("")
func FromError(err error) *ExErr {
if verr, ok := err.(*ExErr); ok {
// A simple ExErr
return verr
}
// A foreign error (eg a MongoDB exception)
return &ExErr{
UniqueID: newID(),
Category: CatForeign,
Type: TypeInternal,
Severity: SevErr,
Timestamp: time.Time{},
StatusCode: nil,
Message: err.Error(),
WrappedErrType: fmt.Sprintf("%T", err),
WrappedErr: err,
Caller: "",
OriginalError: nil,
Meta: getForeignMeta(err),
}
}
func newExErr(cat ErrorCategory, errtype ErrorType, msg string) *ExErr {
return &ExErr{
UniqueID: newID(),
Category: cat,
Type: errtype,
Severity: SevErr,
Timestamp: time.Now(),
StatusCode: nil,
Message: msg,
WrappedErrType: "",
WrappedErr: nil,
Caller: callername(2),
OriginalError: nil,
Meta: make(map[string]MetaValue),
}
}
func wrapExErr(e *ExErr, msg string, cat ErrorCategory, stacktraceskip int) *ExErr {
return &ExErr{
UniqueID: newID(),
Category: cat,
Type: TypeWrap,
Severity: SevErr,
Timestamp: time.Now(),
StatusCode: e.StatusCode,
Message: msg,
WrappedErrType: "",
WrappedErr: nil,
Caller: callername(1 + stacktraceskip),
OriginalError: e,
Meta: make(map[string]MetaValue),
}
}
func getForeignMeta(err error) (mm MetaMap) {
mm = make(map[string]MetaValue)
defer func() {
if panicerr := recover(); panicerr != nil {
New(TypePanic, "Panic while trying to get foreign meta").
Str("source", err.Error()).
Interface("panic-object", panicerr).
Stack().
Print()
}
}()
rval := reflect.ValueOf(err)
if rval.Kind() == reflect.Interface || rval.Kind() == reflect.Ptr {
rval = reflect.ValueOf(err).Elem()
}
mm.add("foreign.errortype", MDTString, rval.Type().String())
for k, v := range addMetaPrefix("foreign", getReflectedMetaValues(err, 8)) {
mm[k] = v
}
return mm
}
func getReflectedMetaValues(value interface{}, remainingDepth int) map[string]MetaValue {
if remainingDepth <= 0 {
return map[string]MetaValue{}
}
if langext.IsNil(value) {
return map[string]MetaValue{"": {DataType: MDTNil, Value: nil}}
}
rval := reflect.ValueOf(value)
if rval.Type().Kind() == reflect.Ptr {
if rval.IsNil() {
return map[string]MetaValue{"*": {DataType: MDTNil, Value: nil}}
}
elem := rval.Elem()
return addMetaPrefix("*", getReflectedMetaValues(elem.Interface(), remainingDepth-1))
}
if !rval.CanInterface() {
return map[string]MetaValue{"": {DataType: MDTString, Value: "<<no-interface>>"}}
}
raw := rval.Interface()
switch ifraw := raw.(type) {
case time.Time:
return map[string]MetaValue{"": {DataType: MDTTime, Value: ifraw}}
case time.Duration:
return map[string]MetaValue{"": {DataType: MDTDuration, Value: ifraw}}
case int:
return map[string]MetaValue{"": {DataType: MDTInt, Value: ifraw}}
case int8:
return map[string]MetaValue{"": {DataType: MDTInt8, Value: ifraw}}
case int16:
return map[string]MetaValue{"": {DataType: MDTInt16, Value: ifraw}}
case int32:
return map[string]MetaValue{"": {DataType: MDTInt32, Value: ifraw}}
case int64:
return map[string]MetaValue{"": {DataType: MDTInt64, Value: ifraw}}
case string:
return map[string]MetaValue{"": {DataType: MDTString, Value: ifraw}}
case bool:
return map[string]MetaValue{"": {DataType: MDTBool, Value: ifraw}}
case []byte:
return map[string]MetaValue{"": {DataType: MDTBytes, Value: ifraw}}
case float32:
return map[string]MetaValue{"": {DataType: MDTFloat32, Value: ifraw}}
case float64:
return map[string]MetaValue{"": {DataType: MDTFloat64, Value: ifraw}}
case []int:
return map[string]MetaValue{"": {DataType: MDTIntArray, Value: ifraw}}
case []int32:
return map[string]MetaValue{"": {DataType: MDTInt32Array, Value: ifraw}}
case primitive.ObjectID:
return map[string]MetaValue{"": {DataType: MDTObjectID, Value: ifraw}}
case []string:
return map[string]MetaValue{"": {DataType: MDTStringArray, Value: ifraw}}
}
if rval.Type().Kind() == reflect.Struct {
m := make(map[string]MetaValue)
for i := 0; i < rval.NumField(); i++ {
fieldtype := rval.Type().Field(i)
fieldname := fieldtype.Name
if fieldtype.IsExported() {
for k, v := range addMetaPrefix(fieldname, getReflectedMetaValues(rval.Field(i).Interface(), remainingDepth-1)) {
m[k] = v
}
}
}
return m
}
if rval.Type().ConvertibleTo(reflectTypeStr) {
return map[string]MetaValue{"": {DataType: MDTString, Value: rval.Convert(reflectTypeStr).String()}}
}
jsonval, err := json.Marshal(value)
if err != nil {
panic(err) // gets recovered later up
}
return map[string]MetaValue{"": {DataType: MDTString, Value: string(jsonval)}}
}
func addMetaPrefix(prefix string, m map[string]MetaValue) map[string]MetaValue {
if len(m) == 1 {
for k, v := range m {
if k == "" {
return map[string]MetaValue{prefix: v}
}
}
}
r := make(map[string]MetaValue, len(m))
for k, v := range m {
r[prefix+"."+k] = v
}
return r
}

View File

@@ -1,83 +0,0 @@
package exerr
import (
"gogs.mikescher.com/BlackForestBytes/goext/dataext"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
)
type ErrorCategory struct{ Category string }
var (
CatWrap = ErrorCategory{"Wrap"} // The error is simply wrapping another error (e.g. when a grpc call returns an error)
CatSystem = ErrorCategory{"System"} // An internal system error (e.g. connection to db failed)
CatUser = ErrorCategory{"User"} // The user (the API caller) did something wrong (e.g. he has no permissions to do this)
CatForeign = ErrorCategory{"Foreign"} // A foreign error that some component threw (e.g. an unknown mongodb error), happens if we call Wrap(..) on an non-bmerror value
)
//goland:noinspection GoUnusedGlobalVariable
var AllCategories = []ErrorCategory{CatWrap, CatSystem, CatUser, CatForeign}
type ErrorSeverity struct{ Severity string }
var (
SevTrace = ErrorSeverity{"Trace"}
SevDebug = ErrorSeverity{"Debug"}
SevInfo = ErrorSeverity{"Info"}
SevWarn = ErrorSeverity{"Warn"}
SevErr = ErrorSeverity{"Err"}
SevFatal = ErrorSeverity{"Fatal"}
)
//goland:noinspection GoUnusedGlobalVariable
var AllSeverities = []ErrorSeverity{SevTrace, SevDebug, SevInfo, SevWarn, SevErr, SevFatal}
type ErrorType struct {
Key string
DefaultStatusCode *int
}
//goland:noinspection GoUnusedGlobalVariable
var (
TypeInternal = NewType("INTERNAL_ERROR", langext.Ptr(500))
TypePanic = NewType("PANIC", langext.Ptr(500))
TypeNotImplemented = NewType("NOT_IMPLEMENTED", langext.Ptr(500))
TypeMongoQuery = NewType("MONGO_QUERY", langext.Ptr(500))
TypeCursorTokenDecode = NewType("CURSOR_TOKEN_DECODE", langext.Ptr(500))
TypeMongoFilter = NewType("MONGO_FILTER", langext.Ptr(500))
TypeMongoReflection = NewType("MONGO_REFLECTION", langext.Ptr(500))
TypeWrap = NewType("Wrap", nil)
TypeBindFailURI = NewType("BINDFAIL_URI", langext.Ptr(400))
TypeBindFailQuery = NewType("BINDFAIL_QUERY", langext.Ptr(400))
TypeBindFailJSON = NewType("BINDFAIL_JSON", langext.Ptr(400))
TypeBindFailFormData = NewType("BINDFAIL_FORMDATA", langext.Ptr(400))
TypeBindFailHeader = NewType("BINDFAIL_HEADER", langext.Ptr(400))
TypeMarshalEntityID = NewType("MARSHAL_ENTITY_ID", langext.Ptr(400))
TypeUnauthorized = NewType("UNAUTHORIZED", langext.Ptr(401))
TypeAuthFailed = NewType("AUTH_FAILED", langext.Ptr(401))
// other values come the used package
)
var registeredTypes = dataext.SyncSet[string]{}
func NewType(key string, defStatusCode *int) ErrorType {
insertOkay := registeredTypes.Add(key)
if !insertOkay {
panic("Cannot register same ErrType ('" + key + "') more than once")
}
return ErrorType{key, defStatusCode}
}
type LogPrintLevel string
const (
LogPrintFull LogPrintLevel = "Full"
LogPrintOverview LogPrintLevel = "Overview"
LogPrintShort LogPrintLevel = "Short"
)

View File

@@ -1,80 +0,0 @@
package exerr
import (
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
)
type ErrorPackageConfig struct {
ZeroLogErrTraces bool // autom print zerolog logs on .Build() (for SevErr and SevFatal)
ZeroLogAllTraces bool // autom print zerolog logs on .Build() (for all Severities)
RecursiveErrors bool // errors contains their Origin-Error
ExtendedGinOutput bool // Log extended data (trace, meta, ...) to gin in err.Output()
IncludeMetaInGinOutput bool // Log meta fields ( from e.g. `.Str(key, val).Build()` ) to gin in err.Output()
ExtendGinOutput func(err *ExErr, json map[string]any) // (Optionally) extend the gin output with more fields
ExtendGinDataOutput func(err *ExErr, depth int, json map[string]any) // (Optionally) extend the gin `__data` output with more fields
}
type ErrorPackageConfigInit struct {
ZeroLogErrTraces *bool
ZeroLogAllTraces *bool
RecursiveErrors *bool
ExtendedGinOutput *bool
IncludeMetaInGinOutput *bool
ExtendGinOutput func(err *ExErr, json map[string]any)
ExtendGinDataOutput func(err *ExErr, depth int, json map[string]any)
}
var initialized = false
var pkgconfig = ErrorPackageConfig{
ZeroLogErrTraces: true,
ZeroLogAllTraces: false,
RecursiveErrors: true,
ExtendedGinOutput: false,
IncludeMetaInGinOutput: true,
ExtendGinOutput: func(err *ExErr, json map[string]any) {},
ExtendGinDataOutput: func(err *ExErr, depth int, json map[string]any) {},
}
// Init initializes the exerr packages
// Must be called at the program start, before (!) any errors
// Is not thread-safe
func Init(cfg ErrorPackageConfigInit) {
if initialized {
panic("Cannot re-init error package")
}
ego := func(err *ExErr, json map[string]any) {}
egdo := func(err *ExErr, depth int, json map[string]any) {}
if cfg.ExtendGinOutput != nil {
ego = cfg.ExtendGinOutput
}
if cfg.ExtendGinDataOutput != nil {
egdo = cfg.ExtendGinDataOutput
}
pkgconfig = ErrorPackageConfig{
ZeroLogErrTraces: langext.Coalesce(cfg.ZeroLogErrTraces, pkgconfig.ZeroLogErrTraces),
ZeroLogAllTraces: langext.Coalesce(cfg.ZeroLogAllTraces, pkgconfig.ZeroLogAllTraces),
RecursiveErrors: langext.Coalesce(cfg.RecursiveErrors, pkgconfig.RecursiveErrors),
ExtendedGinOutput: langext.Coalesce(cfg.ExtendedGinOutput, pkgconfig.ExtendedGinOutput),
IncludeMetaInGinOutput: langext.Coalesce(cfg.IncludeMetaInGinOutput, pkgconfig.IncludeMetaInGinOutput),
ExtendGinOutput: ego,
ExtendGinDataOutput: egdo,
}
initialized = true
}
func warnOnPkgConfigNotInitialized() {
if !initialized {
fmt.Printf("\n")
fmt.Printf("%s\n", langext.StrRepeat("=", 80))
fmt.Printf("%s\n", "[WARNING] exerr package used without initializiation")
fmt.Printf("%s\n", " call exerr.Init() in your main() function")
fmt.Printf("%s\n", langext.StrRepeat("=", 80))
fmt.Printf("\n")
}
}

View File

@@ -1,298 +0,0 @@
package exerr
import (
"github.com/rs/xid"
"github.com/rs/zerolog"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"reflect"
"strings"
"time"
)
type ExErr struct {
UniqueID string `json:"uniqueID"`
Timestamp time.Time `json:"timestamp"`
Category ErrorCategory `json:"category"`
Severity ErrorSeverity `json:"severity"`
Type ErrorType `json:"type"`
StatusCode *int `json:"statusCode"`
Message string `json:"message"`
WrappedErrType string `json:"wrappedErrType"`
WrappedErr any `json:"-"`
Caller string `json:"caller"`
OriginalError *ExErr `json:"originalError"`
Meta MetaMap `json:"meta"`
}
func (ee *ExErr) Error() string {
return ee.Message
}
// Unwrap must be implemented so that some error.XXX methods work
func (ee *ExErr) Unwrap() error {
if ee.OriginalError == nil {
return nil // this is neccessary - otherwise we return a wrapped nil and the `x == nil` comparison fails (= panic in errors.Is and other failures)
}
return ee.OriginalError
}
// Is must be implemented so that error.Is(x) works
func (ee *ExErr) Is(e error) bool {
return IsFrom(ee, e)
}
// As must be implemented so that error.As(x) works
//
//goland:noinspection GoTypeAssertionOnErrors
func (ee *ExErr) As(target any) bool {
if dstErr, ok := target.(*ExErr); ok {
if dst0, ok := ee.contains(dstErr); ok {
dstErr = dst0
return true
} else {
return false
}
} else {
val := reflect.ValueOf(target)
typStr := val.Type().Elem().String()
for curr := ee; curr != nil; curr = curr.OriginalError {
if curr.Category == CatForeign && curr.WrappedErrType == typStr && curr.WrappedErr != nil {
val.Elem().Set(reflect.ValueOf(curr.WrappedErr))
return true
}
}
return false
}
}
func (ee *ExErr) Log(evt *zerolog.Event) {
evt.Msg(ee.FormatLog(LogPrintFull))
}
func (ee *ExErr) FormatLog(lvl LogPrintLevel) string {
if lvl == LogPrintShort {
msg := ee.Message
if ee.OriginalError != nil && ee.OriginalError.Category == CatForeign {
msg = msg + " (" + strings.ReplaceAll(ee.OriginalError.Message, "\n", " ") + ")"
}
if ee.Type != TypeWrap {
return "[" + ee.Type.Key + "] " + msg
} else {
return msg
}
} else if lvl == LogPrintOverview {
str := "[" + ee.RecursiveType().Key + "] <" + ee.UniqueID + "> " + strings.ReplaceAll(ee.RecursiveMessage(), "\n", " ") + "\n"
indent := ""
for curr := ee; curr != nil; curr = curr.OriginalError {
indent += " "
str += indent
str += "-> "
strmsg := strings.Trim(curr.Message, " \r\n\t")
if lbidx := strings.Index(curr.Message, "\n"); lbidx >= 0 {
strmsg = strmsg[0:lbidx]
}
strmsg = langext.StrLimit(strmsg, 61, "...")
str += strmsg
str += "\n"
}
return str
} else if lvl == LogPrintFull {
str := "[" + ee.RecursiveType().Key + "] <" + ee.UniqueID + "> " + strings.ReplaceAll(ee.RecursiveMessage(), "\n", " ") + "\n"
indent := ""
for curr := ee; curr != nil; curr = curr.OriginalError {
indent += " "
etype := ee.Type.Key
if ee.Type == TypeWrap {
etype = "~"
}
str += indent
str += "-> ["
str += etype
if curr.Category == CatForeign {
str += "|Foreign"
}
str += "] "
str += strings.ReplaceAll(curr.Message, "\n", " ")
if curr.Caller != "" {
str += " (@ "
str += curr.Caller
str += ")"
}
str += "\n"
if curr.Meta.Any() {
meta := indent + " {" + curr.Meta.FormatOneLine(240) + "}"
if len(meta) < 200 {
str += meta
str += "\n"
} else {
str += curr.Meta.FormatMultiLine(indent+" ", " ", 1024)
str += "\n"
}
}
}
return str
} else {
return "[?[" + ee.UniqueID + "]?]"
}
}
func (ee *ExErr) ShortLog(evt *zerolog.Event) {
ee.Meta.Apply(evt, langext.Ptr(240)).Msg(ee.FormatLog(LogPrintShort))
}
// RecursiveMessage returns the message to show
// = first error (top-down) that is not wrapping/foreign/empty
func (ee *ExErr) RecursiveMessage() string {
for curr := ee; curr != nil; curr = curr.OriginalError {
if curr.Message != "" && curr.Category != CatWrap && curr.Category != CatForeign {
return curr.Message
}
}
// fallback to self
return ee.Message
}
// RecursiveType returns the statuscode to use
// = first error (top-down) that is not wrapping/empty
func (ee *ExErr) RecursiveType() ErrorType {
for curr := ee; curr != nil; curr = curr.OriginalError {
if curr.Type != TypeWrap {
return curr.Type
}
}
// fallback to self
return ee.Type
}
// RecursiveStatuscode returns the HTTP Statuscode to use
// = first error (top-down) that has a statuscode set
func (ee *ExErr) RecursiveStatuscode() *int {
for curr := ee; curr != nil; curr = curr.OriginalError {
if curr.StatusCode != nil {
return langext.Ptr(*curr.StatusCode)
}
}
return nil
}
// RecursiveCategory returns the ErrorCategory to use
// = first error (top-down) that has a statuscode set
func (ee *ExErr) RecursiveCategory() ErrorCategory {
for curr := ee; curr != nil; curr = curr.OriginalError {
if curr.Category != CatWrap {
return curr.Category
}
}
// fallback to <empty>
return ee.Category
}
// RecursiveMeta searches (top-down) for teh first error that has a meta value with teh specified key
// and returns its value (or nil)
func (ee *ExErr) RecursiveMeta(key string) *MetaValue {
for curr := ee; curr != nil; curr = curr.OriginalError {
if metaval, ok := curr.Meta[key]; ok {
return langext.Ptr(metaval)
}
}
return nil
}
// Depth returns the depth of recursively contained errors
func (ee *ExErr) Depth() int {
if ee.OriginalError == nil {
return 1
} else {
return ee.OriginalError.Depth() + 1
}
}
// contains test if the supplied error is contained in this error (anywhere in the chain)
func (ee *ExErr) contains(original *ExErr) (*ExErr, bool) {
if original == nil {
return nil, false
}
if ee == original {
return ee, true
}
for curr := ee; curr != nil; curr = curr.OriginalError {
if curr.equalsDirectProperties(curr) {
return curr, true
}
}
return nil, false
}
// equalsDirectProperties tests if ee and other are equals, but only looks at primary properties (not `OriginalError` or `Meta`)
func (ee *ExErr) equalsDirectProperties(other *ExErr) bool {
if ee.UniqueID != other.UniqueID {
return false
}
if ee.Timestamp != other.Timestamp {
return false
}
if ee.Category != other.Category {
return false
}
if ee.Severity != other.Severity {
return false
}
if ee.Type != other.Type {
return false
}
if ee.StatusCode != other.StatusCode {
return false
}
if ee.Message != other.Message {
return false
}
if ee.WrappedErrType != other.WrappedErrType {
return false
}
if ee.Caller != other.Caller {
return false
}
return true
}
func newID() string {
return xid.New().String()
}

View File

@@ -1,93 +0,0 @@
package exerr
import (
"errors"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
type golangErr struct {
Message string
}
func (g golangErr) Error() string {
return g.Message
}
type golangErr2 struct {
Message string
}
func (g golangErr2) Error() string {
return g.Message
}
type simpleError struct {
}
func (g simpleError) Error() string {
return "Something simple went wroong"
}
type simpleError2 struct {
}
func (g simpleError2) Error() string {
return "Something simple went wroong"
}
func TestExErrIs1(t *testing.T) {
e0 := simpleError{}
wrap := Wrap(e0, "something went wrong").Str("test", "123").Build()
tst.AssertTrue(t, errors.Is(wrap, simpleError{}))
tst.AssertFalse(t, errors.Is(wrap, golangErr{}))
tst.AssertFalse(t, errors.Is(wrap, golangErr{"error1"}))
}
func TestExErrIs2(t *testing.T) {
e0 := golangErr{"error1"}
wrap := Wrap(e0, "something went wrong").Str("test", "123").Build()
tst.AssertTrue(t, errors.Is(wrap, e0))
tst.AssertTrue(t, errors.Is(wrap, golangErr{"error1"}))
tst.AssertFalse(t, errors.Is(wrap, golangErr{"error2"}))
tst.AssertFalse(t, errors.Is(wrap, simpleError{}))
}
func TestExErrAs(t *testing.T) {
e0 := golangErr{"error1"}
w0 := Wrap(e0, "something went wrong").Str("test", "123").Build()
{
out := golangErr{}
ok := errors.As(w0, &out)
tst.AssertTrue(t, ok)
tst.AssertEqual(t, out.Message, "error1")
}
w1 := Wrap(w0, "outher error").Build()
{
out := golangErr{}
ok := errors.As(w1, &out)
tst.AssertTrue(t, ok)
tst.AssertEqual(t, out.Message, "error1")
}
{
out := golangErr2{}
ok := errors.As(w1, &out)
tst.AssertFalse(t, ok)
}
{
out := simpleError2{}
ok := errors.As(w1, &out)
tst.AssertFalse(t, ok)
}
}

View File

@@ -1,112 +0,0 @@
package exerr
import (
"github.com/gin-gonic/gin"
json "gogs.mikescher.com/BlackForestBytes/goext/gojson"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"net/http"
"time"
)
func (ee *ExErr) toJson(depth int, applyExtendListener bool, outputMeta bool) langext.H {
ginJson := langext.H{}
if ee.UniqueID != "" {
ginJson["id"] = ee.UniqueID
}
if ee.Category != CatWrap {
ginJson["category"] = ee.Category
}
if ee.Type != TypeWrap {
ginJson["type"] = ee.Type
}
if ee.StatusCode != nil {
ginJson["statuscode"] = ee.StatusCode
}
if ee.Message != "" {
ginJson["message"] = ee.Message
}
if ee.Caller != "" {
ginJson["caller"] = ee.Caller
}
if ee.Severity != SevErr {
ginJson["severity"] = ee.Severity
}
if ee.Timestamp != (time.Time{}) {
ginJson["time"] = ee.Timestamp.Format(time.RFC3339)
}
if ee.WrappedErrType != "" {
ginJson["wrappedErrType"] = ee.WrappedErrType
}
if ee.OriginalError != nil {
ginJson["original"] = ee.OriginalError.toJson(depth+1, applyExtendListener, outputMeta)
}
if outputMeta {
metaJson := langext.H{}
for metaKey, metaVal := range ee.Meta {
metaJson[metaKey] = metaVal.rawValueForJson()
}
ginJson["meta"] = metaJson
}
if applyExtendListener {
pkgconfig.ExtendGinDataOutput(ee, depth, ginJson)
}
return ginJson
}
// ToAPIJson converts the ExError to a json object
// (the same object as used in the Output(gin) method)
//
// Parameters:
// - [applyExtendListener]: if false the pkgconfig.ExtendGinOutput / pkgconfig.ExtendGinDataOutput will not be applied
// - [includeWrappedErrors]: if false we do not include the recursive/wrapped errors in `__data`
// - [includeMetaFields]: if true we also include meta-values (aka from `.Str(key, value).Build()`), needs includeWrappedErrors=true
func (ee *ExErr) ToAPIJson(applyExtendListener bool, includeWrappedErrors bool, includeMetaFields bool) langext.H {
apiOutput := langext.H{
"errorid": ee.UniqueID,
"message": ee.RecursiveMessage(),
"errorcode": ee.RecursiveType().Key,
"category": ee.RecursiveCategory().Category,
}
if includeWrappedErrors {
apiOutput["__data"] = ee.toJson(0, applyExtendListener, includeMetaFields)
}
if applyExtendListener {
pkgconfig.ExtendGinOutput(ee, apiOutput)
}
return apiOutput
}
func (ee *ExErr) Output(g *gin.Context) {
warnOnPkgConfigNotInitialized()
var statuscode = http.StatusInternalServerError
var baseCat = ee.RecursiveCategory()
var baseType = ee.RecursiveType()
var baseStatuscode = ee.RecursiveStatuscode()
if baseCat == CatUser {
statuscode = http.StatusBadRequest
} else if baseCat == CatSystem {
statuscode = http.StatusInternalServerError
}
if baseStatuscode != nil {
statuscode = *ee.StatusCode
} else if baseType.DefaultStatusCode != nil {
statuscode = *baseType.DefaultStatusCode
}
ginOutput := ee.ToAPIJson(true, pkgconfig.ExtendedGinOutput, pkgconfig.IncludeMetaInGinOutput)
g.Render(statuscode, json.GoJsonRender{Data: ginOutput, NilSafeSlices: true, NilSafeMaps: true})
}

View File

@@ -1,88 +0,0 @@
package exerr
import "fmt"
// IsType test if the supplied error is of the specified ErrorType.
func IsType(err error, errType ErrorType) bool {
if err == nil {
return false
}
bmerr := FromError(err)
for bmerr != nil {
if bmerr.Type == errType {
return true
}
bmerr = bmerr.OriginalError
}
return false
}
// IsFrom test if the supplied error stems originally from original
func IsFrom(e error, original error) bool {
if e == nil {
return false
}
//goland:noinspection GoDirectComparisonOfErrors
if e == original {
return true
}
bmerr := FromError(e)
for bmerr == nil {
return false
}
for curr := bmerr; curr != nil; curr = curr.OriginalError {
if curr.Category == CatForeign && curr.Message == original.Error() && curr.WrappedErrType == fmt.Sprintf("%T", original) {
return true
}
}
return false
}
// HasSourceMessage tests if the supplied error stems originally from an error with the message msg
func HasSourceMessage(e error, msg string) bool {
if e == nil {
return false
}
bmerr := FromError(e)
for bmerr == nil {
return false
}
for curr := bmerr; curr != nil; curr = curr.OriginalError {
if curr.OriginalError == nil && curr.Message == msg {
return true
}
}
return false
}
func MessageMatch(e error, matcher func(string) bool) bool {
if e == nil {
return false
}
if matcher(e.Error()) {
return true
}
bmerr := FromError(e)
for bmerr == nil {
return false
}
for curr := bmerr; curr != nil; curr = curr.OriginalError {
if matcher(curr.Message) {
return true
}
}
return false
}

View File

@@ -1,37 +0,0 @@
package exerr
import (
"sync"
)
type Method string
const (
MethodOutput Method = "OUTPUT"
MethodPrint Method = "PRINT"
MethodBuild Method = "BUILD"
MethodFatal Method = "FATAL"
)
type Listener = func(method Method, v *ExErr)
var listenerLock = sync.Mutex{}
var listener = make([]Listener, 0)
func RegisterListener(l Listener) {
listenerLock.Lock()
defer listenerLock.Unlock()
listener = append(listener, l)
}
func (b *Builder) CallListener(m Method) {
valErr := b.errorData
listenerLock.Lock()
defer listenerLock.Unlock()
for _, v := range listener {
v(m, valErr)
}
}

View File

@@ -1,736 +0,0 @@
package exerr
import (
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"strconv"
"strings"
"time"
)
// This is a buffed up map[string]any
// we also save type information of the map-values
// which allows us to deserialize them back into te correct types later
type MetaMap map[string]MetaValue
type metaDataType string
const (
MDTString metaDataType = "String"
MDTStringPtr metaDataType = "StringPtr"
MDTInt metaDataType = "Int"
MDTInt8 metaDataType = "Int8"
MDTInt16 metaDataType = "Int16"
MDTInt32 metaDataType = "Int32"
MDTInt64 metaDataType = "Int64"
MDTFloat32 metaDataType = "Float32"
MDTFloat64 metaDataType = "Float64"
MDTBool metaDataType = "Bool"
MDTBytes metaDataType = "Bytes"
MDTObjectID metaDataType = "ObjectID"
MDTTime metaDataType = "Time"
MDTDuration metaDataType = "Duration"
MDTStringArray metaDataType = "StringArr"
MDTIntArray metaDataType = "IntArr"
MDTInt32Array metaDataType = "Int32Arr"
MDTID metaDataType = "ID"
MDTAny metaDataType = "Interface"
MDTNil metaDataType = "Nil"
MDTEnum metaDataType = "Enum"
)
type MetaValue struct {
DataType metaDataType `json:"dataType"`
Value interface{} `json:"value"`
}
type metaValueSerialization struct {
DataType metaDataType `bson:"dataType"`
Value string `bson:"value"`
Raw interface{} `bson:"raw"`
}
func (v MetaValue) SerializeValue() (string, error) {
switch v.DataType {
case MDTString:
return v.Value.(string), nil
case MDTID:
return v.Value.(IDWrap).Serialize(), nil
case MDTAny:
return v.Value.(AnyWrap).Serialize(), nil
case MDTStringPtr:
if langext.IsNil(v.Value) {
return "#", nil
}
r := v.Value.(*string)
if r != nil {
return "*" + *r, nil
} else {
return "#", nil
}
case MDTInt:
return strconv.Itoa(v.Value.(int)), nil
case MDTInt8:
return strconv.FormatInt(int64(v.Value.(int8)), 10), nil
case MDTInt16:
return strconv.FormatInt(int64(v.Value.(int16)), 10), nil
case MDTInt32:
return strconv.FormatInt(int64(v.Value.(int32)), 10), nil
case MDTInt64:
return strconv.FormatInt(v.Value.(int64), 10), nil
case MDTFloat32:
return strconv.FormatFloat(float64(v.Value.(float32)), 'X', -1, 32), nil
case MDTFloat64:
return strconv.FormatFloat(v.Value.(float64), 'X', -1, 64), nil
case MDTBool:
if v.Value.(bool) {
return "true", nil
} else {
return "false", nil
}
case MDTBytes:
return hex.EncodeToString(v.Value.([]byte)), nil
case MDTObjectID:
return v.Value.(primitive.ObjectID).Hex(), nil
case MDTTime:
return strconv.FormatInt(v.Value.(time.Time).Unix(), 10) + "|" + strconv.FormatInt(int64(v.Value.(time.Time).Nanosecond()), 10), nil
case MDTDuration:
return v.Value.(time.Duration).String(), nil
case MDTStringArray:
if langext.IsNil(v.Value) {
return "#", nil
}
r, err := json.Marshal(v.Value.([]string))
if err != nil {
return "", err
}
return string(r), nil
case MDTIntArray:
if langext.IsNil(v.Value) {
return "#", nil
}
r, err := json.Marshal(v.Value.([]int))
if err != nil {
return "", err
}
return string(r), nil
case MDTInt32Array:
if langext.IsNil(v.Value) {
return "#", nil
}
r, err := json.Marshal(v.Value.([]int32))
if err != nil {
return "", err
}
return string(r), nil
case MDTNil:
return "", nil
case MDTEnum:
return v.Value.(EnumWrap).Serialize(), nil
}
return "", errors.New("Unknown type: " + string(v.DataType))
}
func (v MetaValue) ShortString(lim int) string {
switch v.DataType {
case MDTString:
r := strings.ReplaceAll(v.Value.(string), "\r", "")
r = strings.ReplaceAll(r, "\n", "\\n")
r = strings.ReplaceAll(r, "\t", "\\t")
return langext.StrLimit(r, lim, "...")
case MDTID:
return v.Value.(IDWrap).String()
case MDTAny:
return v.Value.(AnyWrap).String()
case MDTStringPtr:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r := langext.CoalesceString(v.Value.(*string), "<<null>>")
r = strings.ReplaceAll(r, "\r", "")
r = strings.ReplaceAll(r, "\n", "\\n")
r = strings.ReplaceAll(r, "\t", "\\t")
return langext.StrLimit(r, lim, "...")
case MDTInt:
return strconv.Itoa(v.Value.(int))
case MDTInt8:
return strconv.FormatInt(int64(v.Value.(int8)), 10)
case MDTInt16:
return strconv.FormatInt(int64(v.Value.(int16)), 10)
case MDTInt32:
return strconv.FormatInt(int64(v.Value.(int32)), 10)
case MDTInt64:
return strconv.FormatInt(v.Value.(int64), 10)
case MDTFloat32:
return strconv.FormatFloat(float64(v.Value.(float32)), 'g', 4, 32)
case MDTFloat64:
return strconv.FormatFloat(v.Value.(float64), 'g', 4, 64)
case MDTBool:
return fmt.Sprintf("%v", v.Value.(bool))
case MDTBytes:
return langext.StrLimit(hex.EncodeToString(v.Value.([]byte)), lim, "...")
case MDTObjectID:
return v.Value.(primitive.ObjectID).Hex()
case MDTTime:
return v.Value.(time.Time).Format(time.RFC3339)
case MDTDuration:
return v.Value.(time.Duration).String()
case MDTStringArray:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r, err := json.Marshal(v.Value.([]string))
if err != nil {
return "(err)"
}
return langext.StrLimit(string(r), lim, "...")
case MDTIntArray:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r, err := json.Marshal(v.Value.([]int))
if err != nil {
return "(err)"
}
return langext.StrLimit(string(r), lim, "...")
case MDTInt32Array:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r, err := json.Marshal(v.Value.([]int32))
if err != nil {
return "(err)"
}
return langext.StrLimit(string(r), lim, "...")
case MDTNil:
return "<<null>>"
case MDTEnum:
return v.Value.(EnumWrap).String()
}
return "(err)"
}
func (v MetaValue) Apply(key string, evt *zerolog.Event, limitLen *int) *zerolog.Event {
switch v.DataType {
case MDTString:
if limitLen == nil {
return evt.Str(key, v.Value.(string))
} else {
return evt.Str(key, langext.StrLimit(v.Value.(string), *limitLen, "..."))
}
case MDTID:
return evt.Str(key, v.Value.(IDWrap).Value)
case MDTAny:
if v.Value.(AnyWrap).IsError {
return evt.Str(key, "(err)")
} else {
if limitLen == nil {
return evt.Str(key, v.Value.(AnyWrap).Json)
} else {
return evt.Str(key, langext.StrLimit(v.Value.(AnyWrap).Json, *limitLen, "..."))
}
}
case MDTStringPtr:
if langext.IsNil(v.Value) {
return evt.Str(key, "<<null>>")
}
if limitLen == nil {
return evt.Str(key, langext.CoalesceString(v.Value.(*string), "<<null>>"))
} else {
return evt.Str(key, langext.StrLimit(langext.CoalesceString(v.Value.(*string), "<<null>>"), *limitLen, "..."))
}
case MDTInt:
return evt.Int(key, v.Value.(int))
case MDTInt8:
return evt.Int8(key, v.Value.(int8))
case MDTInt16:
return evt.Int16(key, v.Value.(int16))
case MDTInt32:
return evt.Int32(key, v.Value.(int32))
case MDTInt64:
return evt.Int64(key, v.Value.(int64))
case MDTFloat32:
return evt.Float32(key, v.Value.(float32))
case MDTFloat64:
return evt.Float64(key, v.Value.(float64))
case MDTBool:
return evt.Bool(key, v.Value.(bool))
case MDTBytes:
return evt.Bytes(key, v.Value.([]byte))
case MDTObjectID:
return evt.Str(key, v.Value.(primitive.ObjectID).Hex())
case MDTTime:
return evt.Time(key, v.Value.(time.Time))
case MDTDuration:
return evt.Dur(key, v.Value.(time.Duration))
case MDTStringArray:
if langext.IsNil(v.Value) {
return evt.Strs(key, nil)
}
return evt.Strs(key, v.Value.([]string))
case MDTIntArray:
if langext.IsNil(v.Value) {
return evt.Ints(key, nil)
}
return evt.Ints(key, v.Value.([]int))
case MDTInt32Array:
if langext.IsNil(v.Value) {
return evt.Ints32(key, nil)
}
return evt.Ints32(key, v.Value.([]int32))
case MDTNil:
return evt.Str(key, "<<null>>")
case MDTEnum:
if v.Value.(EnumWrap).IsNil {
return evt.Any(key, nil)
} else if v.Value.(EnumWrap).ValueRaw != nil {
return evt.Any(key, v.Value.(EnumWrap).ValueRaw)
} else {
return evt.Str(key, v.Value.(EnumWrap).ValueString)
}
}
return evt.Str(key, "(err)")
}
func (v MetaValue) MarshalJSON() ([]byte, error) {
str, err := v.SerializeValue()
if err != nil {
return nil, err
}
return json.Marshal(string(v.DataType) + ":" + str)
}
func (v *MetaValue) UnmarshalJSON(data []byte) error {
var str = ""
err := json.Unmarshal(data, &str)
if err != nil {
return err
}
split := strings.SplitN(str, ":", 2)
if len(split) != 2 {
return errors.New("failed to decode MetaValue: '" + str + "'")
}
return v.Deserialize(split[1], metaDataType(split[0]))
}
func (v MetaValue) MarshalBSON() ([]byte, error) {
serval, err := v.SerializeValue()
if err != nil {
return nil, Wrap(err, "failed to bson-marshal MetaValue (serialize)").Build()
}
// this is an kinda ugly hack - but serialization to mongodb and back can loose the correct type information....
bin, err := bson.Marshal(metaValueSerialization{
DataType: v.DataType,
Value: serval,
Raw: v.Value,
})
if err != nil {
return nil, Wrap(err, "failed to bson-marshal MetaValue (marshal)").Build()
}
return bin, nil
}
func (v *MetaValue) UnmarshalBSON(bytes []byte) error {
var serval metaValueSerialization
err := bson.Unmarshal(bytes, &serval)
if err != nil {
return Wrap(err, "failed to bson-unmarshal MetaValue (unmarshal)").Build()
}
err = v.Deserialize(serval.Value, serval.DataType)
if err != nil {
return Wrap(err, "failed to deserialize MetaValue from bson").Str("raw", serval.Value).Build()
}
return nil
}
func (v *MetaValue) Deserialize(value string, datatype metaDataType) error {
switch datatype {
case MDTString:
v.Value = value
v.DataType = datatype
return nil
case MDTID:
v.Value = deserializeIDWrap(value)
v.DataType = datatype
return nil
case MDTAny:
v.Value = deserializeAnyWrap(value)
v.DataType = datatype
return nil
case MDTStringPtr:
if len(value) <= 0 || (value[0] != '*' && value[0] != '#') {
return errors.New("Invalid StringPtr: " + value)
} else if value == "#" {
v.Value = nil
v.DataType = datatype
return nil
} else {
v.Value = langext.Ptr(value[1:])
v.DataType = datatype
return nil
}
case MDTInt:
pv, err := strconv.ParseInt(value, 10, 0)
if err != nil {
return err
}
v.Value = int(pv)
v.DataType = datatype
return nil
case MDTInt8:
pv, err := strconv.ParseInt(value, 10, 8)
if err != nil {
return err
}
v.Value = int8(pv)
v.DataType = datatype
return nil
case MDTInt16:
pv, err := strconv.ParseInt(value, 10, 16)
if err != nil {
return err
}
v.Value = int16(pv)
v.DataType = datatype
return nil
case MDTInt32:
pv, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return err
}
v.Value = int32(pv)
v.DataType = datatype
return nil
case MDTInt64:
pv, err := strconv.ParseInt(value, 10, 64)
if err != nil {
return err
}
v.Value = pv
v.DataType = datatype
return nil
case MDTFloat32:
pv, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
v.Value = float32(pv)
v.DataType = datatype
return nil
case MDTFloat64:
pv, err := strconv.ParseFloat(value, 64)
if err != nil {
return err
}
v.Value = pv
v.DataType = datatype
return nil
case MDTBool:
if value == "true" {
v.Value = true
v.DataType = datatype
return nil
}
if value == "false" {
v.Value = false
v.DataType = datatype
return nil
}
return errors.New("invalid bool value: " + value)
case MDTBytes:
r, err := hex.DecodeString(value)
if err != nil {
return err
}
v.Value = r
v.DataType = datatype
return nil
case MDTObjectID:
r, err := primitive.ObjectIDFromHex(value)
if err != nil {
return err
}
v.Value = r
v.DataType = datatype
return nil
case MDTTime:
ps := strings.Split(value, "|")
if len(ps) != 2 {
return errors.New("invalid time.time: " + value)
}
p1, err := strconv.ParseInt(ps[0], 10, 64)
if err != nil {
return err
}
p2, err := strconv.ParseInt(ps[1], 10, 32)
if err != nil {
return err
}
v.Value = time.Unix(p1, p2)
v.DataType = datatype
return nil
case MDTDuration:
r, err := time.ParseDuration(value)
if err != nil {
return err
}
v.Value = r
v.DataType = datatype
return nil
case MDTStringArray:
if value == "#" {
v.Value = nil
v.DataType = datatype
return nil
}
pj := make([]string, 0)
err := json.Unmarshal([]byte(value), &pj)
if err != nil {
return err
}
v.Value = pj
v.DataType = datatype
return nil
case MDTIntArray:
if value == "#" {
v.Value = nil
v.DataType = datatype
return nil
}
pj := make([]int, 0)
err := json.Unmarshal([]byte(value), &pj)
if err != nil {
return err
}
v.Value = pj
v.DataType = datatype
return nil
case MDTInt32Array:
if value == "#" {
v.Value = nil
v.DataType = datatype
return nil
}
pj := make([]int32, 0)
err := json.Unmarshal([]byte(value), &pj)
if err != nil {
return err
}
v.Value = pj
v.DataType = datatype
return nil
case MDTNil:
v.Value = nil
v.DataType = datatype
return nil
case MDTEnum:
v.Value = deserializeEnumWrap(value)
v.DataType = datatype
return nil
}
return errors.New("Unknown type: " + string(datatype))
}
func (v MetaValue) ValueString() string {
switch v.DataType {
case MDTString:
return v.Value.(string)
case MDTID:
return v.Value.(IDWrap).String()
case MDTAny:
return v.Value.(AnyWrap).String()
case MDTStringPtr:
if langext.IsNil(v.Value) {
return "<<null>>"
}
return langext.CoalesceString(v.Value.(*string), "<<null>>")
case MDTInt:
return strconv.Itoa(v.Value.(int))
case MDTInt8:
return strconv.FormatInt(int64(v.Value.(int8)), 10)
case MDTInt16:
return strconv.FormatInt(int64(v.Value.(int16)), 10)
case MDTInt32:
return strconv.FormatInt(int64(v.Value.(int32)), 10)
case MDTInt64:
return strconv.FormatInt(v.Value.(int64), 10)
case MDTFloat32:
return strconv.FormatFloat(float64(v.Value.(float32)), 'g', 4, 32)
case MDTFloat64:
return strconv.FormatFloat(v.Value.(float64), 'g', 4, 64)
case MDTBool:
return fmt.Sprintf("%v", v.Value.(bool))
case MDTBytes:
return hex.EncodeToString(v.Value.([]byte))
case MDTObjectID:
return v.Value.(primitive.ObjectID).Hex()
case MDTTime:
return v.Value.(time.Time).Format(time.RFC3339Nano)
case MDTDuration:
return v.Value.(time.Duration).String()
case MDTStringArray:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r, err := json.MarshalIndent(v.Value.([]string), "", " ")
if err != nil {
return "(err)"
}
return string(r)
case MDTIntArray:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r, err := json.MarshalIndent(v.Value.([]int), "", " ")
if err != nil {
return "(err)"
}
return string(r)
case MDTInt32Array:
if langext.IsNil(v.Value) {
return "<<null>>"
}
r, err := json.MarshalIndent(v.Value.([]int32), "", " ")
if err != nil {
return "(err)"
}
return string(r)
case MDTNil:
return "<<null>>"
case MDTEnum:
return v.Value.(EnumWrap).String()
}
return "(err)"
}
// rawValueForJson returns most-of-the-time the `Value` field
// but for some datatyes we do special processing
// all, so we can pluck the output value in json.Marshal without any suprises
func (v MetaValue) rawValueForJson() any {
if v.DataType == MDTAny {
if v.Value.(AnyWrap).IsNil {
return nil
}
if v.Value.(AnyWrap).IsError {
return bson.M{"@error": true}
}
jsonobj := primitive.M{}
jsonarr := primitive.A{}
if err := json.Unmarshal([]byte(v.Value.(AnyWrap).Json), &jsonobj); err == nil {
return jsonobj
} else if err := json.Unmarshal([]byte(v.Value.(AnyWrap).Json), &jsonarr); err == nil {
return jsonarr
} else {
return bson.M{"type": v.Value.(AnyWrap).Type, "data": v.Value.(AnyWrap).Json}
}
}
if v.DataType == MDTID {
if v.Value.(IDWrap).IsNil {
return nil
}
return v.Value.(IDWrap).Value
}
if v.DataType == MDTBytes {
return hex.EncodeToString(v.Value.([]byte))
}
if v.DataType == MDTDuration {
return v.Value.(time.Duration).String()
}
if v.DataType == MDTTime {
return v.Value.(time.Time).Format(time.RFC3339Nano)
}
if v.DataType == MDTObjectID {
return v.Value.(primitive.ObjectID).Hex()
}
if v.DataType == MDTNil {
return nil
}
if v.DataType == MDTEnum {
if v.Value.(EnumWrap).IsNil {
return nil
}
if v.Value.(EnumWrap).ValueRaw != nil {
return v.Value.(EnumWrap).ValueRaw
}
return v.Value.(EnumWrap).ValueString
}
return v.Value
}
func (mm MetaMap) FormatOneLine(singleMaxLen int) string {
r := ""
i := 0
for key, val := range mm {
if i > 0 {
r += ", "
}
r += "\"" + key + "\""
r += ": "
r += "\"" + val.ShortString(singleMaxLen) + "\""
i++
}
return r
}
func (mm MetaMap) FormatMultiLine(indentFront string, indentKeys string, maxLenValue int) string {
r := ""
r += indentFront + "{" + "\n"
for key, val := range mm {
if key == "gin.body" {
continue
}
r += indentFront
r += indentKeys
r += "\"" + key + "\""
r += ": "
r += "\"" + val.ShortString(maxLenValue) + "\""
r += ",\n"
}
r += indentFront + "}"
return r
}
func (mm MetaMap) Any() bool {
return len(mm) > 0
}
func (mm MetaMap) Apply(evt *zerolog.Event, limitLen *int) *zerolog.Event {
for key, val := range mm {
evt = val.Apply(key, evt, limitLen)
}
return evt
}
func (mm MetaMap) add(key string, mdtype metaDataType, val interface{}) {
if _, ok := mm[key]; !ok {
mm[key] = MetaValue{DataType: mdtype, Value: val}
return
}
for i := 2; ; i++ {
realkey := key + "-" + strconv.Itoa(i)
if _, ok := mm[realkey]; !ok {
mm[realkey] = MetaValue{DataType: mdtype, Value: val}
return
}
}
}

View File

@@ -1,14 +0,0 @@
package exerr
import (
"fmt"
"runtime"
)
func callername(skip int) string {
pc := make([]uintptr, 15)
n := runtime.Callers(skip+2, pc)
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
return fmt.Sprintf("%s:%d %s", frame.File, frame.Line, frame.Function)
}

View File

@@ -1,189 +0,0 @@
package exerr
import (
"encoding/json"
"fmt"
"github.com/rs/zerolog/log"
"gogs.mikescher.com/BlackForestBytes/goext/enums"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"strings"
)
//
// These are wrapper objects, because for some metadata-types we need to serialize a bit more complex data
// (eg thy actual type for ID objects, or the json representation for any types)
//
type IDWrap struct {
Type string
Value string
IsNil bool
}
func newIDWrap(val fmt.Stringer) IDWrap {
t := fmt.Sprintf("%T", val)
arr := strings.Split(t, ".")
if len(arr) > 0 {
t = arr[len(arr)-1]
}
if langext.IsNil(val) {
return IDWrap{Type: t, Value: "", IsNil: true}
}
v := val.String()
return IDWrap{Type: t, Value: v, IsNil: false}
}
func (w IDWrap) Serialize() string {
if w.IsNil {
return "!nil" + ":" + w.Type
}
return w.Type + ":" + w.Value
}
func (w IDWrap) String() string {
if w.IsNil {
return w.Type + "<<nil>>"
}
return w.Type + "(" + w.Value + ")"
}
func deserializeIDWrap(v string) IDWrap {
r := strings.SplitN(v, ":", 2)
if len(r) == 2 && r[0] == "!nil" {
return IDWrap{Type: r[1], Value: v, IsNil: true}
}
if len(r) == 0 {
return IDWrap{}
} else if len(r) == 1 {
return IDWrap{Type: "", Value: v, IsNil: false}
} else {
return IDWrap{Type: r[0], Value: r[1], IsNil: false}
}
}
type AnyWrap struct {
Type string
Json string
IsError bool
IsNil bool
}
func newAnyWrap(val any) (result AnyWrap) {
result = AnyWrap{Type: "", Json: "", IsError: true, IsNil: false} // ensure a return in case of recover()
defer func() {
if err := recover(); err != nil {
// send error should never crash our program
log.Error().Interface("err", err).Msg("Panic while trying to marshal anywrap ( bmerror.Interface )")
}
}()
t := fmt.Sprintf("%T", val)
if langext.IsNil(val) {
return AnyWrap{Type: t, Json: "", IsError: false, IsNil: true}
}
j, err := json.Marshal(val)
if err == nil {
return AnyWrap{Type: t, Json: string(j), IsError: false, IsNil: false}
} else {
return AnyWrap{Type: t, Json: "", IsError: true, IsNil: false}
}
}
func (w AnyWrap) Serialize() string {
if w.IsError {
return "ERR" + ":" + w.Type + ":" + w.Json
} else if w.IsNil {
return "NIL" + ":" + w.Type + ":" + w.Json
} else {
return "OK" + ":" + w.Type + ":" + w.Json
}
}
func (w AnyWrap) String() string {
if w.IsError {
return "(error)"
} else if w.IsNil {
return "(nil)"
} else {
return w.Json
}
}
func deserializeAnyWrap(v string) AnyWrap {
r := strings.SplitN(v, ":", 3)
if len(r) != 3 {
return AnyWrap{IsError: true, Type: "", Json: "", IsNil: false}
} else {
if r[0] == "OK" {
return AnyWrap{IsError: false, Type: r[1], Json: r[2], IsNil: false}
} else if r[0] == "ERR" {
return AnyWrap{IsError: true, Type: r[1], Json: r[2], IsNil: false}
} else if r[0] == "NIL" {
return AnyWrap{IsError: false, Type: r[1], Json: "", IsNil: true}
} else {
return AnyWrap{IsError: true, Type: "", Json: "", IsNil: false}
}
}
}
type EnumWrap struct {
Type string
ValueString string
ValueRaw enums.Enum // `ValueRaw` is lost during serialization roundtrip
IsNil bool
}
func newEnumWrap(val enums.Enum) EnumWrap {
t := fmt.Sprintf("%T", val)
arr := strings.Split(t, ".")
if len(arr) > 0 {
t = arr[len(arr)-1]
}
if langext.IsNil(val) {
return EnumWrap{Type: t, ValueString: "", ValueRaw: val, IsNil: true}
}
if enumstr, ok := val.(enums.StringEnum); ok {
return EnumWrap{Type: t, ValueString: enumstr.String(), ValueRaw: val, IsNil: false}
}
return EnumWrap{Type: t, ValueString: fmt.Sprintf("%v", val), ValueRaw: val, IsNil: false}
}
func (w EnumWrap) Serialize() string {
if w.IsNil {
return "!nil" + ":" + w.Type
}
return w.Type + ":" + w.ValueString
}
func (w EnumWrap) String() string {
if w.IsNil {
return w.Type + "<<nil>>"
}
return "[" + w.Type + "] " + w.ValueString
}
func deserializeEnumWrap(v string) EnumWrap {
r := strings.SplitN(v, ":", 2)
if len(r) == 2 && r[0] == "!nil" {
return EnumWrap{Type: r[1], ValueString: v, ValueRaw: nil, IsNil: true}
}
if len(r) == 0 {
return EnumWrap{}
} else if len(r) == 1 {
return EnumWrap{Type: "", ValueString: v, ValueRaw: nil, IsNil: false}
} else {
return EnumWrap{Type: r[0], ValueString: r[1], ValueRaw: nil, IsNil: false}
}
}

View File

@@ -1,59 +0,0 @@
package ginext
import (
"context"
"github.com/gin-gonic/gin"
"time"
)
type AppContext struct {
inner context.Context
cancelFunc context.CancelFunc
cancelled bool
GinContext *gin.Context
}
func CreateAppContext(g *gin.Context, innerCtx context.Context, cancelFn context.CancelFunc) *AppContext {
for key, value := range g.Keys {
innerCtx = context.WithValue(innerCtx, key, value)
}
return &AppContext{
inner: innerCtx,
cancelFunc: cancelFn,
cancelled: false,
GinContext: g,
}
}
func (ac *AppContext) Deadline() (deadline time.Time, ok bool) {
return ac.inner.Deadline()
}
func (ac *AppContext) Done() <-chan struct{} {
return ac.inner.Done()
}
func (ac *AppContext) Err() error {
return ac.inner.Err()
}
func (ac *AppContext) Value(key any) any {
return ac.inner.Value(key)
}
func (ac *AppContext) Set(key, value any) {
ac.inner = context.WithValue(ac.inner, key, value)
}
func (ac *AppContext) Cancel() {
ac.cancelled = true
ac.cancelFunc()
}
func (ac *AppContext) RequestURI() string {
if ac.GinContext != nil && ac.GinContext.Request != nil {
return ac.GinContext.Request.Method + " :: " + ac.GinContext.Request.RequestURI
} else {
return ""
}
}

View File

@@ -1,23 +0,0 @@
package ginext
import (
"net/http"
)
func RedirectFound(newuri string) WHandlerFunc {
return func(pctx PreContext) HTTPResponse {
return Redirect(http.StatusFound, newuri)
}
}
func RedirectTemporary(newuri string) WHandlerFunc {
return func(pctx PreContext) HTTPResponse {
return Redirect(http.StatusTemporaryRedirect, newuri)
}
}
func RedirectPermanent(newuri string) WHandlerFunc {
return func(pctx PreContext) HTTPResponse {
return Redirect(http.StatusPermanentRedirect, newuri)
}
}

View File

@@ -1,12 +0,0 @@
package ginext
import (
"github.com/gin-gonic/gin"
"gogs.mikescher.com/BlackForestBytes/goext/dataext"
)
func BodyBuffer(g *gin.Context) {
if g.Request.Body != nil {
g.Request.Body = dataext.NewBufferedReadCloser(g.Request.Body)
}
}

View File

@@ -1,21 +0,0 @@
package ginext
import (
"github.com/gin-gonic/gin"
"net/http"
)
func CorsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
c.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PUT, PATCH, DELETE, COUNT")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusOK)
} else {
c.Next()
}
}
}

View File

@@ -1,149 +0,0 @@
package ginext
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/mathext"
"net"
"net/http"
"strings"
"time"
)
type GinWrapper struct {
engine *gin.Engine
SuppressGinLogs bool
allowCors bool
ginDebug bool
bufferBody bool
requestTimeout time.Duration
routeSpecs []ginRouteSpec
}
type ginRouteSpec struct {
Method string
URL string
Middlewares []string
Handler string
}
// NewEngine creates a new (wrapped) ginEngine
// Parameters are:
// - [allowCors] Add cors handler to allow all CORS requests on the default http methods
// - [ginDebug] Set gin.debug to true (adds more logs)
// - [bufferBody] Buffers the input body stream, this way the ginext error handler can later include the whole request body
// - [timeout] The default handler timeout
func NewEngine(allowCors bool, ginDebug bool, bufferBody bool, timeout time.Duration) *GinWrapper {
engine := gin.New()
wrapper := &GinWrapper{
engine: engine,
SuppressGinLogs: false,
allowCors: allowCors,
ginDebug: ginDebug,
bufferBody: bufferBody,
requestTimeout: timeout,
}
engine.RedirectFixedPath = false
engine.RedirectTrailingSlash = false
if allowCors {
engine.Use(CorsMiddleware())
}
// do not debug-print routes
gin.DebugPrintRouteFunc = func(_, _, _ string, _ int) {}
if !ginDebug {
gin.SetMode(gin.ReleaseMode)
ginlogger := gin.Logger()
engine.Use(func(context *gin.Context) {
if !wrapper.SuppressGinLogs {
ginlogger(context)
}
})
} else {
gin.SetMode(gin.DebugMode)
}
return wrapper
}
func (w *GinWrapper) ListenAndServeHTTP(addr string, postInit func(port string)) (chan error, *http.Server) {
w.DebugPrintRoutes()
httpserver := &http.Server{
Addr: addr,
Handler: w.engine,
}
errChan := make(chan error)
go func() {
ln, err := net.Listen("tcp", httpserver.Addr)
if err != nil {
errChan <- err
return
}
_, port, err := net.SplitHostPort(ln.Addr().String())
if err != nil {
errChan <- err
return
}
log.Info().Str("address", httpserver.Addr).Msg("HTTP-Server started on http://localhost:" + port)
if postInit != nil {
postInit(port) // the net.Listener a few lines above is at this point actually already buffering requests
}
errChan <- httpserver.Serve(ln)
}()
return errChan, httpserver
}
func (w *GinWrapper) DebugPrintRoutes() {
if !w.ginDebug {
return
}
lines := make([][4]string, 0)
pad := [4]int{0, 0, 0, 0}
for _, spec := range w.routeSpecs {
line := [4]string{
spec.Method,
spec.URL,
strings.Join(spec.Middlewares, " -> "),
spec.Handler,
}
lines = append(lines, line)
pad[0] = mathext.Max(pad[0], len(line[0]))
pad[1] = mathext.Max(pad[1], len(line[1]))
pad[2] = mathext.Max(pad[2], len(line[2]))
pad[3] = mathext.Max(pad[3], len(line[3]))
}
for _, line := range lines {
fmt.Printf("Gin-Route: %s %s --> %s --> %s\n",
langext.StrPadRight("["+line[0]+"]", " ", pad[0]+2),
langext.StrPadRight(line[1], " ", pad[1]),
langext.StrPadRight(line[2], " ", pad[2]),
langext.StrPadRight(line[3], " ", pad[3]))
}
}

View File

@@ -1,39 +0,0 @@
package ginext
import (
"fmt"
"github.com/gin-gonic/gin"
"gogs.mikescher.com/BlackForestBytes/goext/exerr"
)
type WHandlerFunc func(PreContext) HTTPResponse
func Wrap(w *GinWrapper, fn WHandlerFunc) gin.HandlerFunc {
return func(g *gin.Context) {
reqctx := g.Request.Context()
wrap, stackTrace, panicObj := callPanicSafe(fn, PreContext{wrapper: w, ginCtx: g})
if panicObj != nil {
fmt.Printf("\n======== ======== STACKTRACE ======== ========\n%s\n======== ======== ======== ========\n\n", stackTrace)
err := exerr.
New(exerr.TypePanic, "Panic occured (in gin handler)").
Any("panicObj", panicObj).
Str("trace", stackTrace).
Build()
wrap = Error(err)
}
if g.Writer.Written() {
panic("Writing in WrapperFunc is not supported")
}
if reqctx.Err() == nil {
wrap.Write(g)
}
}
}

View File

@@ -1,145 +0,0 @@
package ginext
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"gogs.mikescher.com/BlackForestBytes/goext/exerr"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"runtime/debug"
"time"
)
type PreContext struct {
ginCtx *gin.Context
wrapper *GinWrapper
uri any
query any
body any
form any
header any
timeout *time.Duration
}
func (pctx *PreContext) URI(uri any) *PreContext {
pctx.uri = uri
return pctx
}
func (pctx *PreContext) Query(query any) *PreContext {
pctx.query = query
return pctx
}
func (pctx *PreContext) Body(body any) *PreContext {
pctx.body = body
return pctx
}
func (pctx *PreContext) Form(form any) *PreContext {
pctx.form = form
return pctx
}
func (pctx *PreContext) Header(header any) *PreContext {
pctx.header = header
return pctx
}
func (pctx *PreContext) WithTimeout(to time.Duration) *PreContext {
pctx.timeout = &to
return pctx
}
func (pctx PreContext) Start() (*AppContext, *gin.Context, *HTTPResponse) {
if pctx.uri != nil {
if err := pctx.ginCtx.ShouldBindUri(pctx.uri); err != nil {
err = exerr.Wrap(err, "Failed to read uri").
WithType(exerr.TypeBindFailURI).
Str("struct_type", fmt.Sprintf("%T", pctx.uri)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
}
if pctx.query != nil {
if err := pctx.ginCtx.ShouldBindQuery(pctx.query); err != nil {
err = exerr.Wrap(err, "Failed to read query").
WithType(exerr.TypeBindFailQuery).
Str("struct_type", fmt.Sprintf("%T", pctx.query)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
}
if pctx.body != nil {
if pctx.ginCtx.ContentType() == "application/json" {
if err := pctx.ginCtx.ShouldBindJSON(pctx.body); err != nil {
err = exerr.Wrap(err, "Failed to read json-body").
WithType(exerr.TypeBindFailJSON).
Str("struct_type", fmt.Sprintf("%T", pctx.body)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
} else {
err := exerr.New(exerr.TypeBindFailJSON, "missing JSON body").
Str("struct_type", fmt.Sprintf("%T", pctx.body)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
}
if pctx.form != nil {
if pctx.ginCtx.ContentType() == "multipart/form-data" {
if err := pctx.ginCtx.ShouldBindWith(pctx.form, binding.Form); err != nil {
err = exerr.Wrap(err, "Failed to read multipart-form").
WithType(exerr.TypeBindFailFormData).
Str("struct_type", fmt.Sprintf("%T", pctx.form)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
} else if pctx.ginCtx.ContentType() == "application/x-www-form-urlencoded" {
if err := pctx.ginCtx.ShouldBindWith(pctx.form, binding.Form); err != nil {
err = exerr.Wrap(err, "Failed to read urlencoded-form").
WithType(exerr.TypeBindFailFormData).
Str("struct_type", fmt.Sprintf("%T", pctx.form)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
} else {
err := exerr.New(exerr.TypeBindFailFormData, "missing form body").
Str("struct_type", fmt.Sprintf("%T", pctx.form)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
}
if pctx.header != nil {
if err := pctx.ginCtx.ShouldBindHeader(pctx.header); err != nil {
err = exerr.Wrap(err, "Failed to read header").
WithType(exerr.TypeBindFailHeader).
Str("struct_type", fmt.Sprintf("%T", pctx.query)).
Build()
return nil, nil, langext.Ptr(Error(err))
}
}
ictx, cancel := context.WithTimeout(context.Background(), langext.Coalesce(pctx.timeout, pctx.wrapper.requestTimeout))
actx := CreateAppContext(pctx.ginCtx, ictx, cancel)
return actx, pctx.ginCtx, nil
}
func callPanicSafe(fn WHandlerFunc, pctx PreContext) (res HTTPResponse, stackTrace string, panicObj any) {
defer func() {
if rec := recover(); rec != nil {
res = nil
stackTrace = string(debug.Stack())
panicObj = rec
}
}()
res = fn(pctx)
return res, "", nil
}

View File

@@ -1,220 +0,0 @@
package ginext
import (
"fmt"
"github.com/gin-gonic/gin"
"gogs.mikescher.com/BlackForestBytes/goext/exerr"
json "gogs.mikescher.com/BlackForestBytes/goext/gojson"
)
type headerval struct {
Key string
Val string
}
type HTTPResponse interface {
Write(g *gin.Context)
WithHeader(k string, v string) HTTPResponse
}
type jsonHTTPResponse struct {
statusCode int
data any
headers []headerval
}
func (j jsonHTTPResponse) Write(g *gin.Context) {
for _, v := range j.headers {
g.Header(v.Key, v.Val)
}
var f *string
if jsonfilter := g.GetString("goext.jsonfilter"); jsonfilter != "" {
f = &jsonfilter
}
g.Render(j.statusCode, json.GoJsonRender{Data: j.data, NilSafeSlices: true, NilSafeMaps: true, Filter: f})
}
func (j jsonHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type emptyHTTPResponse struct {
statusCode int
headers []headerval
}
func (j emptyHTTPResponse) Write(g *gin.Context) {
for _, v := range j.headers {
g.Header(v.Key, v.Val)
}
g.Status(j.statusCode)
}
func (j emptyHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type textHTTPResponse struct {
statusCode int
data string
headers []headerval
}
func (j textHTTPResponse) Write(g *gin.Context) {
for _, v := range j.headers {
g.Header(v.Key, v.Val)
}
g.String(j.statusCode, "%s", j.data)
}
func (j textHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type dataHTTPResponse struct {
statusCode int
data []byte
contentType string
headers []headerval
}
func (j dataHTTPResponse) Write(g *gin.Context) {
for _, v := range j.headers {
g.Header(v.Key, v.Val)
}
g.Data(j.statusCode, j.contentType, j.data)
}
func (j dataHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type fileHTTPResponse struct {
mimetype string
filepath string
filename *string
headers []headerval
}
func (j fileHTTPResponse) Write(g *gin.Context) {
g.Header("Content-Type", j.mimetype) // if we don't set it here gin does weird file-sniffing later...
if j.filename != nil {
g.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", *j.filename))
}
for _, v := range j.headers {
g.Header(v.Key, v.Val)
}
g.File(j.filepath)
}
func (j fileHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type downloadDataHTTPResponse struct {
statusCode int
mimetype string
data []byte
filename *string
headers []headerval
}
func (j downloadDataHTTPResponse) Write(g *gin.Context) {
g.Header("Content-Type", j.mimetype) // if we don't set it here gin does weird file-sniffing later...
if j.filename != nil {
g.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", *j.filename))
}
for _, v := range j.headers {
g.Header(v.Key, v.Val)
}
g.Data(j.statusCode, j.mimetype, j.data)
}
func (j downloadDataHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type redirectHTTPResponse struct {
statusCode int
url string
headers []headerval
}
func (j redirectHTTPResponse) Write(g *gin.Context) {
g.Redirect(j.statusCode, j.url)
}
func (j redirectHTTPResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
type jsonAPIErrResponse struct {
err *exerr.ExErr
headers []headerval
}
func (j jsonAPIErrResponse) Write(g *gin.Context) {
j.err.Output(g)
}
func (j jsonAPIErrResponse) WithHeader(k string, v string) HTTPResponse {
j.headers = append(j.headers, headerval{k, v})
return j
}
func Status(sc int) HTTPResponse {
return &emptyHTTPResponse{statusCode: sc}
}
func JSON(sc int, data any) HTTPResponse {
return &jsonHTTPResponse{statusCode: sc, data: data}
}
func Data(sc int, contentType string, data []byte) HTTPResponse {
return &dataHTTPResponse{statusCode: sc, contentType: contentType, data: data}
}
func Text(sc int, data string) HTTPResponse {
return &textHTTPResponse{statusCode: sc, data: data}
}
func File(mimetype string, filepath string) HTTPResponse {
return &fileHTTPResponse{mimetype: mimetype, filepath: filepath}
}
func Download(mimetype string, filepath string, filename string) HTTPResponse {
return &fileHTTPResponse{mimetype: mimetype, filepath: filepath, filename: &filename}
}
func DownloadData(status int, mimetype string, filename string, data []byte) HTTPResponse {
return &downloadDataHTTPResponse{statusCode: status, mimetype: mimetype, data: data, filename: &filename}
}
func Redirect(sc int, newURL string) HTTPResponse {
return &redirectHTTPResponse{statusCode: sc, url: newURL}
}
func Error(e error) HTTPResponse {
return &jsonAPIErrResponse{
err: exerr.FromError(e),
}
}
func ErrWrap(e error, errorType exerr.ErrorType, msg string) HTTPResponse {
return &jsonAPIErrResponse{
err: exerr.FromError(exerr.Wrap(e, msg).WithType(errorType).Build()),
}
}
func NotImplemented() HTTPResponse {
return Error(exerr.New(exerr.TypeNotImplemented, "").Build())
}

View File

@@ -1,226 +0,0 @@
package ginext
import (
"github.com/gin-gonic/gin"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/rext"
"net/http"
"path"
"reflect"
"regexp"
"runtime"
"strings"
)
var anyMethods = []string{
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch,
http.MethodHead, http.MethodOptions, http.MethodDelete, http.MethodConnect,
http.MethodTrace,
}
type GinRoutesWrapper struct {
wrapper *GinWrapper
routes gin.IRouter
absPath string
defaultHandler []gin.HandlerFunc
}
type GinRouteBuilder struct {
routes *GinRoutesWrapper
method string
relPath string
absPath string
handlers []gin.HandlerFunc
}
func (w *GinWrapper) Routes() *GinRoutesWrapper {
return &GinRoutesWrapper{
wrapper: w,
routes: w.engine,
absPath: "",
defaultHandler: make([]gin.HandlerFunc, 0),
}
}
func (w *GinRoutesWrapper) Group(relativePath string) *GinRoutesWrapper {
return &GinRoutesWrapper{
wrapper: w.wrapper,
routes: w.routes.Group(relativePath),
defaultHandler: langext.ArrCopy(w.defaultHandler),
absPath: joinPaths(w.absPath, relativePath),
}
}
func (w *GinRoutesWrapper) Use(middleware ...gin.HandlerFunc) *GinRoutesWrapper {
defHandler := langext.ArrCopy(w.defaultHandler)
defHandler = append(defHandler, middleware...)
return &GinRoutesWrapper{wrapper: w.wrapper, routes: w.routes, defaultHandler: defHandler}
}
func (w *GinRoutesWrapper) WithJSONFilter(filter string) *GinRoutesWrapper {
defHandler := langext.ArrCopy(w.defaultHandler)
defHandler = append(defHandler, func(g *gin.Context) {
g.Set("goext.jsonfilter", filter)
})
return &GinRoutesWrapper{wrapper: w.wrapper, routes: w.routes, defaultHandler: defHandler}
}
func (w *GinRoutesWrapper) GET(relativePath string) *GinRouteBuilder {
return w._route(http.MethodGet, relativePath)
}
func (w *GinRoutesWrapper) POST(relativePath string) *GinRouteBuilder {
return w._route(http.MethodPost, relativePath)
}
func (w *GinRoutesWrapper) DELETE(relativePath string) *GinRouteBuilder {
return w._route(http.MethodDelete, relativePath)
}
func (w *GinRoutesWrapper) PATCH(relativePath string) *GinRouteBuilder {
return w._route(http.MethodPatch, relativePath)
}
func (w *GinRoutesWrapper) PUT(relativePath string) *GinRouteBuilder {
return w._route(http.MethodPut, relativePath)
}
func (w *GinRoutesWrapper) OPTIONS(relativePath string) *GinRouteBuilder {
return w._route(http.MethodOptions, relativePath)
}
func (w *GinRoutesWrapper) HEAD(relativePath string) *GinRouteBuilder {
return w._route(http.MethodHead, relativePath)
}
func (w *GinRoutesWrapper) COUNT(relativePath string) *GinRouteBuilder {
return w._route("COUNT", relativePath)
}
func (w *GinRoutesWrapper) Any(relativePath string) *GinRouteBuilder {
return w._route("*", relativePath)
}
func (w *GinRoutesWrapper) _route(method string, relativePath string) *GinRouteBuilder {
return &GinRouteBuilder{
routes: w,
method: method,
relPath: relativePath,
absPath: joinPaths(w.absPath, relativePath),
handlers: langext.ArrCopy(w.defaultHandler),
}
}
func (w *GinRouteBuilder) Use(middleware ...gin.HandlerFunc) *GinRouteBuilder {
w.handlers = append(w.handlers, middleware...)
return w
}
func (w *GinRouteBuilder) WithJSONFilter(filter string) *GinRouteBuilder {
w.handlers = append(w.handlers, func(g *gin.Context) {
g.Set("goext.jsonfilter", filter)
})
return w
}
func (w *GinRouteBuilder) Handle(handler WHandlerFunc) {
if w.routes.wrapper.bufferBody {
arr := make([]gin.HandlerFunc, 0, len(w.handlers)+1)
arr = append(arr, BodyBuffer)
arr = append(arr, w.handlers...)
w.handlers = arr
}
middlewareNames := langext.ArrMap(w.handlers, func(v gin.HandlerFunc) string { return nameOfFunction(v) })
handlerName := nameOfFunction(handler)
w.handlers = append(w.handlers, Wrap(w.routes.wrapper, handler))
methodName := w.method
if w.method == "*" {
methodName = "ANY"
for _, method := range anyMethods {
w.routes.routes.Handle(method, w.relPath, w.handlers...)
}
} else {
w.routes.routes.Handle(w.method, w.relPath, w.handlers...)
}
w.routes.wrapper.routeSpecs = append(w.routes.wrapper.routeSpecs, ginRouteSpec{
Method: methodName,
URL: w.absPath,
Middlewares: middlewareNames,
Handler: handlerName,
})
}
func (w *GinWrapper) NoRoute(handler WHandlerFunc) {
handlers := make([]gin.HandlerFunc, 0)
if w.bufferBody {
handlers = append(handlers, BodyBuffer)
}
middlewareNames := langext.ArrMap(handlers, func(v gin.HandlerFunc) string { return nameOfFunction(v) })
handlerName := nameOfFunction(handler)
handlers = append(handlers, Wrap(w, handler))
w.engine.NoRoute(handlers...)
w.routeSpecs = append(w.routeSpecs, ginRouteSpec{
Method: "ANY",
URL: "[NO_ROUTE]",
Middlewares: middlewareNames,
Handler: handlerName,
})
}
func nameOfFunction(f any) string {
fname := runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name()
split := strings.Split(fname, "/")
if len(split) == 0 {
return ""
}
fname = split[len(split)-1]
// https://stackoverflow.com/a/32925345/1761622
if strings.HasSuffix(fname, "-fm") {
fname = fname[:len(fname)-len("-fm")]
}
suffix := rext.W(regexp.MustCompile(`\.func[0-9]+(?:\.[0-9]+)*$`))
if match, ok := suffix.MatchFirst(fname); ok {
fname = fname[:len(fname)-match.FullMatch().Length()]
}
return fname
}
// joinPaths is copied verbatim from gin@v1.9.1/gin.go
func joinPaths(absolutePath, relativePath string) string {
if relativePath == "" {
return absolutePath
}
finalPath := path.Join(absolutePath, relativePath)
if lastChar(relativePath) == '/' && lastChar(finalPath) != '/' {
return finalPath + "/"
}
return finalPath
}
func lastChar(str string) uint8 {
if str == "" {
panic("The length of the string can't be 0")
}
return str[len(str)-1]
}

57
go.mod
View File

@@ -3,47 +3,30 @@ module gogs.mikescher.com/BlackForestBytes/goext
go 1.19
require (
github.com/gin-gonic/gin v1.9.1
github.com/golang/snappy v0.0.4
github.com/google/go-cmp v0.5.9
github.com/jmoiron/sqlx v1.3.5
github.com/rs/xid v1.5.0
github.com/rs/zerolog v1.31.0
go.mongodb.org/mongo-driver v1.12.1
golang.org/x/crypto v0.14.0
golang.org/x/sys v0.13.0
golang.org/x/term v0.13.0
github.com/klauspost/compress v1.16.6
github.com/kr/pretty v0.1.0
github.com/montanaflynn/stats v0.7.1
github.com/pkg/errors v0.9.1
github.com/stretchr/testify v1.8.4
github.com/tidwall/pretty v1.0.0
github.com/xdg-go/scram v1.1.2
github.com/xdg-go/stringprep v1.0.4
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a
go.mongodb.org/mongo-driver v1.11.7
golang.org/x/crypto v0.10.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.9.0
golang.org/x/term v0.9.0
)
require (
github.com/bytedance/sonic v1.10.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d // indirect
github.com/chenzhuoyu/iasm v0.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.15.5 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/montanaflynn/stats v0.7.1 // indirect
github.com/pelletier/go-toml/v2 v2.1.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/text v0.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.1.2 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a // indirect
golang.org/x/arch v0.5.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sync v0.4.0 // indirect
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect
golang.org/x/text v0.10.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

130
go.sum
View File

@@ -1,132 +1,72 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
github.com/bytedance/sonic v1.10.2 h1:GQebETVBxYB7JGWJtLBi07OVzWwt+8dWA00gEVW2ZFE=
github.com/bytedance/sonic v1.10.2/go.mod h1:iZcSUejdk5aukTND/Eu/ivjQuEL0Cu9/rf50Hi0u/g4=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d h1:77cEq6EriyTZ0g/qfRdp61a3Uu/AWrgIq2s0ClJV1g0=
github.com/chenzhuoyu/base64x v0.0.0-20230717121745-296ad89f973d/go.mod h1:8EPpVsBuRksnlj1mLy4AWzRNQYxauNi62uWcE3to6eA=
github.com/chenzhuoyu/iasm v0.9.0 h1:9fhXjVzq5hUy2gkhhgHl95zG2cEAhw9OSGs8toWWAwo=
github.com/chenzhuoyu/iasm v0.9.0/go.mod h1:Xjy2NpN3h7aUqeqM+woSuuvxmIe6+DDsiNLIrkAmYog=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.15.5 h1:LEBecTWb/1j5TNY1YYG2RcOUN3R7NLylN+x8TTueE24=
github.com/go-playground/validator/v10 v10.15.5/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM=
github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk=
github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4=
github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk=
github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.mongodb.org/mongo-driver v1.12.1 h1:nLkghSU8fQNaK7oUmDhQFsnrtcoNy7Z6LVFKsEecqgE=
go.mongodb.org/mongo-driver v1.12.1/go.mod h1:/rGBTebI3XYboVmgz+Wv3Bcbl3aD0QF9zl6kDDw18rQ=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.5.0 h1:jpGode6huXQxcskEIpOCvrU+tzo81b6+oFLUYXWtH/Y=
golang.org/x/arch v0.5.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
go.mongodb.org/mongo-driver v1.11.7 h1:LIwYxASDLGUg/8wOhgOOZhX8tQa/9tgZPgzZoVqJvcs=
go.mongodb.org/mongo-driver v1.11.7/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200302210943-78000ba7a073/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos=
golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ=
golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -134,37 +74,27 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -1,5 +1,5 @@
package goext
const GoextVersion = "0.0.286"
const GoextVersion = "0.0.166"
const GoextVersionTimestamp = "2023-10-11T11:27:18+0200"
const GoextVersionTimestamp = "2023-06-19T10:25:41+0200"

View File

@@ -156,6 +156,7 @@ import (
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: true})
if err != nil {
@@ -163,8 +164,6 @@ func Marshal(v any) ([]byte, error) {
}
buf := append([]byte(nil), e.Bytes()...)
encodeStatePool.Put(e)
return buf, nil
}
@@ -175,9 +174,9 @@ type IndentOpt struct {
// MarshalSafeCollections is like Marshal except it will marshal nil maps and
// slices as '{}' and '[]' respectfully instead of 'null'
func MarshalSafeCollections(v interface{}, nilSafeSlices bool, nilSafeMaps bool, indent *IndentOpt, filter *string) ([]byte, error) {
func MarshalSafeCollections(v interface{}, nilSafeSlices bool, nilSafeMaps bool, indent *IndentOpt) ([]byte, error) {
e := &encodeState{}
err := e.marshal(v, encOpts{escapeHTML: true, nilSafeSlices: nilSafeSlices, nilSafeMaps: nilSafeMaps, filter: filter})
err := e.marshal(v, encOpts{escapeHTML: true, nilSafeSlices: nilSafeSlices, nilSafeMaps: nilSafeMaps})
if err != nil {
return nil, err
}
@@ -394,9 +393,6 @@ type encOpts struct {
nilSafeSlices bool
// nilSafeMaps marshals a nil maps '{}' instead of 'null'
nilSafeMaps bool
// filter matches jsonfilter tag of struct
// marshals if no jsonfilter is set or otherwise if jsonfilter has the filter value
filter *string
}
type encoderFunc func(e *encodeState, v reflect.Value, opts encOpts)
@@ -781,8 +777,6 @@ FieldLoop:
if f.omitEmpty && isEmptyValue(fv) {
continue
} else if opts.filter != nil && len(f.jsonfilter) > 0 && !f.jsonfilter.Contains(*opts.filter) {
continue
}
e.WriteByte(next)
next = ','
@@ -1226,28 +1220,15 @@ type field struct {
nameNonEsc string // `"` + name + `":`
nameEscHTML string // `"` + HTMLEscape(name) + `":`
tag bool
index []int
typ reflect.Type
omitEmpty bool
jsonfilter jsonfilter
quoted bool
tag bool
index []int
typ reflect.Type
omitEmpty bool
quoted bool
encoder encoderFunc
}
// jsonfilter stores the value of the jsonfilter struct tag
type jsonfilter []string
func (j jsonfilter) Contains(t string) bool {
for _, tag := range j {
if t == tag {
return true
}
}
return false
}
// byIndex sorts field by index sequence.
type byIndex []field
@@ -1323,13 +1304,6 @@ func typeFields(t reflect.Type) structFields {
if !isValidTag(name) {
name = ""
}
var jsonfilter []string
jsonfilterTag := sf.Tag.Get("jsonfilter")
if jsonfilterTag != "" && jsonfilterTag != "-" {
jsonfilter = strings.Split(jsonfilterTag, ",")
}
index := make([]int, len(f.index)+1)
copy(index, f.index)
index[len(f.index)] = i
@@ -1360,13 +1334,12 @@ func typeFields(t reflect.Type) structFields {
name = sf.Name
}
field := field{
name: name,
tag: tagged,
index: index,
typ: ft,
omitEmpty: opts.Contains("omitempty"),
jsonfilter: jsonfilter,
quoted: quoted,
name: name,
tag: tagged,
index: index,
typ: ft,
omitEmpty: opts.Contains("omitempty"),
quoted: quoted,
}
field.nameBytes = []byte(field.name)
field.equalFold = foldFunc(field.nameBytes)

View File

@@ -1253,10 +1253,6 @@ func TestMarshalSafeCollections(t *testing.T) {
nilMapStruct struct {
NilMap map[string]interface{} `json:"nil_map"`
}
testWithFilter struct {
Test1 string `json:"test1" jsonfilter:"FILTERONE"`
Test2 string `json:"test2" jsonfilter:"FILTERTWO"`
}
)
tests := []struct {
@@ -1275,12 +1271,10 @@ func TestMarshalSafeCollections(t *testing.T) {
{map[string]interface{}{"1": 1, "2": 2, "3": 3}, "{\"1\":1,\"2\":2,\"3\":3}"},
{pNilMap, "null"},
{nilMapStruct{}, "{\"nil_map\":{}}"},
{testWithFilter{}, "{\"test1\":\"\"}"},
}
filter := "FILTERONE"
for i, tt := range tests {
b, err := MarshalSafeCollections(tt.in, true, true, nil, &filter)
b, err := MarshalSafeCollections(tt.in, true, true, nil)
if err != nil {
t.Errorf("test %d, unexpected failure: %v", i, err)
}

View File

@@ -97,10 +97,7 @@ func equalFoldRight(s, t []byte) bool {
t = t[size:]
}
if len(t) > 0 {
return false
}
return true
return len(t) == 0
}
// asciiEqualFold is a specialization of bytes.EqualFold for use when

View File

@@ -52,9 +52,7 @@ func TestFold(t *testing.T) {
}
func TestFoldAgainstUnicode(t *testing.T) {
const bufSize = 5
buf1 := make([]byte, 0, bufSize)
buf2 := make([]byte, 0, bufSize)
var buf1, buf2 []byte
var runes []rune
for i := 0x20; i <= 0x7f; i++ {
runes = append(runes, rune(i))
@@ -96,12 +94,8 @@ func TestFoldAgainstUnicode(t *testing.T) {
continue
}
for _, r2 := range runes {
buf1 := append(buf1[:0], 'x')
buf2 := append(buf2[:0], 'x')
buf1 = buf1[:1+utf8.EncodeRune(buf1[1:bufSize], r)]
buf2 = buf2[:1+utf8.EncodeRune(buf2[1:bufSize], r2)]
buf1 = append(buf1, 'x')
buf2 = append(buf2, 'x')
buf1 = append(utf8.AppendRune(append(buf1[:0], 'x'), r), 'x')
buf2 = append(utf8.AppendRune(append(buf2[:0], 'x'), r2), 'x')
want := bytes.EqualFold(buf1, buf2)
if got := ff.fold(buf1, buf2); got != want {
t.Errorf("%s(%q, %q) = %v; want %v", ff.name, buf1, buf2, got, want)

View File

@@ -17,7 +17,6 @@ type GoJsonRender struct {
NilSafeSlices bool
NilSafeMaps bool
Indent *IndentOpt
Filter *string
}
func (r GoJsonRender) Render(w http.ResponseWriter) error {
@@ -26,7 +25,7 @@ func (r GoJsonRender) Render(w http.ResponseWriter) error {
header["Content-Type"] = []string{"application/json; charset=utf-8"}
}
jsonBytes, err := MarshalSafeCollections(r.Data, r.NilSafeSlices, r.NilSafeMaps, r.Indent, r.Filter)
jsonBytes, err := MarshalSafeCollections(r.Data, r.NilSafeSlices, r.NilSafeMaps, r.Indent)
if err != nil {
panic(err)
}

View File

@@ -116,18 +116,3 @@ func TestNumberIsValid(t *testing.T) {
}
}
}
func BenchmarkNumberIsValid(b *testing.B) {
s := "-61657.61667E+61673"
for i := 0; i < b.N; i++ {
isValidNumber(s)
}
}
func BenchmarkNumberIsValidRegexp(b *testing.B) {
var jsonNumberRegexp = regexp.MustCompile(`^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$`)
s := "-61657.61667E+61673"
for i := 0; i < b.N; i++ {
jsonNumberRegexp.MatchString(s)
}
}

View File

@@ -594,7 +594,7 @@ func (s *scanner) error(c byte, context string) int {
return scanError
}
// quoteChar formats c as a quoted character literal
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {

View File

@@ -179,9 +179,11 @@ func nonSpace(b []byte) bool {
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
w io.Writer
err error
escapeHTML bool
nilSafeSlices bool
nilSafeMaps bool
indentBuf *bytes.Buffer
indentPrefix string
@@ -202,8 +204,11 @@ func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML})
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML, nilSafeMaps: enc.nilSafeMaps, nilSafeSlices: enc.nilSafeSlices})
if err != nil {
return err
}
@@ -231,7 +236,6 @@ func (enc *Encoder) Encode(v any) error {
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
encodeStatePool.Put(e)
return err
}
@@ -243,6 +247,13 @@ func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentValue = indent
}
// SetNilSafeCollection specifies whether to represent nil slices and maps as
// '[]' or '{}' respectfully (flag on) instead of 'null' (default) when marshaling json.
func (enc *Encoder) SetNilSafeCollection(nilSafeSlices bool, nilSafeMaps bool) {
enc.nilSafeSlices = nilSafeSlices
enc.nilSafeMaps = nilSafeMaps
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e

View File

@@ -12,6 +12,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"runtime/debug"
"strings"
"testing"
)
@@ -41,7 +42,7 @@ false
func TestEncoder(t *testing.T) {
for i := 0; i <= len(streamTest); i++ {
var buf bytes.Buffer
var buf strings.Builder
enc := NewEncoder(&buf)
// Check that enc.SetIndent("", "") turns off indentation.
enc.SetIndent(">", ".")
@@ -59,6 +60,43 @@ func TestEncoder(t *testing.T) {
}
}
func TestEncoderErrorAndReuseEncodeState(t *testing.T) {
// Disable the GC temporarily to prevent encodeState's in Pool being cleaned away during the test.
percent := debug.SetGCPercent(-1)
defer debug.SetGCPercent(percent)
// Trigger an error in Marshal with cyclic data.
type Dummy struct {
Name string
Next *Dummy
}
dummy := Dummy{Name: "Dummy"}
dummy.Next = &dummy
var buf bytes.Buffer
enc := NewEncoder(&buf)
if err := enc.Encode(dummy); err == nil {
t.Errorf("Encode(dummy) == nil; want error")
}
type Data struct {
A string
I int
}
data := Data{A: "a", I: 1}
if err := enc.Encode(data); err != nil {
t.Errorf("Marshal(%v) = %v", data, err)
}
var data2 Data
if err := Unmarshal(buf.Bytes(), &data2); err != nil {
t.Errorf("Unmarshal(%v) = %v", data2, err)
}
if data2 != data {
t.Errorf("expect: %v, but get: %v", data, data2)
}
}
var streamEncodedIndent = `0.1
"hello"
null
@@ -77,7 +115,7 @@ false
`
func TestEncoderIndent(t *testing.T) {
var buf bytes.Buffer
var buf strings.Builder
enc := NewEncoder(&buf)
enc.SetIndent(">", ".")
for _, v := range streamTest {
@@ -147,7 +185,7 @@ func TestEncoderSetEscapeHTML(t *testing.T) {
`{"bar":"\"<html>foobar</html>\""}`,
},
} {
var buf bytes.Buffer
var buf strings.Builder
enc := NewEncoder(&buf)
if err := enc.Encode(tt.v); err != nil {
t.Errorf("Encode(%s): %s", tt.name, err)
@@ -309,21 +347,6 @@ func TestBlocking(t *testing.T) {
}
}
func BenchmarkEncoderEncode(b *testing.B) {
b.ReportAllocs()
type T struct {
X, Y string
}
v := &T{"foo", "bar"}
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if err := NewEncoder(io.Discard).Encode(v); err != nil {
b.Fatal(err)
}
}
})
}
type tokenStreamCase struct {
json string
expTokens []any
@@ -472,3 +495,45 @@ func TestHTTPDecoding(t *testing.T) {
t.Errorf("err = %v; want io.EOF", err)
}
}
func TestEncoderSetNilSafeCollection(t *testing.T) {
var (
nilSlice []interface{}
pNilSlice *[]interface{}
nilMap map[string]interface{}
pNilMap *map[string]interface{}
)
for _, tt := range []struct {
name string
v interface{}
want string
rescuedWant string
}{
{"nilSlice", nilSlice, "null", "[]"},
{"nonNilSlice", []interface{}{}, "[]", "[]"},
{"sliceWithValues", []interface{}{1, 2, 3}, "[1,2,3]", "[1,2,3]"},
{"pNilSlice", pNilSlice, "null", "null"},
{"nilMap", nilMap, "null", "{}"},
{"nonNilMap", map[string]interface{}{}, "{}", "{}"},
{"mapWithValues", map[string]interface{}{"1": 1, "2": 2, "3": 3}, "{\"1\":1,\"2\":2,\"3\":3}", "{\"1\":1,\"2\":2,\"3\":3}"},
{"pNilMap", pNilMap, "null", "null"},
} {
var buf bytes.Buffer
enc := NewEncoder(&buf)
if err := enc.Encode(tt.v); err != nil {
t.Fatalf("Encode(%s): %s", tt.name, err)
}
if got := strings.TrimSpace(buf.String()); got != tt.want {
t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.want)
}
buf.Reset()
enc.SetNilSafeCollection(true, true)
if err := enc.Encode(tt.v); err != nil {
t.Fatalf("SetNilSafeCollection(true) Encode(%s): %s", tt.name, err)
}
if got := strings.TrimSpace(buf.String()); got != tt.rescuedWant {
t.Errorf("SetNilSafeCollection(true) Encode(%s) = %#q, want %#q",
tt.name, got, tt.want)
}
}
}

View File

@@ -400,7 +400,7 @@ func ArrCastErr[T1 any, T2 any](arr []T1) ([]T2, error) {
if vcast, ok := any(v).(T2); ok {
r[i] = vcast
} else {
return nil, errors.New(fmt.Sprintf("Cannot cast element %d of type %T to type %v", i, v, *new(T2)))
return nil, errors.New(fmt.Sprintf("Cannot cast element %d of type %T to type %s", i, v, *new(T2)))
}
}
return r, nil
@@ -412,7 +412,7 @@ func ArrCastPanic[T1 any, T2 any](arr []T1) []T2 {
if vcast, ok := any(v).(T2); ok {
r[i] = vcast
} else {
panic(fmt.Sprintf("Cannot cast element %d of type %T to type %v", i, v, *new(T2)))
panic(fmt.Sprintf("Cannot cast element %d of type %T to type %s", i, v, *new(T2)))
}
}
return r
@@ -440,42 +440,3 @@ func ArrCopy[T any](in []T) []T {
copy(out, in)
return out
}
func ArrRemove[T comparable](arr []T, needle T) []T {
idx := ArrFirstIndex(arr, needle)
if idx >= 0 {
return append(arr[:idx], arr[idx+1:]...)
}
return arr
}
func ArrExcept[T comparable](arr []T, needles ...T) []T {
r := make([]T, 0, len(arr))
rmlist := ArrToSet(needles)
for _, v := range arr {
if _, ok := rmlist[v]; !ok {
r = append(r, v)
}
}
return r
}
func ArrayToInterface[T any](t []T) []interface{} {
res := make([]interface{}, 0, len(t))
for i, _ := range t {
res = append(res, t[i])
}
return res
}
func JoinString(arr []string, delimiter string) string {
str := ""
for i, v := range arr {
str += v
if i < len(arr)-1 {
str += delimiter
}
}
return str
}

View File

@@ -1,12 +0,0 @@
package langext
import (
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
func TestJoinString(t *testing.T) {
ids := []string{"1", "2", "3"}
res := JoinString(ids, ",")
tst.AssertEqual(t, res, "1,2,3")
}

View File

@@ -1,7 +1,6 @@
package langext
import (
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
@@ -60,3 +59,9 @@ func TestBase58FlickrDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58FlickrEncoding, "9aJCVZR"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58FlickrEncoding, "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func tst.AssertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -29,14 +29,6 @@ func ArrToMap[T comparable, V any](a []V, keyfunc func(V) T) map[T]V {
return result
}
func ArrToSet[T comparable](a []T) map[T]bool {
result := make(map[T]bool, len(a))
for _, v := range a {
result[v] = true
}
return result
}
func MapToArr[T comparable, V any](v map[T]V) []MapEntry[T, V] {
result := make([]MapEntry[T, V], 0, len(v))
for mk, mv := range v {

View File

@@ -1,10 +1,7 @@
package langext
import "runtime/debug"
type PanicWrappedErr struct {
panic any
Stack string
}
func (p PanicWrappedErr) Error() string {
@@ -18,7 +15,7 @@ func (p PanicWrappedErr) ReoveredObj() any {
func RunPanicSafe(fn func()) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = PanicWrappedErr{panic: rec, Stack: string(debug.Stack())}
err = PanicWrappedErr{panic: rec}
}
}()
@@ -30,7 +27,7 @@ func RunPanicSafe(fn func()) (err error) {
func RunPanicSafeR1(fn func() error) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = PanicWrappedErr{panic: rec, Stack: string(debug.Stack())}
err = PanicWrappedErr{panic: rec}
}
}()
@@ -41,7 +38,7 @@ func RunPanicSafeR2[T1 any](fn func() (T1, error)) (r1 T1, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
err = PanicWrappedErr{panic: rec, Stack: string(debug.Stack())}
err = PanicWrappedErr{panic: rec}
}
}()
@@ -53,7 +50,7 @@ func RunPanicSafeR3[T1 any, T2 any](fn func() (T1, T2, error)) (r1 T1, r2 T2, er
if rec := recover(); rec != nil {
r1 = *new(T1)
r2 = *new(T2)
err = PanicWrappedErr{panic: rec, Stack: string(debug.Stack())}
err = PanicWrappedErr{panic: rec}
}
}()
@@ -66,7 +63,7 @@ func RunPanicSafeR4[T1 any, T2 any, T3 any](fn func() (T1, T2, T3, error)) (r1 T
r1 = *new(T1)
r2 = *new(T2)
r3 = *new(T3)
err = PanicWrappedErr{panic: rec, Stack: string(debug.Stack())}
err = PanicWrappedErr{panic: rec}
}
}()

16
mongo/.errcheck-excludes Normal file
View File

@@ -0,0 +1,16 @@
(go.mongodb.org/mongo-driver/x/mongo/driver.Connection).Close
(*go.mongodb.org/mongo-driver/x/network/connection.connection).Close
(go.mongodb.org/mongo-driver/x/network/connection.Connection).Close
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.connection).close
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Topology).Unsubscribe
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.Server).Close
(*go.mongodb.org/mongo-driver/x/network/connection.pool).closeConnection
(*go.mongodb.org/mongo-driver/x/mongo/driver/topology.pool).close
(go.mongodb.org/mongo-driver/x/network/wiremessage.ReadWriteCloser).Close
(*go.mongodb.org/mongo-driver/mongo.Cursor).Close
(*go.mongodb.org/mongo-driver/mongo.ChangeStream).Close
(*go.mongodb.org/mongo-driver/mongo.Client).Disconnect
(net.Conn).Close
encoding/pem.Encode
fmt.Fprintf
fmt.Fprint

13
mongo/.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
.vscode
debug
.idea
*.iml
*.ipr
*.iws
.idea
*.sublime-project
*.sublime-workspace
driver-test-data.tar.gz
perf
**mongocryptd.pid
*.test

3
mongo/.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "specifications"]
path = specifications
url = git@github.com:mongodb/specifications.git

123
mongo/.golangci.yml Normal file
View File

@@ -0,0 +1,123 @@
run:
timeout: 5m
linters:
disable-all: true
# TODO(GODRIVER-2156): Enable all commented-out linters.
enable:
- errcheck
# - errorlint
- gocritic
- goimports
- gosimple
- gosec
- govet
- ineffassign
- makezero
- misspell
- nakedret
- paralleltest
- prealloc
- revive
- staticcheck
- typecheck
- unused
- unconvert
- unparam
linters-settings:
errcheck:
exclude: .errcheck-excludes
gocritic:
enabled-checks:
# Detects suspicious append result assignments. E.g. "b := append(a, 1, 2, 3)"
- appendAssign
govet:
disable:
- cgocall
- composites
paralleltest:
# Ignore missing calls to `t.Parallel()` and only report incorrect uses of `t.Parallel()`.
ignore-missing: true
staticcheck:
checks: [
"all",
"-SA1019", # Disable deprecation warnings for now.
"-SA1012", # Disable "do not pass a nil Context" to allow testing nil contexts in tests.
]
issues:
exclude-use-default: false
exclude:
# Add all default excluded issues except issues related to exported types/functions not having
# comments; we want those warnings. The defaults are copied from the "--exclude-use-default"
# documentation on https://golangci-lint.run/usage/configuration/#command-line-options
## Defaults ##
# EXC0001 errcheck: Almost all programs ignore errors on these functions and in most cases it's ok
- Error return value of .((os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*print(f|ln)?|os\.(Un)?Setenv). is not checked
# EXC0003 golint: False positive when tests are defined in package 'test'
- func name will be used as test\.Test.* by other packages, and that stutters; consider calling this
# EXC0004 govet: Common false positives
- (possible misuse of unsafe.Pointer|should have signature)
# EXC0005 staticcheck: Developers tend to write in C-style with an explicit 'break' in a 'switch', so it's ok to ignore
- ineffective break statement. Did you mean to break out of the outer loop
# EXC0006 gosec: Too many false-positives on 'unsafe' usage
- Use of unsafe calls should be audited
# EXC0007 gosec: Too many false-positives for parametrized shell calls
- Subprocess launch(ed with variable|ing should be audited)
# EXC0008 gosec: Duplicated errcheck checks
- (G104|G307)
# EXC0009 gosec: Too many issues in popular repos
- (Expect directory permissions to be 0750 or less|Expect file permissions to be 0600 or less)
# EXC0010 gosec: False positive is triggered by 'src, err := ioutil.ReadFile(filename)'
- Potential file inclusion via variable
## End Defaults ##
# Ignore capitalization warning for this weird field name.
- "var-naming: struct field CqCssWxW should be CqCSSWxW"
# Ignore warnings for common "wiremessage.Read..." usage because the safest way to use that API
# is by assigning possibly unused returned byte buffers.
- "SA4006: this value of `wm` is never used"
- "SA4006: this value of `rem` is never used"
- "ineffectual assignment to wm"
- "ineffectual assignment to rem"
skip-dirs-use-default: false
skip-dirs:
- (^|/)vendor($|/)
- (^|/)testdata($|/)
- (^|/)etc($|/)
exclude-rules:
# Ignore some linters for example code that is intentionally simplified.
- path: examples/
linters:
- revive
- errcheck
# Disable unused code linters for the copy/pasted "awsv4" package.
- path: x/mongo/driver/auth/internal/awsv4
linters:
- unused
# Disable "unused" linter for code files that depend on the "mongocrypt.MongoCrypt" type because
# the linter build doesn't work correctly with CGO enabled. As a result, all calls to a
# "mongocrypt.MongoCrypt" API appear to always panic (see mongocrypt_not_enabled.go), leading
# to confusing messages about unused code.
- path: x/mongo/driver/crypt.go|mongo/(crypt_retrievers|mongocryptd).go
linters:
- unused
# Ignore "TLS MinVersion too low", "TLS InsecureSkipVerify set true", and "Use of weak random
# number generator (math/rand instead of crypto/rand)" in tests.
- path: _test\.go
text: G401|G402|G404
linters:
- gosec
# Ignore missing comments for exported variable/function/type for code in the "internal" and
# "benchmark" directories.
- path: (internal\/|benchmark\/)
text: exported (.+) should have comment( \(or a comment on this block\))? or be unexported
# Ignore missing package comments for directories that aren't frequently used by external users.
- path: (internal\/|benchmark\/|x\/|cmd\/|mongo\/integration\/)
text: should have a package comment
# Disable unused linter for "golang.org/x/exp/rand" package in internal/randutil/rand.
- path: internal/randutil/rand
linters:
- unused

201
mongo/LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

210
mongo/Makefile Normal file
View File

@@ -0,0 +1,210 @@
ATLAS_URIS = "$(ATLAS_FREE)" "$(ATLAS_REPLSET)" "$(ATLAS_SHARD)" "$(ATLAS_TLS11)" "$(ATLAS_TLS12)" "$(ATLAS_FREE_SRV)" "$(ATLAS_REPLSET_SRV)" "$(ATLAS_SHARD_SRV)" "$(ATLAS_TLS11_SRV)" "$(ATLAS_TLS12_SRV)" "$(ATLAS_SERVERLESS)" "$(ATLAS_SERVERLESS_SRV)"
TEST_TIMEOUT = 1800
### Utility targets. ###
.PHONY: default
default: build check-license check-fmt check-modules lint test-short
.PHONY: add-license
add-license:
etc/check_license.sh -a
.PHONY: check-license
check-license:
etc/check_license.sh
.PHONY: build
build: cross-compile build-tests build-compile-check
go build ./...
go build $(BUILD_TAGS) ./...
# Use ^$ to match no tests so that no tests are actually run but all tests are
# compiled. Run with -short to ensure none of the TestMain functions try to
# connect to a server.
.PHONY: build-tests
build-tests:
go test -short $(BUILD_TAGS) -run ^$$ ./...
.PHONY: build-compile-check
build-compile-check:
etc/compile_check.sh
# Cross-compiling on Linux for architectures 386, arm, arm64, amd64, ppc64le, and s390x.
# Omit any build tags because we don't expect our build environment to support compiling the C
# libraries for other architectures.
.PHONY: cross-compile
cross-compile:
GOOS=linux GOARCH=386 go build ./...
GOOS=linux GOARCH=arm go build ./...
GOOS=linux GOARCH=arm64 go build ./...
GOOS=linux GOARCH=amd64 go build ./...
GOOS=linux GOARCH=ppc64le go build ./...
GOOS=linux GOARCH=s390x go build ./...
.PHONY: install-lll
install-lll:
go install github.com/walle/lll/...@latest
.PHONY: check-fmt
check-fmt: install-lll
etc/check_fmt.sh
# check-modules runs "go mod tidy" then "go mod vendor" and exits with a non-zero exit code if there
# are any module or vendored modules changes. The intent is to confirm two properties:
#
# 1. Exactly the required modules are declared as dependencies. We should always be able to run
# "go mod tidy" and expect that no unrelated changes are made to the "go.mod" file.
#
# 2. All required modules are copied into the vendor/ directory and are an exact copy of the
# original module source code (i.e. the vendored modules are not modified from their original code).
.PHONY: check-modules
check-modules:
go mod tidy -v
go mod vendor
git diff --exit-code go.mod go.sum ./vendor
.PHONY: doc
doc:
godoc -http=:6060 -index
.PHONY: fmt
fmt:
go fmt ./...
.PHONY: install-golangci-lint
install-golangci-lint:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.51.0
# Lint with various GOOS and GOARCH targets to catch static analysis failures that may only affect
# specific operating systems or architectures. For example, staticcheck will only check for 64-bit
# alignment of atomically accessed variables on 32-bit architectures (see
# https://staticcheck.io/docs/checks#SA1027)
.PHONY: lint
lint: install-golangci-lint
GOOS=linux GOARCH=386 golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=arm golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=arm64 golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=amd64 golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=ppc64le golangci-lint run --config .golangci.yml ./...
GOOS=linux GOARCH=s390x golangci-lint run --config .golangci.yml ./...
.PHONY: update-notices
update-notices:
etc/generate_notices.pl > THIRD-PARTY-NOTICES
### Local testing targets. ###
.PHONY: test
test:
go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -p 1 ./...
.PHONY: test-cover
test-cover:
go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -cover $(COVER_ARGS) -p 1 ./...
.PHONY: test-race
test-race:
go test $(BUILD_TAGS) -timeout $(TEST_TIMEOUT)s -race -p 1 ./...
.PHONY: test-short
test-short:
go test $(BUILD_TAGS) -timeout 60s -short ./...
### Evergreen specific targets. ###
.PHONY: build-aws-ecs-test
build-aws-ecs-test:
go build $(BUILD_TAGS) ./cmd/testaws/main.go
.PHONY: evg-test
evg-test:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s -p 1 ./... >> test.suite
.PHONY: evg-test-atlas
evg-test-atlas:
go run ./cmd/testatlas/main.go $(ATLAS_URIS)
.PHONY: evg-test-atlas-data-lake
evg-test-atlas-data-lake:
ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestUnifiedSpecs/atlas-data-lake-testing >> spec_test.suite
ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./mongo/integration -run TestAtlasDataLake >> spec_test.suite
.PHONY: evg-test-enterprise-auth
evg-test-enterprise-auth:
go run -tags gssapi ./cmd/testentauth/main.go
.PHONY: evg-test-kmip
evg-test-kmip:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/data_key_and_double_encryption >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/corpus >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/custom_endpoint >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/kms_tls_options_test >> test.suite
.PHONY: evg-test-kms
evg-test-kms:
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse/kms_tls_tests >> test.suite
.PHONY: evg-test-load-balancers
evg-test-load-balancers:
# Load balancer should be tested with all unified tests as well as tests in the following
# components: retryable reads, retryable writes, change streams, initial DNS seedlist discovery.
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/retryable-reads -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestChangeStreamSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestInitialDNSSeedlistDiscoverySpec/load_balanced -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
.PHONY: evg-test-ocsp
evg-test-ocsp:
go test -v ./mongo -run TestOCSP $(OCSP_TLS_SHOULD_SUCCEED) >> test.suite
.PHONY: evg-test-serverless
evg-test-serverless:
# Serverless should be tested with all unified tests as well as tests in the following components: CRUD, load balancer,
# retryable reads, retryable writes, sessions, transactions and cursor behavior.
go test $(BUILD_TAGS) ./mongo/integration -run TestCrudSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestWriteErrorsWithLabels -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestWriteErrorsDetails -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestHintErrors -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestWriteConcernError -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestErrorsCodeNamePropagated -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestLoadBalancerSupport -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/retryable-reads -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableReadsProse -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestRetryableWritesProse -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/sessions -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestSessionsProse -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestUnifiedSpecs/transactions/legacy -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestConvenientTransactions -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration -run TestCursor -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test $(BUILD_TAGS) ./mongo/integration/unified -run TestUnifiedSpec -v -timeout $(TEST_TIMEOUT)s >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionProse >> test.suite
.PHONY: evg-test-versioned-api
evg-test-versioned-api:
# Versioned API related tests are in the mongo, integration and unified packages.
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration >> test.suite
go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration/unified >> test.suite
.PHONY: build-gcpkms-test
build-gcpkms-test:
go build $(BUILD_TAGS) ./cmd/testgcpkms
### Benchmark specific targets and support. ###
.PHONY: benchmark
benchmark:perf
go test $(BUILD_TAGS) -benchmem -bench=. ./benchmark
.PHONY: driver-benchmark
driver-benchmark:perf
@go run cmd/godriver-benchmark/main.go | tee perf.suite
perf:driver-test-data.tar.gz
tar -zxf $< $(if $(eq $(UNAME_S),Darwin),-s , --transform=s)/testdata/perf/
@touch $@
driver-test-data.tar.gz:
curl --retry 5 "https://s3.amazonaws.com/boxes.10gen.com/build/driver-test-data.tar.gz" -o driver-test-data.tar.gz --silent --max-time 120

251
mongo/README.md Normal file
View File

@@ -0,0 +1,251 @@
<p align="center"><img src="etc/assets/mongo-gopher.png" width="250"></p>
<p align="center">
<a href="https://goreportcard.com/report/go.mongodb.org/mongo-driver"><img src="https://goreportcard.com/badge/go.mongodb.org/mongo-driver"></a>
<a href="https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo"><img src="etc/assets/godev-mongo-blue.svg" alt="docs"></a>
<a href="https://pkg.go.dev/go.mongodb.org/mongo-driver/bson"><img src="etc/assets/godev-bson-blue.svg" alt="docs"></a>
<a href="https://www.mongodb.com/docs/drivers/go/current/"><img src="etc/assets/docs-mongodb-green.svg"></a>
</p>
# MongoDB Go Driver
The MongoDB supported driver for Go.
-------------------------
- [Requirements](#requirements)
- [Installation](#installation)
- [Usage](#usage)
- [Feedback](#feedback)
- [Testing / Development](#testing--development)
- [Continuous Integration](#continuous-integration)
- [License](#license)
-------------------------
## Requirements
- Go 1.13 or higher. We aim to support the latest versions of Go.
- Go 1.20 or higher is required to run the driver test suite.
- MongoDB 3.6 and higher.
-------------------------
## Installation
The recommended way to get started using the MongoDB Go driver is by using Go modules to install the dependency in
your project. This can be done either by importing packages from `go.mongodb.org/mongo-driver` and having the build
step install the dependency or by explicitly running
```bash
go get go.mongodb.org/mongo-driver/mongo
```
When using a version of Go that does not support modules, the driver can be installed using `dep` by running
```bash
dep ensure -add "go.mongodb.org/mongo-driver/mongo"
```
-------------------------
## Usage
To get started with the driver, import the `mongo` package and create a `mongo.Client` with the `Connect` function:
```go
import (
"context"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017"))
```
Make sure to defer a call to `Disconnect` after instantiating your client:
```go
defer func() {
if err = client.Disconnect(ctx); err != nil {
panic(err)
}
}()
```
For more advanced configuration and authentication, see the [documentation for mongo.Connect](https://pkg.go.dev/go.mongodb.org/mongo-driver/mongo#Connect).
Calling `Connect` does not block for server discovery. If you wish to know if a MongoDB server has been found and connected to,
use the `Ping` method:
```go
ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
err = client.Ping(ctx, readpref.Primary())
```
To insert a document into a collection, first retrieve a `Database` and then `Collection` instance from the `Client`:
```go
collection := client.Database("testing").Collection("numbers")
```
The `Collection` instance can then be used to insert documents:
```go
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := collection.InsertOne(ctx, bson.D{{"name", "pi"}, {"value", 3.14159}})
id := res.InsertedID
```
To use `bson.D`, you will need to add `"go.mongodb.org/mongo-driver/bson"` to your imports.
Your import statement should now look like this:
```go
import (
"context"
"log"
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
)
```
Several query methods return a cursor, which can be used like this:
```go
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cur, err := collection.Find(ctx, bson.D{})
if err != nil { log.Fatal(err) }
defer cur.Close(ctx)
for cur.Next(ctx) {
var result bson.D
err := cur.Decode(&result)
if err != nil { log.Fatal(err) }
// do something with result....
}
if err := cur.Err(); err != nil {
log.Fatal(err)
}
```
For methods that return a single item, a `SingleResult` instance is returned:
```go
var result struct {
Value float64
}
filter := bson.D{{"name", "pi"}}
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = collection.FindOne(ctx, filter).Decode(&result)
if err == mongo.ErrNoDocuments {
// Do something when no record was found
fmt.Println("record does not exist")
} else if err != nil {
log.Fatal(err)
}
// Do something with result...
```
Additional examples and documentation can be found under the examples directory and [on the MongoDB Documentation website](https://www.mongodb.com/docs/drivers/go/current/).
-------------------------
## Feedback
For help with the driver, please post in the [MongoDB Community Forums](https://developer.mongodb.com/community/forums/tag/golang/).
New features and bugs can be reported on jira: https://jira.mongodb.org/browse/GODRIVER
-------------------------
## Testing / Development
The driver tests can be run against several database configurations. The most simple configuration is a standalone mongod with no auth, no ssl, and no compression. To run these basic driver tests, make sure a standalone MongoDB server instance is running at localhost:27017. To run the tests, you can run `make` (on Windows, run `nmake`). This will run coverage, run go-lint, run go-vet, and build the examples.
### Testing Different Topologies
To test a **replica set** or **sharded cluster**, set `MONGODB_URI="<connection-string>"` for the `make` command.
For example, for a local replica set named `rs1` comprised of three nodes on ports 27017, 27018, and 27019:
```
MONGODB_URI="mongodb://localhost:27017,localhost:27018,localhost:27019/?replicaSet=rs1" make
```
### Testing Auth and TLS
To test authentication and TLS, first set up a MongoDB cluster with auth and TLS configured. Testing authentication requires a user with the `root` role on the `admin` database. Here is an example command that would run a mongod with TLS correctly configured for tests. Either set or replace PATH_TO_SERVER_KEY_FILE and PATH_TO_CA_FILE with paths to their respective files:
```
mongod \
--auth \
--tlsMode requireTLS \
--tlsCertificateKeyFile $PATH_TO_SERVER_KEY_FILE \
--tlsCAFile $PATH_TO_CA_FILE \
--tlsAllowInvalidCertificates
```
To run the tests with `make`, set:
- `MONGO_GO_DRIVER_CA_FILE` to the location of the CA file used by the database
- `MONGO_GO_DRIVER_KEY_FILE` to the location of the client key file
- `MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE` to the location of the pkcs8 client key file encrypted with the password string: `password`
- `MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE` to the location of the unencrypted pkcs8 key file
- `MONGODB_URI` to the connection string of the server
- `AUTH=auth`
- `SSL=ssl`
For example:
```
AUTH=auth SSL=ssl \
MONGO_GO_DRIVER_CA_FILE=$PATH_TO_CA_FILE \
MONGO_GO_DRIVER_KEY_FILE=$PATH_TO_CLIENT_KEY_FILE \
MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE=$PATH_TO_ENCRYPTED_KEY_FILE \
MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE=$PATH_TO_UNENCRYPTED_KEY_FILE \
MONGODB_URI="mongodb://user:password@localhost:27017/?authSource=admin" \
make
```
Notes:
- The `--tlsAllowInvalidCertificates` flag is required on the server for the test suite to work correctly.
- The test suite requires the auth database to be set with `?authSource=admin`, not `/admin`.
### Testing Compression
The MongoDB Go Driver supports wire protocol compression using Snappy, zLib, or zstd. To run tests with wire protocol compression, set `MONGO_GO_DRIVER_COMPRESSOR` to `snappy`, `zlib`, or `zstd`. For example:
```
MONGO_GO_DRIVER_COMPRESSOR=snappy make
```
Ensure the [`--networkMessageCompressors` flag](https://www.mongodb.com/docs/manual/reference/program/mongod/#cmdoption-mongod-networkmessagecompressors) on mongod or mongos includes `zlib` if testing zLib compression.
-------------------------
## Contribution
Check out the [project page](https://jira.mongodb.org/browse/GODRIVER) for tickets that need completing. See our [contribution guidelines](docs/CONTRIBUTING.md) for details.
-------------------------
## Continuous Integration
Commits to master are run automatically on [evergreen](https://evergreen.mongodb.com/waterfall/mongo-go-driver).
-------------------------
## Frequently Encountered Issues
See our [common issues](docs/common-issues.md) documentation for troubleshooting frequently encountered issues.
-------------------------
## Thanks and Acknowledgement
<a href="https://github.com/ashleymcnamara">@ashleymcnamara</a> - Mongo Gopher Artwork
-------------------------
## License
The MongoDB Go Driver is licensed under the [Apache License](LICENSE).

1554
mongo/THIRD-PARTY-NOTICES Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,307 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"compress/gzip"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"path"
"testing"
)
type encodetest struct {
Field1String string
Field1Int64 int64
Field1Float64 float64
Field2String string
Field2Int64 int64
Field2Float64 float64
Field3String string
Field3Int64 int64
Field3Float64 float64
Field4String string
Field4Int64 int64
Field4Float64 float64
}
type nestedtest1 struct {
Nested nestedtest2
}
type nestedtest2 struct {
Nested nestedtest3
}
type nestedtest3 struct {
Nested nestedtest4
}
type nestedtest4 struct {
Nested nestedtest5
}
type nestedtest5 struct {
Nested nestedtest6
}
type nestedtest6 struct {
Nested nestedtest7
}
type nestedtest7 struct {
Nested nestedtest8
}
type nestedtest8 struct {
Nested nestedtest9
}
type nestedtest9 struct {
Nested nestedtest10
}
type nestedtest10 struct {
Nested nestedtest11
}
type nestedtest11 struct {
Nested encodetest
}
var encodetestInstance = encodetest{
Field1String: "foo",
Field1Int64: 1,
Field1Float64: 3.0,
Field2String: "bar",
Field2Int64: 2,
Field2Float64: 3.1,
Field3String: "baz",
Field3Int64: 3,
Field3Float64: 3.14,
Field4String: "qux",
Field4Int64: 4,
Field4Float64: 3.141,
}
var nestedInstance = nestedtest1{
nestedtest2{
nestedtest3{
nestedtest4{
nestedtest5{
nestedtest6{
nestedtest7{
nestedtest8{
nestedtest9{
nestedtest10{
nestedtest11{
encodetest{
Field1String: "foo",
Field1Int64: 1,
Field1Float64: 3.0,
Field2String: "bar",
Field2Int64: 2,
Field2Float64: 3.1,
Field3String: "baz",
Field3Int64: 3,
Field3Float64: 3.14,
Field4String: "qux",
Field4Int64: 4,
Field4Float64: 3.141,
},
},
},
},
},
},
},
},
},
},
},
}
const extendedBSONDir = "../testdata/extended_bson"
// readExtJSONFile reads the GZIP-compressed extended JSON document from the given filename in the
// "extended BSON" test data directory (../testdata/extended_bson) and returns it as a
// map[string]interface{}. It panics on any errors.
func readExtJSONFile(filename string) map[string]interface{} {
filePath := path.Join(extendedBSONDir, filename)
file, err := os.Open(filePath)
if err != nil {
panic(fmt.Sprintf("error opening file %q: %s", filePath, err))
}
defer func() {
_ = file.Close()
}()
gz, err := gzip.NewReader(file)
if err != nil {
panic(fmt.Sprintf("error creating GZIP reader: %s", err))
}
defer func() {
_ = gz.Close()
}()
data, err := ioutil.ReadAll(gz)
if err != nil {
panic(fmt.Sprintf("error reading GZIP contents of file: %s", err))
}
var v map[string]interface{}
err = UnmarshalExtJSON(data, false, &v)
if err != nil {
panic(fmt.Sprintf("error unmarshalling extended JSON: %s", err))
}
return v
}
func BenchmarkMarshal(b *testing.B) {
cases := []struct {
desc string
value interface{}
}{
{
desc: "simple struct",
value: encodetestInstance,
},
{
desc: "nested struct",
value: nestedInstance,
},
{
desc: "deep_bson.json.gz",
value: readExtJSONFile("deep_bson.json.gz"),
},
{
desc: "flat_bson.json.gz",
value: readExtJSONFile("flat_bson.json.gz"),
},
{
desc: "full_bson.json.gz",
value: readExtJSONFile("full_bson.json.gz"),
},
}
for _, tc := range cases {
b.Run(tc.desc, func(b *testing.B) {
b.Run("BSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling BSON: %s", err)
}
}
})
b.Run("extJSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := MarshalExtJSON(tc.value, true, false)
if err != nil {
b.Errorf("error marshalling extended JSON: %s", err)
}
}
})
b.Run("JSON", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, err := json.Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling JSON: %s", err)
}
}
})
})
}
}
func BenchmarkUnmarshal(b *testing.B) {
cases := []struct {
desc string
value interface{}
}{
{
desc: "simple struct",
value: encodetestInstance,
},
{
desc: "nested struct",
value: nestedInstance,
},
{
desc: "deep_bson.json.gz",
value: readExtJSONFile("deep_bson.json.gz"),
},
{
desc: "flat_bson.json.gz",
value: readExtJSONFile("flat_bson.json.gz"),
},
{
desc: "full_bson.json.gz",
value: readExtJSONFile("full_bson.json.gz"),
},
}
for _, tc := range cases {
b.Run(tc.desc, func(b *testing.B) {
b.Run("BSON", func(b *testing.B) {
data, err := Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling BSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := Unmarshal(data, &v2)
if err != nil {
b.Errorf("error unmarshalling BSON: %s", err)
}
}
})
b.Run("extJSON", func(b *testing.B) {
data, err := MarshalExtJSON(tc.value, true, false)
if err != nil {
b.Errorf("error marshalling extended JSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := UnmarshalExtJSON(data, true, &v2)
if err != nil {
b.Errorf("error unmarshalling extended JSON: %s", err)
}
}
})
b.Run("JSON", func(b *testing.B) {
data, err := json.Marshal(tc.value)
if err != nil {
b.Errorf("error marshalling JSON: %s", err)
return
}
b.ResetTimer()
var v2 map[string]interface{}
for i := 0; i < b.N; i++ {
err := json.Unmarshal(data, &v2)
if err != nil {
b.Errorf("error unmarshalling JSON: %s", err)
}
}
})
})
}
}

50
mongo/bson/bson.go Normal file
View File

@@ -0,0 +1,50 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
//
// Based on gopkg.in/mgo.v2/bson by Gustavo Niemeyer
// See THIRD-PARTY-NOTICES for original license terms.
package bson // import "go.mongodb.org/mongo-driver/bson"
import (
"go.mongodb.org/mongo-driver/bson/primitive"
)
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// D is an ordered representation of a BSON document. This type should be used when the order of the elements matters,
// such as MongoDB command documents. If the order of the elements does not matter, an M should be used instead.
//
// A D should not be constructed with duplicate key names, as that can cause undefined server behavior.
//
// Example usage:
//
// bson.D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
type D = primitive.D
// E represents a BSON element for a D. It is usually used inside a D.
type E = primitive.E
// M is an unordered representation of a BSON document. This type should be used when the order of the elements does not
// matter. This type is handled as a regular map[string]interface{} when encoding and decoding. Elements will be
// serialized in an undefined, random order. If the order of the elements matters, a D should be used instead.
//
// Example usage:
//
// bson.M{"foo": "bar", "hello": "world", "pi": 3.14159}
type M = primitive.M
// An A is an ordered representation of a BSON array.
//
// Example usage:
//
// bson.A{"bar", "world", 3.14159, bson.D{{"qux", 12345}}}
type A = primitive.A

View File

@@ -0,0 +1,530 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"encoding/hex"
"encoding/json"
"fmt"
"math"
"os"
"path"
"strconv"
"strings"
"testing"
"unicode"
"unicode/utf8"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
"github.com/tidwall/pretty"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
type testCase struct {
Description string `json:"description"`
BsonType string `json:"bson_type"`
TestKey *string `json:"test_key"`
Valid []validityTestCase `json:"valid"`
DecodeErrors []decodeErrorTestCase `json:"decodeErrors"`
ParseErrors []parseErrorTestCase `json:"parseErrors"`
Deprecated *bool `json:"deprecated"`
}
type validityTestCase struct {
Description string `json:"description"`
CanonicalBson string `json:"canonical_bson"`
CanonicalExtJSON string `json:"canonical_extjson"`
RelaxedExtJSON *string `json:"relaxed_extjson"`
DegenerateBSON *string `json:"degenerate_bson"`
DegenerateExtJSON *string `json:"degenerate_extjson"`
ConvertedBSON *string `json:"converted_bson"`
ConvertedExtJSON *string `json:"converted_extjson"`
Lossy *bool `json:"lossy"`
}
type decodeErrorTestCase struct {
Description string `json:"description"`
Bson string `json:"bson"`
}
type parseErrorTestCase struct {
Description string `json:"description"`
String string `json:"string"`
}
const dataDir = "../testdata/bson-corpus/"
func findJSONFilesInDir(dir string) ([]string, error) {
files := make([]string, 0)
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, entry := range entries {
if entry.IsDir() || path.Ext(entry.Name()) != ".json" {
continue
}
files = append(files, entry.Name())
}
return files, nil
}
// seedExtJSON will add the byte representation of the "extJSON" string to the fuzzer's coprus.
func seedExtJSON(f *testing.F, extJSON string, extJSONType string, desc string) {
jbytes, err := jsonToBytes(extJSON, extJSONType, desc)
if err != nil {
f.Fatalf("failed to convert JSON to bytes: %v", err)
}
f.Add(jbytes)
}
// seedTestCase will add the byte representation for each "extJSON" string of each valid test case to the fuzzer's
// corpus.
func seedTestCase(f *testing.F, tcase *testCase) {
for _, vtc := range tcase.Valid {
seedExtJSON(f, vtc.CanonicalExtJSON, "canonical", vtc.Description)
// Seed the relaxed extended JSON.
if vtc.RelaxedExtJSON != nil {
seedExtJSON(f, *vtc.RelaxedExtJSON, "relaxed", vtc.Description)
}
// Seed the degenerate extended JSON.
if vtc.DegenerateExtJSON != nil {
seedExtJSON(f, *vtc.DegenerateExtJSON, "degenerate", vtc.Description)
}
// Seed the converted extended JSON.
if vtc.ConvertedExtJSON != nil {
seedExtJSON(f, *vtc.ConvertedExtJSON, "converted", vtc.Description)
}
}
}
// seedBSONCorpus will unmarshal the data from "testdata/bson-corpus" into a slice of "testCase" structs and then
// marshal the "*_extjson" field of each "validityTestCase" into a slice of bytes to seed the fuzz corpus.
func seedBSONCorpus(f *testing.F) {
fileNames, err := findJSONFilesInDir(dataDir)
if err != nil {
f.Fatalf("failed to find JSON files in directory %q: %v", dataDir, err)
}
for _, fileName := range fileNames {
filePath := path.Join(dataDir, fileName)
file, err := os.Open(filePath)
if err != nil {
f.Fatalf("failed to open file %q: %v", filePath, err)
}
var tcase testCase
if err := json.NewDecoder(file).Decode(&tcase); err != nil {
f.Fatal(err)
}
seedTestCase(f, &tcase)
}
}
func needsEscapedUnicode(bsonType string) bool {
return bsonType == "0x02" || bsonType == "0x0D" || bsonType == "0x0E" || bsonType == "0x0F"
}
func unescapeUnicode(s, bsonType string) string {
if !needsEscapedUnicode(bsonType) {
return s
}
newS := ""
for i := 0; i < len(s); i++ {
c := s[i]
switch c {
case '\\':
switch s[i+1] {
case 'u':
us := s[i : i+6]
u, err := strconv.Unquote(strings.Replace(strconv.Quote(us), `\\u`, `\u`, 1))
if err != nil {
return ""
}
for _, r := range u {
if r < ' ' {
newS += fmt.Sprintf(`\u%04x`, r)
} else {
newS += string(r)
}
}
i += 5
default:
newS += string(c)
}
default:
if c > unicode.MaxASCII {
r, size := utf8.DecodeRune([]byte(s[i:]))
newS += string(r)
i += size - 1
} else {
newS += string(c)
}
}
}
return newS
}
func formatDouble(f float64) string {
var s string
if math.IsInf(f, 1) {
s = "Infinity"
} else if math.IsInf(f, -1) {
s = "-Infinity"
} else if math.IsNaN(f) {
s = "NaN"
} else {
// Print exactly one decimalType place for integers; otherwise, print as many are necessary to
// perfectly represent it.
s = strconv.FormatFloat(f, 'G', -1, 64)
if !strings.ContainsRune(s, 'E') && !strings.ContainsRune(s, '.') {
s += ".0"
}
}
return s
}
func normalizeCanonicalDouble(t *testing.T, key string, cEJ string) string {
// Unmarshal string into map
cEJMap := make(map[string]map[string]string)
err := json.Unmarshal([]byte(cEJ), &cEJMap)
require.NoError(t, err)
// Parse the float contained by the map.
expectedString := cEJMap[key]["$numberDouble"]
expectedFloat, err := strconv.ParseFloat(expectedString, 64)
require.NoError(t, err)
// Normalize the string
return fmt.Sprintf(`{"%s":{"$numberDouble":"%s"}}`, key, formatDouble(expectedFloat))
}
func normalizeRelaxedDouble(t *testing.T, key string, rEJ string) string {
// Unmarshal string into map
rEJMap := make(map[string]float64)
err := json.Unmarshal([]byte(rEJ), &rEJMap)
if err != nil {
return normalizeCanonicalDouble(t, key, rEJ)
}
// Parse the float contained by the map.
expectedFloat := rEJMap[key]
// Normalize the string
return fmt.Sprintf(`{"%s":%s}`, key, formatDouble(expectedFloat))
}
// bsonToNative decodes the BSON bytes (b) into a native Document
func bsonToNative(t *testing.T, b []byte, bType, testDesc string) D {
var doc D
err := Unmarshal(b, &doc)
expectNoError(t, err, fmt.Sprintf("%s: decoding %s BSON", testDesc, bType))
return doc
}
// nativeToBSON encodes the native Document (doc) into canonical BSON and compares it to the expected
// canonical BSON (cB)
func nativeToBSON(t *testing.T, cB []byte, doc D, testDesc, bType, docSrcDesc string) {
actual, err := Marshal(doc)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s BSON", testDesc, bType))
if diff := cmp.Diff(cB, actual); diff != "" {
t.Errorf("%s: 'native_to_bson(%s) = cB' failed (-want, +got):\n-%v\n+%v\n",
testDesc, docSrcDesc, cB, actual)
t.FailNow()
}
}
// jsonToNative decodes the extended JSON string (ej) into a native Document
func jsonToNative(ej, ejType, testDesc string) (D, error) {
var doc D
if err := UnmarshalExtJSON([]byte(ej), ejType != "relaxed", &doc); err != nil {
return nil, fmt.Errorf("%s: decoding %s extended JSON: %w", testDesc, ejType, err)
}
return doc, nil
}
// jsonToBytes decodes the extended JSON string (ej) into canonical BSON and then encodes it into a byte slice.
func jsonToBytes(ej, ejType, testDesc string) ([]byte, error) {
native, err := jsonToNative(ej, ejType, testDesc)
if err != nil {
return nil, err
}
b, err := Marshal(native)
if err != nil {
return nil, fmt.Errorf("%s: encoding %s BSON: %w", testDesc, ejType, err)
}
return b, nil
}
// nativeToJSON encodes the native Document (doc) into an extended JSON string
func nativeToJSON(t *testing.T, ej string, doc D, testDesc, ejType, ejShortName, docSrcDesc string) {
actualEJ, err := MarshalExtJSON(doc, ejType != "relaxed", true)
expectNoError(t, err, fmt.Sprintf("%s: encoding %s extended JSON", testDesc, ejType))
if diff := cmp.Diff(ej, string(actualEJ)); diff != "" {
t.Errorf("%s: 'native_to_%s_extended_json(%s) = %s' failed (-want, +got):\n%s\n",
testDesc, ejType, docSrcDesc, ejShortName, diff)
t.FailNow()
}
}
func runTest(t *testing.T, file string) {
filepath := path.Join(dataDir, file)
content, err := os.ReadFile(filepath)
require.NoError(t, err)
// Remove ".json" from filename.
file = file[:len(file)-5]
testName := "bson_corpus--" + file
t.Run(testName, func(t *testing.T) {
var test testCase
require.NoError(t, json.Unmarshal(content, &test))
t.Run("valid", func(t *testing.T) {
for _, v := range test.Valid {
t.Run(v.Description, func(t *testing.T) {
// get canonical BSON
cB, err := hex.DecodeString(v.CanonicalBson)
expectNoError(t, err, fmt.Sprintf("%s: reading canonical BSON", v.Description))
// get canonical extended JSON
cEJ := unescapeUnicode(string(pretty.Ugly([]byte(v.CanonicalExtJSON))), test.BsonType)
if test.BsonType == "0x01" {
cEJ = normalizeCanonicalDouble(t, *test.TestKey, cEJ)
}
/*** canonical BSON round-trip tests ***/
doc := bsonToNative(t, cB, "canonical", v.Description)
// native_to_bson(bson_to_native(cB)) = cB
nativeToBSON(t, cB, doc, v.Description, "canonical", "bson_to_native(cB)")
// native_to_canonical_extended_json(bson_to_native(cB)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "bson_to_native(cB)")
// native_to_relaxed_extended_json(bson_to_native(cB)) = rEJ (if rEJ exists)
if v.RelaxedExtJSON != nil {
rEJ := unescapeUnicode(string(pretty.Ugly([]byte(*v.RelaxedExtJSON))), test.BsonType)
if test.BsonType == "0x01" {
rEJ = normalizeRelaxedDouble(t, *test.TestKey, rEJ)
}
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "rEJ", "bson_to_native(cB)")
/*** relaxed extended JSON round-trip tests (if exists) ***/
doc, err = jsonToNative(rEJ, "relaxed", v.Description)
require.NoError(t, err)
// native_to_relaxed_extended_json(json_to_native(rEJ)) = rEJ
nativeToJSON(t, rEJ, doc, v.Description, "relaxed", "eJR", "json_to_native(rEJ)")
}
/*** canonical extended JSON round-trip tests ***/
doc, err = jsonToNative(cEJ, "canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(cEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "canonical", "cEJ", "json_to_native(cEJ)")
// native_to_bson(json_to_native(cEJ)) = cb (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(cEJ)")
}
/*** degenerate BSON round-trip tests (if exists) ***/
if v.DegenerateBSON != nil {
dB, err := hex.DecodeString(*v.DegenerateBSON)
expectNoError(t, err, fmt.Sprintf("%s: reading degenerate BSON", v.Description))
doc = bsonToNative(t, dB, "degenerate", v.Description)
// native_to_bson(bson_to_native(dB)) = cB
nativeToBSON(t, cB, doc, v.Description, "degenerate", "bson_to_native(dB)")
}
/*** degenerate JSON round-trip tests (if exists) ***/
if v.DegenerateExtJSON != nil {
dEJ := unescapeUnicode(string(pretty.Ugly([]byte(*v.DegenerateExtJSON))), test.BsonType)
if test.BsonType == "0x01" {
dEJ = normalizeCanonicalDouble(t, *test.TestKey, dEJ)
}
doc, err = jsonToNative(dEJ, "degenerate canonical", v.Description)
require.NoError(t, err)
// native_to_canonical_extended_json(json_to_native(dEJ)) = cEJ
nativeToJSON(t, cEJ, doc, v.Description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
// native_to_bson(json_to_native(dEJ)) = cB (unless lossy)
if v.Lossy == nil || !*v.Lossy {
nativeToBSON(t, cB, doc, v.Description, "canonical", "json_to_native(dEJ)")
}
}
})
}
})
t.Run("decode error", func(t *testing.T) {
for _, d := range test.DecodeErrors {
t.Run(d.Description, func(t *testing.T) {
b, err := hex.DecodeString(d.Bson)
expectNoError(t, err, d.Description)
var doc D
err = Unmarshal(b, &doc)
// The driver unmarshals invalid UTF-8 strings without error. Loop over the unmarshalled elements
// and assert that there was no error if any of the string or DBPointer values contain invalid UTF-8
// characters.
for _, elem := range doc {
str, ok := elem.Value.(string)
invalidString := ok && !utf8.ValidString(str)
dbPtr, ok := elem.Value.(primitive.DBPointer)
invalidDBPtr := ok && !utf8.ValidString(dbPtr.DB)
if invalidString || invalidDBPtr {
expectNoError(t, err, d.Description)
return
}
}
expectError(t, err, fmt.Sprintf("%s: expected decode error", d.Description))
})
}
})
t.Run("parse error", func(t *testing.T) {
for _, p := range test.ParseErrors {
t.Run(p.Description, func(t *testing.T) {
s := unescapeUnicode(p.String, test.BsonType)
if test.BsonType == "0x13" {
s = fmt.Sprintf(`{"decimal128": {"$numberDecimal": "%s"}}`, s)
}
switch test.BsonType {
case "0x00", "0x05", "0x13":
var doc D
err := UnmarshalExtJSON([]byte(s), true, &doc)
// Null bytes are validated when marshaling to BSON
if strings.Contains(p.Description, "Null") {
_, err = Marshal(doc)
}
expectError(t, err, fmt.Sprintf("%s: expected parse error", p.Description))
default:
t.Errorf("Update test to check for parse errors for type %s", test.BsonType)
t.Fail()
}
})
}
})
})
}
func Test_BsonCorpus(t *testing.T) {
jsonFiles, err := findJSONFilesInDir(dataDir)
if err != nil {
t.Fatalf("error finding JSON files in %s: %v", dataDir, err)
}
for _, file := range jsonFiles {
runTest(t, file)
}
}
func expectNoError(t *testing.T, err error, desc string) {
if err != nil {
t.Helper()
t.Errorf("%s: Unepexted error: %v", desc, err)
t.FailNow()
}
}
func expectError(t *testing.T, err error, desc string) {
if err == nil {
t.Helper()
t.Errorf("%s: Expected error", desc)
t.FailNow()
}
}
func TestRelaxedUUIDValidation(t *testing.T) {
testCases := []struct {
description string
canonicalExtJSON string
degenerateExtJSON string
expectedErr string
}{
{
"valid uuid",
"{\"x\" : { \"$binary\" : {\"base64\" : \"c//SZESzTGmQ6OfR38A11A==\", \"subType\" : \"04\"}}}",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"",
},
{
"invalid uuid--no hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd26444b34c6990e8e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--trailing hyphens",
"",
"{\"x\" : { \"$uuid\" : \"73ffd264-44b3-4c69-90e8-e7d1dfc035--\"}}",
"$uuid value does not follow RFC 4122 format regarding length and hyphens",
},
{
"invalid uuid--malformed hex",
"",
"{\"x\" : { \"$uuid\" : \"q3@fd26l-44b3-4c69-90e8-e7d1dfc035d4\"}}",
"$uuid value does not follow RFC 4122 format regarding hex bytes: encoding/hex: invalid byte: U+0071 'q'",
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
// get canonical extended JSON
cEJ := unescapeUnicode(string(pretty.Ugly([]byte(tc.canonicalExtJSON))), "0x05")
// get degenerate extended JSON
dEJ := unescapeUnicode(string(pretty.Ugly([]byte(tc.degenerateExtJSON))), "0x05")
// convert dEJ to native doc
var doc D
err := UnmarshalExtJSON([]byte(dEJ), true, &doc)
if tc.expectedErr != "" {
assert.Equal(t, tc.expectedErr, err.Error(), "expected error %v, got %v", tc.expectedErr, err)
} else {
assert.Nil(t, err, "expected no error, got error: %v", err)
// Marshal doc into extended JSON and compare with cEJ
nativeToJSON(t, cEJ, doc, tc.description, "degenerate canonical", "cEJ", "json_to_native(dEJ)")
}
})
}
}

279
mongo/bson/bson_test.go Normal file
View File

@@ -0,0 +1,279 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bson
import (
"bytes"
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}
func TestTimeRoundTrip(t *testing.T) {
val := struct {
Value time.Time
ID string
}{
ID: "time-rt-test",
}
if !val.Value.IsZero() {
t.Errorf("Did not get zero time as expected.")
}
bsonOut, err := Marshal(val)
noerr(t, err)
rtval := struct {
Value time.Time
ID string
}{}
err = Unmarshal(bsonOut, &rtval)
noerr(t, err)
if !cmp.Equal(val, rtval) {
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
}
if !rtval.Value.IsZero() {
t.Errorf("Did not get zero time as expected.")
}
}
func TestNonNullTimeRoundTrip(t *testing.T) {
now := time.Now()
now = time.Unix(now.Unix(), 0)
val := struct {
Value time.Time
ID string
}{
ID: "time-rt-test",
Value: now,
}
bsonOut, err := Marshal(val)
noerr(t, err)
rtval := struct {
Value time.Time
ID string
}{}
err = Unmarshal(bsonOut, &rtval)
noerr(t, err)
if !cmp.Equal(val, rtval) {
t.Errorf("Did not round trip properly. got %v; want %v", val, rtval)
}
}
func TestD(t *testing.T) {
t.Run("can marshal", func(t *testing.T) {
d := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendStringElement(want, "foo", "bar")
want = bsoncore.AppendStringElement(want, "hello", "world")
want = bsoncore.AppendDoubleElement(want, "pi", 3.14159)
want, err := bsoncore.AppendDocumentEnd(want, idx)
noerr(t, err)
got, err := Marshal(d)
noerr(t, err)
if !bytes.Equal(got, want) {
t.Errorf("Marshaled documents do not match. got %v; want %v", Raw(got), Raw(want))
}
})
t.Run("can unmarshal", func(t *testing.T) {
want := D{{"foo", "bar"}, {"hello", "world"}, {"pi", 3.14159}}
idx, doc := bsoncore.AppendDocumentStart(nil)
doc = bsoncore.AppendStringElement(doc, "foo", "bar")
doc = bsoncore.AppendStringElement(doc, "hello", "world")
doc = bsoncore.AppendDoubleElement(doc, "pi", 3.14159)
doc, err := bsoncore.AppendDocumentEnd(doc, idx)
noerr(t, err)
var got D
err = Unmarshal(doc, &got)
noerr(t, err)
if !cmp.Equal(got, want) {
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, want)
}
})
}
type stringerString string
func (ss stringerString) String() string {
return "bar"
}
type keyBool bool
func (kb keyBool) MarshalKey() (string, error) {
return fmt.Sprintf("%v", kb), nil
}
func (kb *keyBool) UnmarshalKey(key string) error {
switch key {
case "true":
*kb = true
case "false":
*kb = false
default:
return fmt.Errorf("invalid bool value %v", key)
}
return nil
}
type keyStruct struct {
val int64
}
func (k keyStruct) MarshalText() (text []byte, err error) {
str := strconv.FormatInt(k.val, 10)
return []byte(str), nil
}
func (k *keyStruct) UnmarshalText(text []byte) error {
val, err := strconv.ParseInt(string(text), 10, 64)
if err != nil {
return err
}
*k = keyStruct{
val: val,
}
return nil
}
func TestMapCodec(t *testing.T) {
t.Run("EncodeKeysWithStringer", func(t *testing.T) {
strstr := stringerString("foo")
mapObj := map[stringerString]int{strstr: 1}
testCases := []struct {
name string
opts *bsonoptions.MapCodecOptions
key string
}{
{"default", bsonoptions.MapCodec(), "foo"},
{"true", bsonoptions.MapCodec().SetEncodeKeysWithStringer(true), "bar"},
{"false", bsonoptions.MapCodec().SetEncodeKeysWithStringer(false), "foo"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mapCodec := bsoncodec.NewMapCodec(tc.opts)
mapRegistry := NewRegistryBuilder().RegisterDefaultEncoder(reflect.Map, mapCodec).Build()
val, err := MarshalWithRegistry(mapRegistry, mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
assert.True(t, strings.Contains(string(val), tc.key), "expected result to contain %v, got: %v", tc.key, string(val))
})
}
})
t.Run("keys implements keyMarshaler and keyUnmarshaler", func(t *testing.T) {
mapObj := map[keyBool]int{keyBool(true): 1}
doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "true", 1)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
var got map[keyBool]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
})
t.Run("keys implements encoding.TextMarshaler and encoding.TextUnmarshaler", func(t *testing.T) {
mapObj := map[keyStruct]int{
{val: 10}: 100,
}
doc, err := Marshal(mapObj)
assert.Nil(t, err, "Marshal error: %v", err)
idx, want := bsoncore.AppendDocumentStart(nil)
want = bsoncore.AppendInt32Element(want, "10", 100)
want, _ = bsoncore.AppendDocumentEnd(want, idx)
assert.Equal(t, want, doc, "expected result %v, got %v", string(want), string(doc))
var got map[keyStruct]int
err = Unmarshal(doc, &got)
assert.Nil(t, err, "Unmarshal error: %v", err)
assert.Equal(t, mapObj, got, "expected result %v, got %v", mapObj, got)
})
}
func TestExtJSONEscapeKey(t *testing.T) {
doc := D{{Key: "\\usb#", Value: int32(1)}}
b, err := MarshalExtJSON(&doc, false, false)
noerr(t, err)
want := "{\"\\\\usb#\":1}"
if diff := cmp.Diff(want, string(b)); diff != "" {
t.Errorf("Marshaled documents do not match. got %v, want %v", string(b), want)
}
var got D
err = UnmarshalExtJSON(b, false, &got)
noerr(t, err)
if !cmp.Equal(got, doc) {
t.Errorf("Unmarshaled documents do not match. got %v; want %v", got, doc)
}
}
func TestBsoncoreArray(t *testing.T) {
type BSONDocumentArray struct {
Array []D `bson:"array"`
}
type BSONArray struct {
Array bsoncore.Array `bson:"array"`
}
bda := BSONDocumentArray{
Array: []D{
{{"x", 1}},
{{"x", 2}},
{{"x", 3}},
},
}
expectedBSON, err := Marshal(bda)
assert.Nil(t, err, "Marshal bsoncore.Document array error: %v", err)
var ba BSONArray
err = Unmarshal(expectedBSON, &ba)
assert.Nil(t, err, "Unmarshal error: %v", err)
actualBSON, err := Marshal(ba)
assert.Nil(t, err, "Marshal bsoncore.Array error: %v", err)
assert.Equal(t, expectedBSON, actualBSON,
"expected BSON to be %v after Marshalling again; got %v", expectedBSON, actualBSON)
doc := bsoncore.Document(actualBSON)
v := doc.Lookup("array")
assert.Equal(t, bsontype.Array, v.Type, "expected type array, got %v", v.Type)
}

View File

@@ -0,0 +1,50 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
// ArrayCodec is the Codec used for bsoncore.Array values.
type ArrayCodec struct{}
var defaultArrayCodec = NewArrayCodec()
// NewArrayCodec returns an ArrayCodec.
func NewArrayCodec() *ArrayCodec {
return &ArrayCodec{}
}
// EncodeValue is the ValueEncoder for bsoncore.Array values.
func (ac *ArrayCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreArray {
return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
arr := val.Interface().(bsoncore.Array)
return bsonrw.Copier{}.CopyArrayFromBytes(vw, arr)
}
// DecodeValue is the ValueDecoder for bsoncore.Array values.
func (ac *ArrayCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tCoreArray {
return ValueDecoderError{Name: "CoreArrayDecodeValue", Types: []reflect.Type{tCoreArray}, Received: val}
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, 0))
}
val.SetLen(0)
arr, err := bsonrw.Copier{}.AppendArrayBytes(val.Interface().(bsoncore.Array), vr)
val.Set(reflect.ValueOf(arr))
return err
}

View File

@@ -0,0 +1,238 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec // import "go.mongodb.org/mongo-driver/bson/bsoncodec"
import (
"fmt"
"reflect"
"strings"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var (
emptyValue = reflect.Value{}
)
// Marshaler is an interface implemented by types that can marshal themselves
// into a BSON document represented as bytes. The bytes returned must be a valid
// BSON document if the error is nil.
type Marshaler interface {
MarshalBSON() ([]byte, error)
}
// ValueMarshaler is an interface implemented by types that can marshal
// themselves into a BSON value as bytes. The type must be the valid type for
// the bytes returned. The bytes and byte type together must be valid if the
// error is nil.
type ValueMarshaler interface {
MarshalBSONValue() (bsontype.Type, []byte, error)
}
// Unmarshaler is an interface implemented by types that can unmarshal a BSON
// document representation of themselves. The BSON bytes can be assumed to be
// valid. UnmarshalBSON must copy the BSON bytes if it wishes to retain the data
// after returning.
type Unmarshaler interface {
UnmarshalBSON([]byte) error
}
// ValueUnmarshaler is an interface implemented by types that can unmarshal a
// BSON value representation of themselves. The BSON bytes and type can be
// assumed to be valid. UnmarshalBSONValue must copy the BSON value bytes if it
// wishes to retain the data after returning.
type ValueUnmarshaler interface {
UnmarshalBSONValue(bsontype.Type, []byte) error
}
// ValueEncoderError is an error returned from a ValueEncoder when the provided value can't be
// encoded by the ValueEncoder.
type ValueEncoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vee ValueEncoderError) Error() string {
typeKinds := make([]string, 0, len(vee.Types)+len(vee.Kinds))
for _, t := range vee.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vee.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vee.Received.Kind().String()
if vee.Received.IsValid() {
received = vee.Received.Type().String()
}
return fmt.Sprintf("%s can only encode valid %s, but got %s", vee.Name, strings.Join(typeKinds, ", "), received)
}
// ValueDecoderError is an error returned from a ValueDecoder when the provided value can't be
// decoded by the ValueDecoder.
type ValueDecoderError struct {
Name string
Types []reflect.Type
Kinds []reflect.Kind
Received reflect.Value
}
func (vde ValueDecoderError) Error() string {
typeKinds := make([]string, 0, len(vde.Types)+len(vde.Kinds))
for _, t := range vde.Types {
typeKinds = append(typeKinds, t.String())
}
for _, k := range vde.Kinds {
if k == reflect.Map {
typeKinds = append(typeKinds, "map[string]*")
continue
}
typeKinds = append(typeKinds, k.String())
}
received := vde.Received.Kind().String()
if vde.Received.IsValid() {
received = vde.Received.Type().String()
}
return fmt.Sprintf("%s can only decode valid and settable %s, but got %s", vde.Name, strings.Join(typeKinds, ", "), received)
}
// EncodeContext is the contextual information required for a Codec to encode a
// value.
type EncodeContext struct {
*Registry
MinSize bool
}
// DecodeContext is the contextual information required for a Codec to decode a
// value.
type DecodeContext struct {
*Registry
Truncate bool
// Ancestor is the type of a containing document. This is mainly used to determine what type
// should be used when decoding an embedded document into an empty interface. For example, if
// Ancestor is a bson.M, BSON embedded document values being decoded into an empty interface
// will be decoded into a bson.M.
//
// Deprecated: Use DefaultDocumentM or DefaultDocumentD instead.
Ancestor reflect.Type
// defaultDocumentType specifies the Go type to decode top-level and nested BSON documents into. In particular, the
// usage for this field is restricted to data typed as "interface{}" or "map[string]interface{}". If DocumentType is
// set to a type that a BSON document cannot be unmarshaled into (e.g. "string"), unmarshalling will result in an
// error. DocumentType overrides the Ancestor field.
defaultDocumentType reflect.Type
}
// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
func (dc *DecodeContext) DefaultDocumentM() {
dc.defaultDocumentType = reflect.TypeOf(primitive.M{})
}
// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as
// "interface{}" or "map[string]interface{}".
func (dc *DecodeContext) DefaultDocumentD() {
dc.defaultDocumentType = reflect.TypeOf(primitive.D{})
}
// ValueCodec is the interface that groups the methods to encode and decode
// values.
type ValueCodec interface {
ValueEncoder
ValueDecoder
}
// ValueEncoder is the interface implemented by types that can handle the encoding of a value.
type ValueEncoder interface {
EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error
}
// ValueEncoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueEncoder.
type ValueEncoderFunc func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error
// EncodeValue implements the ValueEncoder interface.
func (fn ValueEncoderFunc) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
return fn(ec, vw, val)
}
// ValueDecoder is the interface implemented by types that can handle the decoding of a value.
type ValueDecoder interface {
DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error
}
// ValueDecoderFunc is an adapter function that allows a function with the correct signature to be
// used as a ValueDecoder.
type ValueDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) error
// DecodeValue implements the ValueDecoder interface.
func (fn ValueDecoderFunc) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
return fn(dc, vr, val)
}
// typeDecoder is the interface implemented by types that can handle the decoding of a value given its type.
type typeDecoder interface {
decodeType(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error)
}
// typeDecoderFunc is an adapter function that allows a function with the correct signature to be used as a typeDecoder.
type typeDecoderFunc func(DecodeContext, bsonrw.ValueReader, reflect.Type) (reflect.Value, error)
func (fn typeDecoderFunc) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
return fn(dc, vr, t)
}
// decodeAdapter allows two functions with the correct signatures to be used as both a ValueDecoder and typeDecoder.
type decodeAdapter struct {
ValueDecoderFunc
typeDecoderFunc
}
var _ ValueDecoder = decodeAdapter{}
var _ typeDecoder = decodeAdapter{}
// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type
// t and calls decoder.DecodeValue on it.
func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
td, _ := decoder.(typeDecoder)
return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true)
}
func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type, convert bool) (reflect.Value, error) {
if td != nil {
val, err := td.decodeType(dc, vr, t)
if err == nil && convert && val.Type() != t {
// This conversion step is necessary for slices and maps. If a user declares variables like:
//
// type myBool bool
// var m map[string]myBool
//
// and tries to decode BSON bytes into the map, the decoding will fail if this conversion is not present
// because we'll try to assign a value of type bool to one of type myBool.
val = val.Convert(t)
}
return val, err
}
val := reflect.New(t).Elem()
err := vd.DecodeValue(dc, vr, val)
return val, err
}
// CodecZeroer is the interface implemented by Codecs that can also determine if
// a value of the type that would be encoded is zero.
type CodecZeroer interface {
IsTypeZero(interface{}) bool
}

View File

@@ -0,0 +1,143 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"testing"
"time"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
func ExampleValueEncoder() {
var _ ValueEncoderFunc = func(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{Name: "StringEncodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
return vw.WriteString(val.String())
}
}
func ExampleValueDecoder() {
var _ ValueDecoderFunc = func(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
if vr.Type() != bsontype.String {
return fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
str, err := vr.ReadString()
if err != nil {
return err
}
val.SetString(str)
return nil
}
}
func noerr(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Errorf("Unexpected error: (%T)%v", err, err)
t.FailNow()
}
}
func compareTime(t1, t2 time.Time) bool {
if t1.Location() != t2.Location() {
return false
}
return t1.Equal(t2)
}
func compareErrors(err1, err2 error) bool {
if err1 == nil && err2 == nil {
return true
}
if err1 == nil || err2 == nil {
return false
}
if err1.Error() != err2.Error() {
return false
}
return true
}
func compareDecimal128(d1, d2 primitive.Decimal128) bool {
d1H, d1L := d1.GetBytes()
d2H, d2L := d2.GetBytes()
if d1H != d2H {
return false
}
if d1L != d2L {
return false
}
return true
}
type noPrivateFields struct {
a string
}
func compareNoPrivateFields(npf1, npf2 noPrivateFields) bool {
return npf1.a != npf2.a // We don't want these to be equal
}
type zeroTest struct {
reportZero bool
}
func (z zeroTest) IsZero() bool { return z.reportZero }
func compareZeroTest(_, _ zeroTest) bool { return true }
type nonZeroer struct {
value bool
}
type llCodec struct {
t *testing.T
decodeval interface{}
encodeval interface{}
err error
}
func (llc *llCodec) EncodeValue(_ EncodeContext, _ bsonrw.ValueWriter, i interface{}) error {
if llc.err != nil {
return llc.err
}
llc.encodeval = i
return nil
}
func (llc *llCodec) DecodeValue(_ DecodeContext, _ bsonrw.ValueReader, val reflect.Value) error {
if llc.err != nil {
return llc.err
}
if !reflect.TypeOf(llc.decodeval).AssignableTo(val.Type()) {
llc.t.Errorf("decodeval must be assignable to val provided to DecodeValue, but is not. decodeval %T; val %T", llc.decodeval, val)
return nil
}
val.Set(reflect.ValueOf(llc.decodeval))
return nil
}

View File

@@ -0,0 +1,111 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ByteSliceCodec is the Codec used for []byte values.
type ByteSliceCodec struct {
EncodeNilAsEmpty bool
}
var (
defaultByteSliceCodec = NewByteSliceCodec()
_ ValueCodec = defaultByteSliceCodec
_ typeDecoder = defaultByteSliceCodec
)
// NewByteSliceCodec returns a StringCodec with options opts.
func NewByteSliceCodec(opts ...*bsonoptions.ByteSliceCodecOptions) *ByteSliceCodec {
byteSliceOpt := bsonoptions.MergeByteSliceCodecOptions(opts...)
codec := ByteSliceCodec{}
if byteSliceOpt.EncodeNilAsEmpty != nil {
codec.EncodeNilAsEmpty = *byteSliceOpt.EncodeNilAsEmpty
}
return &codec
}
// EncodeValue is the ValueEncoder for []byte.
func (bsc *ByteSliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() && !bsc.EncodeNilAsEmpty {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
func (bsc *ByteSliceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tByteSlice {
return emptyValue, ValueDecoderError{
Name: "ByteSliceDecodeValue",
Types: []reflect.Type{tByteSlice},
Received: reflect.Zero(t),
}
}
var data []byte
var err error
switch vrType := vr.Type(); vrType {
case bsontype.String:
str, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
data = []byte(str)
case bsontype.Symbol:
sym, err := vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
data = []byte(sym)
case bsontype.Binary:
var subtype byte
data, subtype, err = vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "[]byte"}
}
case bsontype.Null:
err = vr.ReadNull()
case bsontype.Undefined:
err = vr.ReadUndefined()
default:
return emptyValue, fmt.Errorf("cannot decode %v into a []byte", vrType)
}
if err != nil {
return emptyValue, err
}
return reflect.ValueOf(data), nil
}
// DecodeValue is the ValueDecoder for []byte.
func (bsc *ByteSliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tByteSlice {
return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
elem, err := bsc.decodeType(dc, vr, tByteSlice)
if err != nil {
return err
}
val.Set(elem)
return nil
}

View File

@@ -0,0 +1,63 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonrw"
)
// condAddrEncoder is the encoder used when a pointer to the encoding value has an encoder.
type condAddrEncoder struct {
canAddrEnc ValueEncoder
elseEnc ValueEncoder
}
var _ ValueEncoder = (*condAddrEncoder)(nil)
// newCondAddrEncoder returns an condAddrEncoder.
func newCondAddrEncoder(canAddrEnc, elseEnc ValueEncoder) *condAddrEncoder {
encoder := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc}
return &encoder
}
// EncodeValue is the ValueEncoderFunc for a value that may be addressable.
func (cae *condAddrEncoder) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.CanAddr() {
return cae.canAddrEnc.EncodeValue(ec, vw, val)
}
if cae.elseEnc != nil {
return cae.elseEnc.EncodeValue(ec, vw, val)
}
return ErrNoEncoder{Type: val.Type()}
}
// condAddrDecoder is the decoder used when a pointer to the value has a decoder.
type condAddrDecoder struct {
canAddrDec ValueDecoder
elseDec ValueDecoder
}
var _ ValueDecoder = (*condAddrDecoder)(nil)
// newCondAddrDecoder returns an CondAddrDecoder.
func newCondAddrDecoder(canAddrDec, elseDec ValueDecoder) *condAddrDecoder {
decoder := condAddrDecoder{canAddrDec: canAddrDec, elseDec: elseDec}
return &decoder
}
// DecodeValue is the ValueDecoderFunc for a value that may be addressable.
func (cad *condAddrDecoder) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if val.CanAddr() {
return cad.canAddrDec.DecodeValue(dc, vr, val)
}
if cad.elseDec != nil {
return cad.elseDec.DecodeValue(dc, vr, val)
}
return ErrNoDecoder{Type: val.Type()}
}

View File

@@ -0,0 +1,97 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestCondAddrCodec(t *testing.T) {
var inner int
canAddrVal := reflect.ValueOf(&inner)
addressable := canAddrVal.Elem()
unaddressable := reflect.ValueOf(inner)
rw := &bsonrwtest.ValueReaderWriter{}
t.Run("addressEncode", func(t *testing.T) {
invoked := 0
encode1 := ValueEncoderFunc(func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
invoked = 1
return nil
})
encode2 := ValueEncoderFunc(func(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
invoked = 2
return nil
})
condEncoder := newCondAddrEncoder(encode1, encode2)
testCases := []struct {
name string
val reflect.Value
invoked int
}{
{"canAddr", addressable, 1},
{"else", unaddressable, 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condEncoder.EncodeValue(EncodeContext{}, rw, tc.val)
assert.Nil(t, err, "CondAddrEncoder error: %v", err)
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
})
}
t.Run("error", func(t *testing.T) {
errEncoder := newCondAddrEncoder(encode1, nil)
err := errEncoder.EncodeValue(EncodeContext{}, rw, unaddressable)
want := ErrNoEncoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
})
t.Run("addressDecode", func(t *testing.T) {
invoked := 0
decode1 := ValueDecoderFunc(func(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
invoked = 1
return nil
})
decode2 := ValueDecoderFunc(func(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
invoked = 2
return nil
})
condDecoder := newCondAddrDecoder(decode1, decode2)
testCases := []struct {
name string
val reflect.Value
invoked int
}{
{"canAddr", addressable, 1},
{"else", unaddressable, 2},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := condDecoder.DecodeValue(DecodeContext{}, rw, tc.val)
assert.Nil(t, err, "CondAddrDecoder error: %v", err)
assert.Equal(t, invoked, tc.invoked, "Expected function %v to be called, called %v", tc.invoked, invoked)
})
}
t.Run("error", func(t *testing.T) {
errDecoder := newCondAddrDecoder(decode1, nil)
err := errDecoder.DecodeValue(DecodeContext{}, rw, unaddressable)
want := ErrNoDecoder{Type: unaddressable.Type()}
assert.Equal(t, err, want, "expected error %v, got %v", want, err)
})
})
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,766 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding/json"
"errors"
"fmt"
"math"
"net/url"
"reflect"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var defaultValueEncoders DefaultValueEncoders
var bvwPool = bsonrw.NewBSONValueWriterPool()
var errInvalidValue = errors.New("cannot encode invalid element")
var sliceWriterPool = sync.Pool{
New: func() interface{} {
sw := make(bsonrw.SliceWriter, 0)
return &sw
},
}
func encodeElement(ec EncodeContext, dw bsonrw.DocumentWriter, e primitive.E) error {
vw, err := dw.WriteDocumentElement(e.Key)
if err != nil {
return err
}
if e.Value == nil {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(reflect.TypeOf(e.Value))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, vw, reflect.ValueOf(e.Value))
if err != nil {
return err
}
return nil
}
// DefaultValueEncoders is a namespace type for the default ValueEncoders used
// when creating a registry.
type DefaultValueEncoders struct{}
// RegisterDefaultEncoders will register the encoder methods attached to DefaultValueEncoders with
// the provided RegistryBuilder.
func (dve DefaultValueEncoders) RegisterDefaultEncoders(rb *RegistryBuilder) {
if rb == nil {
panic(errors.New("argument to RegisterDefaultEncoders must not be nil"))
}
rb.
RegisterTypeEncoder(tByteSlice, defaultByteSliceCodec).
RegisterTypeEncoder(tTime, defaultTimeCodec).
RegisterTypeEncoder(tEmpty, defaultEmptyInterfaceCodec).
RegisterTypeEncoder(tCoreArray, defaultArrayCodec).
RegisterTypeEncoder(tOID, ValueEncoderFunc(dve.ObjectIDEncodeValue)).
RegisterTypeEncoder(tDecimal, ValueEncoderFunc(dve.Decimal128EncodeValue)).
RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(dve.JSONNumberEncodeValue)).
RegisterTypeEncoder(tURL, ValueEncoderFunc(dve.URLEncodeValue)).
RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(dve.JavaScriptEncodeValue)).
RegisterTypeEncoder(tSymbol, ValueEncoderFunc(dve.SymbolEncodeValue)).
RegisterTypeEncoder(tBinary, ValueEncoderFunc(dve.BinaryEncodeValue)).
RegisterTypeEncoder(tUndefined, ValueEncoderFunc(dve.UndefinedEncodeValue)).
RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dve.DateTimeEncodeValue)).
RegisterTypeEncoder(tNull, ValueEncoderFunc(dve.NullEncodeValue)).
RegisterTypeEncoder(tRegex, ValueEncoderFunc(dve.RegexEncodeValue)).
RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dve.DBPointerEncodeValue)).
RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(dve.TimestampEncodeValue)).
RegisterTypeEncoder(tMinKey, ValueEncoderFunc(dve.MinKeyEncodeValue)).
RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(dve.MaxKeyEncodeValue)).
RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(dve.CoreDocumentEncodeValue)).
RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(dve.CodeWithScopeEncodeValue)).
RegisterDefaultEncoder(reflect.Bool, ValueEncoderFunc(dve.BooleanEncodeValue)).
RegisterDefaultEncoder(reflect.Int, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int8, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int16, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int32, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Int64, ValueEncoderFunc(dve.IntEncodeValue)).
RegisterDefaultEncoder(reflect.Uint, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint8, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint16, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint32, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Uint64, defaultUIntCodec).
RegisterDefaultEncoder(reflect.Float32, ValueEncoderFunc(dve.FloatEncodeValue)).
RegisterDefaultEncoder(reflect.Float64, ValueEncoderFunc(dve.FloatEncodeValue)).
RegisterDefaultEncoder(reflect.Array, ValueEncoderFunc(dve.ArrayEncodeValue)).
RegisterDefaultEncoder(reflect.Map, defaultMapCodec).
RegisterDefaultEncoder(reflect.Slice, defaultSliceCodec).
RegisterDefaultEncoder(reflect.String, defaultStringCodec).
RegisterDefaultEncoder(reflect.Struct, newDefaultStructCodec()).
RegisterDefaultEncoder(reflect.Ptr, NewPointerCodec()).
RegisterHookEncoder(tValueMarshaler, ValueEncoderFunc(dve.ValueMarshalerEncodeValue)).
RegisterHookEncoder(tMarshaler, ValueEncoderFunc(dve.MarshalerEncodeValue)).
RegisterHookEncoder(tProxy, ValueEncoderFunc(dve.ProxyEncodeValue))
}
// BooleanEncodeValue is the ValueEncoderFunc for bool types.
func (dve DefaultValueEncoders) BooleanEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Bool {
return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val}
}
return vw.WriteBoolean(val.Bool())
}
func fitsIn32Bits(i int64) bool {
return math.MinInt32 <= i && i <= math.MaxInt32
}
// IntEncodeValue is the ValueEncoderFunc for int types.
func (dve DefaultValueEncoders) IntEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32:
return vw.WriteInt32(int32(val.Int()))
case reflect.Int:
i64 := val.Int()
if fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
case reflect.Int64:
i64 := val.Int()
if ec.MinSize && fitsIn32Bits(i64) {
return vw.WriteInt32(int32(i64))
}
return vw.WriteInt64(i64)
}
return ValueEncoderError{
Name: "IntEncodeValue",
Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int},
Received: val,
}
}
// UintEncodeValue is the ValueEncoderFunc for uint types.
//
// Deprecated: UintEncodeValue is not registered by default. Use UintCodec.EncodeValue instead.
func (dve DefaultValueEncoders) UintEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Uint8, reflect.Uint16:
return vw.WriteInt32(int32(val.Uint()))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
u64 := val.Uint()
if ec.MinSize && u64 <= math.MaxInt32 {
return vw.WriteInt32(int32(u64))
}
if u64 > math.MaxInt64 {
return fmt.Errorf("%d overflows int64", u64)
}
return vw.WriteInt64(int64(u64))
}
return ValueEncoderError{
Name: "UintEncodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
// FloatEncodeValue is the ValueEncoderFunc for float types.
func (dve DefaultValueEncoders) FloatEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Float32, reflect.Float64:
return vw.WriteDouble(val.Float())
}
return ValueEncoderError{Name: "FloatEncodeValue", Kinds: []reflect.Kind{reflect.Float32, reflect.Float64}, Received: val}
}
// StringEncodeValue is the ValueEncoderFunc for string types.
//
// Deprecated: StringEncodeValue is not registered by default. Use StringCodec.EncodeValue instead.
func (dve DefaultValueEncoders) StringEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
// ObjectIDEncodeValue is the ValueEncoderFunc for primitive.ObjectID.
func (dve DefaultValueEncoders) ObjectIDEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tOID {
return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val}
}
return vw.WriteObjectID(val.Interface().(primitive.ObjectID))
}
// Decimal128EncodeValue is the ValueEncoderFunc for primitive.Decimal128.
func (dve DefaultValueEncoders) Decimal128EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDecimal {
return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val}
}
return vw.WriteDecimal128(val.Interface().(primitive.Decimal128))
}
// JSONNumberEncodeValue is the ValueEncoderFunc for json.Number.
func (dve DefaultValueEncoders) JSONNumberEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJSONNumber {
return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val}
}
jsnum := val.Interface().(json.Number)
// Attempt int first, then float64
if i64, err := jsnum.Int64(); err == nil {
return dve.IntEncodeValue(ec, vw, reflect.ValueOf(i64))
}
f64, err := jsnum.Float64()
if err != nil {
return err
}
return dve.FloatEncodeValue(ec, vw, reflect.ValueOf(f64))
}
// URLEncodeValue is the ValueEncoderFunc for url.URL.
func (dve DefaultValueEncoders) URLEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tURL {
return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val}
}
u := val.Interface().(url.URL)
return vw.WriteString(u.String())
}
// TimeEncodeValue is the ValueEncoderFunc for time.TIme.
//
// Deprecated: TimeEncodeValue is not registered by default. Use TimeCodec.EncodeValue instead.
func (dve DefaultValueEncoders) TimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTime {
return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val}
}
tt := val.Interface().(time.Time)
dt := primitive.NewDateTimeFromTime(tt)
return vw.WriteDateTime(int64(dt))
}
// ByteSliceEncodeValue is the ValueEncoderFunc for []byte.
//
// Deprecated: ByteSliceEncodeValue is not registered by default. Use ByteSliceCodec.EncodeValue instead.
func (dve DefaultValueEncoders) ByteSliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tByteSlice {
return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
return vw.WriteBinary(val.Interface().([]byte))
}
// MapEncodeValue is the ValueEncoderFunc for map[string]* types.
//
// Deprecated: MapEncodeValue is not registered by default. Use MapCodec.EncodeValue instead.
func (dve DefaultValueEncoders) MapEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() {
// If we have a nill map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return dve.mapEncodeValue(ec, dw, val, nil)
}
// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (dve DefaultValueEncoders) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
keys := val.MapKeys()
for _, key := range keys {
if collisionFn != nil && collisionFn(key.String()) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := dw.WriteDocumentElement(key.String())
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// ArrayEncodeValue is the ValueEncoderFunc for array types.
func (dve DefaultValueEncoders) ArrayEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Array {
return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val}
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().Elem() == tE {
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for idx := 0; idx < val.Len(); idx++ {
e := val.Index(idx).Interface().(primitive.E)
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// SliceEncodeValue is the ValueEncoderFunc for slice types.
//
// Deprecated: SliceEncodeValue is not registered by default. Use SliceCodec.EncodeValue instead.
func (dve DefaultValueEncoders) SliceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := dve.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
func (dve DefaultValueEncoders) lookupElementEncoder(ec EncodeContext, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) {
if origEncoder != nil || (currVal.Kind() != reflect.Interface) {
return origEncoder, currVal, nil
}
currVal = currVal.Elem()
if !currVal.IsValid() {
return nil, currVal, errInvalidValue
}
currEncoder, err := ec.LookupEncoder(currVal.Type())
return currEncoder, currVal, err
}
// EmptyInterfaceEncodeValue is the ValueEncoderFunc for interface{}.
//
// Deprecated: EmptyInterfaceEncodeValue is not registered by default. Use EmptyInterfaceCodec.EncodeValue instead.
func (dve DefaultValueEncoders) EmptyInterfaceEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
// ValueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations.
func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement ValueMarshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
case val.Type().Implements(tValueMarshaler):
// If ValueMarshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tValueMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tValueMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}
fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue")
returns := fn.Call(nil)
if !returns[2].IsNil() {
return returns[2].Interface().(error)
}
t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data)
}
// MarshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations.
func (dve DefaultValueEncoders) MarshalerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Marshaler
switch {
case !val.IsValid():
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
case val.Type().Implements(tMarshaler):
// If Marshaler is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tMarshaler) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tMarshaler) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}
fn := val.Convert(tMarshaler).MethodByName("MarshalBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
}
data := returns[0].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data)
}
// ProxyEncodeValue is the ValueEncoderFunc for Proxy implementations.
func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
// Either val or a pointer to val must implement Proxy
switch {
case !val.IsValid():
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
case val.Type().Implements(tProxy):
// If Proxy is implemented on a concrete type, make sure that val isn't a nil pointer
if isImplementationNil(val, tProxy) {
return vw.WriteNull()
}
case reflect.PtrTo(val.Type()).Implements(tProxy) && val.CanAddr():
val = val.Addr()
default:
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
}
fn := val.Convert(tProxy).MethodByName("ProxyBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
}
data := returns[0]
var encoder ValueEncoder
var err error
if data.Elem().IsValid() {
encoder, err = ec.LookupEncoder(data.Elem().Type())
} else {
encoder, err = ec.LookupEncoder(nil)
}
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, data.Elem())
}
// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type.
func (DefaultValueEncoders) JavaScriptEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tJavaScript {
return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val}
}
return vw.WriteJavascript(val.String())
}
// SymbolEncodeValue is the ValueEncoderFunc for the primitive.Symbol type.
func (DefaultValueEncoders) SymbolEncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tSymbol {
return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val}
}
return vw.WriteSymbol(val.String())
}
// BinaryEncodeValue is the ValueEncoderFunc for Binary.
func (DefaultValueEncoders) BinaryEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tBinary {
return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val}
}
b := val.Interface().(primitive.Binary)
return vw.WriteBinaryWithSubtype(b.Data, b.Subtype)
}
// UndefinedEncodeValue is the ValueEncoderFunc for Undefined.
func (DefaultValueEncoders) UndefinedEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tUndefined {
return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val}
}
return vw.WriteUndefined()
}
// DateTimeEncodeValue is the ValueEncoderFunc for DateTime.
func (DefaultValueEncoders) DateTimeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDateTime {
return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val}
}
return vw.WriteDateTime(val.Int())
}
// NullEncodeValue is the ValueEncoderFunc for Null.
func (DefaultValueEncoders) NullEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tNull {
return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val}
}
return vw.WriteNull()
}
// RegexEncodeValue is the ValueEncoderFunc for Regex.
func (DefaultValueEncoders) RegexEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRegex {
return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val}
}
regex := val.Interface().(primitive.Regex)
return vw.WriteRegex(regex.Pattern, regex.Options)
}
// DBPointerEncodeValue is the ValueEncoderFunc for DBPointer.
func (DefaultValueEncoders) DBPointerEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tDBPointer {
return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val}
}
dbp := val.Interface().(primitive.DBPointer)
return vw.WriteDBPointer(dbp.DB, dbp.Pointer)
}
// TimestampEncodeValue is the ValueEncoderFunc for Timestamp.
func (DefaultValueEncoders) TimestampEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTimestamp {
return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val}
}
ts := val.Interface().(primitive.Timestamp)
return vw.WriteTimestamp(ts.T, ts.I)
}
// MinKeyEncodeValue is the ValueEncoderFunc for MinKey.
func (DefaultValueEncoders) MinKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMinKey {
return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val}
}
return vw.WriteMinKey()
}
// MaxKeyEncodeValue is the ValueEncoderFunc for MaxKey.
func (DefaultValueEncoders) MaxKeyEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tMaxKey {
return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val}
}
return vw.WriteMaxKey()
}
// CoreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document.
func (DefaultValueEncoders) CoreDocumentEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCoreDocument {
return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val}
}
cdoc := val.Interface().(bsoncore.Document)
return bsonrw.Copier{}.CopyDocumentFromBytes(vw, cdoc)
}
// CodeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope.
func (dve DefaultValueEncoders) CodeWithScopeEncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tCodeWithScope {
return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val}
}
cws := val.Interface().(primitive.CodeWithScope)
dw, err := vw.WriteCodeWithScope(string(cws.Code))
if err != nil {
return err
}
sw := sliceWriterPool.Get().(*bsonrw.SliceWriter)
defer sliceWriterPool.Put(sw)
*sw = (*sw)[:0]
scopeVW := bvwPool.Get(sw)
defer bvwPool.Put(scopeVW)
encoder, err := ec.LookupEncoder(reflect.TypeOf(cws.Scope))
if err != nil {
return err
}
err = encoder.EncodeValue(ec, scopeVW, reflect.ValueOf(cws.Scope))
if err != nil {
return err
}
err = bsonrw.Copier{}.CopyBytesToDocumentWriter(dw, *sw)
if err != nil {
return err
}
return dw.WriteDocumentEnd()
}
// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
func isImplementationNil(val reflect.Value, inter reflect.Type) bool {
vt := val.Type()
for vt.Kind() == reflect.Ptr {
vt = vt.Elem()
}
return vt.Implements(inter) && val.Kind() == reflect.Ptr && val.IsNil()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,90 @@
// Copyright (C) MongoDB, Inc. 2022-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
// Package bsoncodec provides a system for encoding values to BSON representations and decoding
// values from BSON representations. This package considers both binary BSON and ExtendedJSON as
// BSON representations. The types in this package enable a flexible system for handling this
// encoding and decoding.
//
// The codec system is composed of two parts:
//
// 1) ValueEncoders and ValueDecoders that handle encoding and decoding Go values to and from BSON
// representations.
//
// 2) A Registry that holds these ValueEncoders and ValueDecoders and provides methods for
// retrieving them.
//
// # ValueEncoders and ValueDecoders
//
// The ValueEncoder interface is implemented by types that can encode a provided Go type to BSON.
// The value to encode is provided as a reflect.Value and a bsonrw.ValueWriter is used within the
// EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc
// is provided to allow use of a function with the correct signature as a ValueEncoder. An
// EncodeContext instance is provided to allow implementations to lookup further ValueEncoders and
// to provide configuration information.
//
// The ValueDecoder interface is the inverse of the ValueEncoder. Implementations should ensure that
// the value they receive is settable. Similar to ValueEncoderFunc, ValueDecoderFunc is provided to
// allow the use of a function with the correct signature as a ValueDecoder. A DecodeContext
// instance is provided and serves similar functionality to the EncodeContext.
//
// # Registry and RegistryBuilder
//
// A Registry is an immutable store for ValueEncoders, ValueDecoders, and a type map. See the Registry type
// documentation for examples of registering various custom encoders and decoders. A Registry can be constructed using a
// RegistryBuilder, which handles three main types of codecs:
//
// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and RegisterTypeDecoder methods.
// The registered codec will be invoked when encoding/decoding a value whose type matches the registered type exactly.
// If the registered type is an interface, the codec will be invoked when encoding or decoding values whose type is the
// interface, but not for values with concrete types that implement the interface.
//
// 2. Hook encoders/decoders - These can be registered using the RegisterHookEncoder and RegisterHookDecoder methods.
// These methods only accept interface types and the registered codecs will be invoked when encoding or decoding values
// whose types implement the interface. An example of a hook defined by the driver is bson.Marshaler. The driver will
// call the MarshalBSON method for any value whose type implements bson.Marshaler, regardless of the value's concrete
// type.
//
// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type associations are used when
// decoding into a bson.D/bson.M or a struct field of type interface{}. For example, by default, BSON int32 and int64
// values decode as Go int32 and int64 instances, respectively, when decoding into a bson.D. The following code would
// change the behavior so these values decode as Go int instances instead:
//
// intType := reflect.TypeOf(int(0))
// registryBuilder.RegisterTypeMapEntry(bsontype.Int32, intType).RegisterTypeMapEntry(bsontype.Int64, intType)
//
// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and RegisterDefaultDecoder
// methods. The registered codec will be invoked when encoding or decoding values whose reflect.Kind matches the
// registered reflect.Kind as long as the value's type doesn't match a registered type or hook encoder/decoder first.
// These methods should be used to change the behavior for all values for a specific kind.
//
// # Registry Lookup Procedure
//
// When looking up an encoder in a Registry, the precedence rules are as follows:
//
// 1. A type encoder registered for the exact type of the value.
//
// 2. A hook encoder registered for an interface that is implemented by the value or by a pointer to the value. If the
// value matches multiple hooks (e.g. the type implements bsoncodec.Marshaler and bsoncodec.ValueMarshaler), the first
// one registered will be selected. Note that registries constructed using bson.NewRegistryBuilder have driver-defined
// hooks registered for the bsoncodec.Marshaler, bsoncodec.ValueMarshaler, and bsoncodec.Proxy interfaces, so those
// will take precedence over any new hooks.
//
// 3. A kind encoder registered for the value's kind.
//
// If all of these lookups fail to find an encoder, an error of type ErrNoEncoder is returned. The same precedence
// rules apply for decoders, with the exception that an error of type ErrNoDecoder will be returned if no decoder is
// found.
//
// # DefaultValueEncoders and DefaultValueDecoders
//
// The DefaultValueEncoders and DefaultValueDecoders types provide a full set of ValueEncoders and
// ValueDecoders for handling a wide range of Go types, including all of the types within the
// primitive package. To make registering these codecs easier, a helper method on each type is
// provided. For the DefaultValueEncoders type the method is called RegisterDefaultEncoders and for
// the DefaultValueDecoders type the method is called RegisterDefaultDecoders, this method also
// handles registering type map entries for each BSON type.
package bsoncodec

View File

@@ -0,0 +1,147 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
// EmptyInterfaceCodec is the Codec used for interface{} values.
type EmptyInterfaceCodec struct {
DecodeBinaryAsSlice bool
}
var (
defaultEmptyInterfaceCodec = NewEmptyInterfaceCodec()
_ ValueCodec = defaultEmptyInterfaceCodec
_ typeDecoder = defaultEmptyInterfaceCodec
)
// NewEmptyInterfaceCodec returns a EmptyInterfaceCodec with options opts.
func NewEmptyInterfaceCodec(opts ...*bsonoptions.EmptyInterfaceCodecOptions) *EmptyInterfaceCodec {
interfaceOpt := bsonoptions.MergeEmptyInterfaceCodecOptions(opts...)
codec := EmptyInterfaceCodec{}
if interfaceOpt.DecodeBinaryAsSlice != nil {
codec.DecodeBinaryAsSlice = *interfaceOpt.DecodeBinaryAsSlice
}
return &codec
}
// EncodeValue is the ValueEncoderFunc for interface{}.
func (eic EmptyInterfaceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tEmpty {
return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
encoder, err := ec.LookupEncoder(val.Elem().Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, val.Elem())
}
func (eic EmptyInterfaceCodec) getEmptyInterfaceDecodeType(dc DecodeContext, valueType bsontype.Type) (reflect.Type, error) {
isDocument := valueType == bsontype.Type(0) || valueType == bsontype.EmbeddedDocument
if isDocument {
if dc.defaultDocumentType != nil {
// If the bsontype is an embedded document and the DocumentType is set on the DecodeContext, then return
// that type.
return dc.defaultDocumentType, nil
}
if dc.Ancestor != nil {
// Using ancestor information rather than looking up the type map entry forces consistent decoding.
// If we're decoding into a bson.D, subdocuments should also be decoded as bson.D, even if a type map entry
// has been registered.
return dc.Ancestor, nil
}
}
rtype, err := dc.LookupTypeMapEntry(valueType)
if err == nil {
return rtype, nil
}
if isDocument {
// For documents, fallback to looking up a type map entry for bsontype.Type(0) or bsontype.EmbeddedDocument,
// depending on the original valueType.
var lookupType bsontype.Type
switch valueType {
case bsontype.Type(0):
lookupType = bsontype.EmbeddedDocument
case bsontype.EmbeddedDocument:
lookupType = bsontype.Type(0)
}
rtype, err = dc.LookupTypeMapEntry(lookupType)
if err == nil {
return rtype, nil
}
}
return nil, err
}
func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tEmpty {
return emptyValue, ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: reflect.Zero(t)}
}
rtype, err := eic.getEmptyInterfaceDecodeType(dc, vr.Type())
if err != nil {
switch vr.Type() {
case bsontype.Null:
return reflect.Zero(t), vr.ReadNull()
default:
return emptyValue, err
}
}
decoder, err := dc.LookupDecoder(rtype)
if err != nil {
return emptyValue, err
}
elem, err := decodeTypeOrValue(decoder, dc, vr, rtype)
if err != nil {
return emptyValue, err
}
if eic.DecodeBinaryAsSlice && rtype == tBinary {
binElem := elem.Interface().(primitive.Binary)
if binElem.Subtype == bsontype.BinaryGeneric || binElem.Subtype == bsontype.BinaryBinaryOld {
elem = reflect.ValueOf(binElem.Data)
}
}
return elem, nil
}
// DecodeValue is the ValueDecoderFunc for interface{}.
func (eic EmptyInterfaceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tEmpty {
return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val}
}
elem, err := eic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.Set(elem)
return nil
}

View File

@@ -0,0 +1,309 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding"
"fmt"
"reflect"
"strconv"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var defaultMapCodec = NewMapCodec()
// MapCodec is the Codec used for map values.
type MapCodec struct {
DecodeZerosMap bool
EncodeNilAsEmpty bool
EncodeKeysWithStringer bool
}
var _ ValueCodec = &MapCodec{}
// KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
// This applies to types used as map keys and is similar to encoding.TextMarshaler.
type KeyMarshaler interface {
MarshalKey() (key string, err error)
}
// KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
// of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
//
// UnmarshalKey must be able to decode the form generated by MarshalKey.
// UnmarshalKey must copy the text if it wishes to retain the text
// after returning.
type KeyUnmarshaler interface {
UnmarshalKey(key string) error
}
// NewMapCodec returns a MapCodec with options opts.
func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
codec := MapCodec{}
if mapOpt.DecodeZerosMap != nil {
codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
}
if mapOpt.EncodeNilAsEmpty != nil {
codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
}
if mapOpt.EncodeKeysWithStringer != nil {
codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
}
return &codec
}
// EncodeValue is the ValueEncoder for map[*]* types.
func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Map {
return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
if val.IsNil() && !mc.EncodeNilAsEmpty {
// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
// to a TopLevel document. We can't currently tell if this is what actually happened, but if
// there's a deeper underlying problem, the error will also be returned from WriteDocument,
// so just continue. The operations on a map reflection value are valid, so we can call
// MapKeys within mapEncodeValue without a problem.
err := vw.WriteNull()
if err == nil {
return nil
}
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
return mc.mapEncodeValue(ec, dw, val, nil)
}
// mapEncodeValue handles encoding of the values of a map. The collisionFn returns
// true if the provided key exists, this is mainly used for inline maps in the
// struct codec.
func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
keys := val.MapKeys()
for _, key := range keys {
keyStr, err := mc.encodeKey(key)
if err != nil {
return err
}
if collisionFn != nil && collisionFn(keyStr) {
return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
}
currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := dw.WriteDocumentElement(keyStr)
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
// DecodeValue is the ValueDecoder for map[string/decimal]* types.
func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
}
switch vrType := vr.Type(); vrType {
case bsontype.Type(0), bsontype.EmbeddedDocument:
case bsontype.Null:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case bsontype.Undefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeMap(val.Type()))
}
if val.Len() > 0 && mc.DecodeZerosMap {
clearMap(val)
}
eType := val.Type().Elem()
decoder, err := dc.LookupDecoder(eType)
if err != nil {
return err
}
eTypeDecoder, _ := decoder.(typeDecoder)
if eType == tEmpty {
dc.Ancestor = val.Type()
}
keyType := val.Type().Key()
for {
key, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
k, err := mc.decodeKey(key, keyType)
if err != nil {
return err
}
elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
if err != nil {
return newDecodeError(key, err)
}
val.SetMapIndex(k, elem)
}
return nil
}
func clearMap(m reflect.Value) {
var none reflect.Value
for _, k := range m.MapKeys() {
m.SetMapIndex(k, none)
}
}
func (mc *MapCodec) encodeKey(val reflect.Value) (string, error) {
if mc.EncodeKeysWithStringer {
return fmt.Sprint(val), nil
}
// keys of any string type are used directly
if val.Kind() == reflect.String {
return val.String(), nil
}
// KeyMarshalers are marshaled
if km, ok := val.Interface().(KeyMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalKey()
if err == nil {
return buf, nil
}
return "", err
}
// keys implement encoding.TextMarshaler are marshaled.
if km, ok := val.Interface().(encoding.TextMarshaler); ok {
if val.Kind() == reflect.Ptr && val.IsNil() {
return "", nil
}
buf, err := km.MarshalText()
if err != nil {
return "", err
}
return string(buf), nil
}
switch val.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return strconv.FormatInt(val.Int(), 10), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return strconv.FormatUint(val.Uint(), 10), nil
}
return "", fmt.Errorf("unsupported key type: %v", val.Type())
}
var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
keyVal := reflect.ValueOf(key)
var err error
switch {
// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(KeyUnmarshaler)
err = v.UnmarshalKey(key)
keyVal = keyVal.Elem()
// Try to decode encoding.TextUnmarshalers.
case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
keyVal = reflect.New(keyType)
v := keyVal.Interface().(encoding.TextUnmarshaler)
err = v.UnmarshalText([]byte(key))
keyVal = keyVal.Elem()
// Otherwise, go to type specific behavior
default:
switch keyType.Kind() {
case reflect.String:
keyVal = reflect.ValueOf(key).Convert(keyType)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, parseErr := strconv.ParseInt(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, parseErr := strconv.ParseUint(key, 10, 64)
if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
err = fmt.Errorf("failed to unmarshal number key %v", key)
break
}
keyVal = reflect.ValueOf(n).Convert(keyType)
case reflect.Float32, reflect.Float64:
if mc.EncodeKeysWithStringer {
parsed, err := strconv.ParseFloat(key, 64)
if err != nil {
return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
}
keyVal = reflect.ValueOf(parsed)
break
}
fallthrough
default:
return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
}
}
return keyVal, err
}

View File

@@ -0,0 +1,65 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import "fmt"
type mode int
const (
_ mode = iota
mTopLevel
mDocument
mArray
mValue
mElement
mCodeWithScope
mSpacer
)
func (m mode) String() string {
var str string
switch m {
case mTopLevel:
str = "TopLevel"
case mDocument:
str = "DocumentMode"
case mArray:
str = "ArrayMode"
case mValue:
str = "ValueMode"
case mElement:
str = "ElementMode"
case mCodeWithScope:
str = "CodeWithScopeMode"
case mSpacer:
str = "CodeWithScopeSpacerFrame"
default:
str = "UnknownMode"
}
return str
}
// TransitionError is an error returned when an invalid progressing a
// ValueReader or ValueWriter state machine occurs.
type TransitionError struct {
parent mode
current mode
destination mode
}
func (te TransitionError) Error() string {
if te.destination == mode(0) {
return fmt.Sprintf("invalid state transition: cannot read/write value while in %s", te.current)
}
if te.parent == mode(0) {
return fmt.Sprintf("invalid state transition: %s -> %s", te.current, te.destination)
}
return fmt.Sprintf("invalid state transition: %s -> %s; parent %s", te.current, te.destination, te.parent)
}

View File

@@ -0,0 +1,109 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
var _ ValueEncoder = &PointerCodec{}
var _ ValueDecoder = &PointerCodec{}
// PointerCodec is the Codec used for pointers.
type PointerCodec struct {
ecache map[reflect.Type]ValueEncoder
dcache map[reflect.Type]ValueDecoder
l sync.RWMutex
}
// NewPointerCodec returns a PointerCodec that has been initialized.
func NewPointerCodec() *PointerCodec {
return &PointerCodec{
ecache: make(map[reflect.Type]ValueEncoder),
dcache: make(map[reflect.Type]ValueDecoder),
}
}
// EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil
// or looking up an encoder for the type of value the pointer points to.
func (pc *PointerCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.Ptr {
if !val.IsValid() {
return vw.WriteNull()
}
return ValueEncoderError{Name: "PointerCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if val.IsNil() {
return vw.WriteNull()
}
pc.l.RLock()
enc, ok := pc.ecache[val.Type()]
pc.l.RUnlock()
if ok {
if enc == nil {
return ErrNoEncoder{Type: val.Type()}
}
return enc.EncodeValue(ec, vw, val.Elem())
}
enc, err := ec.LookupEncoder(val.Type().Elem())
pc.l.Lock()
pc.ecache[val.Type()] = enc
pc.l.Unlock()
if err != nil {
return err
}
return enc.EncodeValue(ec, vw, val.Elem())
}
// DecodeValue handles decoding a pointer by looking up a decoder for the type it points to and
// using that to decode. If the BSON value is Null, this method will set the pointer to nil.
func (pc *PointerCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Ptr {
return ValueDecoderError{Name: "PointerCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Ptr}, Received: val}
}
if vr.Type() == bsontype.Null {
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
}
if vr.Type() == bsontype.Undefined {
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
}
if val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
pc.l.RLock()
dec, ok := pc.dcache[val.Type()]
pc.l.RUnlock()
if ok {
if dec == nil {
return ErrNoDecoder{Type: val.Type()}
}
return dec.DecodeValue(dc, vr, val.Elem())
}
dec, err := dc.LookupDecoder(val.Type().Elem())
pc.l.Lock()
pc.dcache[val.Type()] = dec
pc.l.Unlock()
if err != nil {
return err
}
return dec.DecodeValue(dc, vr, val.Elem())
}

View File

@@ -0,0 +1,14 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
// Proxy is an interface implemented by types that cannot themselves be directly encoded. Types
// that implement this interface with have ProxyBSON called during the encoding process and that
// value will be encoded in place for the implementer.
type Proxy interface {
ProxyBSON() (interface{}, error)
}

View File

@@ -0,0 +1,469 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"errors"
"fmt"
"reflect"
"sync"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder.
var ErrNilType = errors.New("cannot perform a decoder lookup on <nil>")
// ErrNotPointer is returned when a non-pointer type is provided to LookupDecoder.
var ErrNotPointer = errors.New("non-pointer provided to LookupDecoder")
// ErrNoEncoder is returned when there wasn't an encoder available for a type.
type ErrNoEncoder struct {
Type reflect.Type
}
func (ene ErrNoEncoder) Error() string {
if ene.Type == nil {
return "no encoder found for <nil>"
}
return "no encoder found for " + ene.Type.String()
}
// ErrNoDecoder is returned when there wasn't a decoder available for a type.
type ErrNoDecoder struct {
Type reflect.Type
}
func (end ErrNoDecoder) Error() string {
return "no decoder found for " + end.Type.String()
}
// ErrNoTypeMapEntry is returned when there wasn't a type available for the provided BSON type.
type ErrNoTypeMapEntry struct {
Type bsontype.Type
}
func (entme ErrNoTypeMapEntry) Error() string {
return "no type map entry found for " + entme.Type.String()
}
// ErrNotInterface is returned when the provided type is not an interface.
var ErrNotInterface = errors.New("The provided type is not an interface")
// A RegistryBuilder is used to build a Registry. This type is not goroutine
// safe.
type RegistryBuilder struct {
typeEncoders map[reflect.Type]ValueEncoder
interfaceEncoders []interfaceValueEncoder
kindEncoders map[reflect.Kind]ValueEncoder
typeDecoders map[reflect.Type]ValueDecoder
interfaceDecoders []interfaceValueDecoder
kindDecoders map[reflect.Kind]ValueDecoder
typeMap map[bsontype.Type]reflect.Type
}
// A Registry is used to store and retrieve codecs for types and interfaces. This type is the main
// typed passed around and Encoders and Decoders are constructed from it.
type Registry struct {
typeEncoders map[reflect.Type]ValueEncoder
typeDecoders map[reflect.Type]ValueDecoder
interfaceEncoders []interfaceValueEncoder
interfaceDecoders []interfaceValueDecoder
kindEncoders map[reflect.Kind]ValueEncoder
kindDecoders map[reflect.Kind]ValueDecoder
typeMap map[bsontype.Type]reflect.Type
mu sync.RWMutex
}
// NewRegistryBuilder creates a new empty RegistryBuilder.
func NewRegistryBuilder() *RegistryBuilder {
return &RegistryBuilder{
typeEncoders: make(map[reflect.Type]ValueEncoder),
typeDecoders: make(map[reflect.Type]ValueDecoder),
interfaceEncoders: make([]interfaceValueEncoder, 0),
interfaceDecoders: make([]interfaceValueDecoder, 0),
kindEncoders: make(map[reflect.Kind]ValueEncoder),
kindDecoders: make(map[reflect.Kind]ValueDecoder),
typeMap: make(map[bsontype.Type]reflect.Type),
}
}
func buildDefaultRegistry() *Registry {
rb := NewRegistryBuilder()
defaultValueEncoders.RegisterDefaultEncoders(rb)
defaultValueDecoders.RegisterDefaultDecoders(rb)
return rb.Build()
}
// RegisterCodec will register the provided ValueCodec for the provided type.
func (rb *RegistryBuilder) RegisterCodec(t reflect.Type, codec ValueCodec) *RegistryBuilder {
rb.RegisterTypeEncoder(t, codec)
rb.RegisterTypeDecoder(t, codec)
return rb
}
// RegisterTypeEncoder will register the provided ValueEncoder for the provided type.
//
// The type will be used directly, so an encoder can be registered for a type and a different encoder can be registered
// for a pointer to that type.
//
// If the given type is an interface, the encoder will be called when marshalling a type that is that interface. It
// will not be called when marshalling a non-interface type that implements the interface.
func (rb *RegistryBuilder) RegisterTypeEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
rb.typeEncoders[t] = enc
return rb
}
// RegisterHookEncoder will register an encoder for the provided interface type t. This encoder will be called when
// marshalling a type if the type implements t or a pointer to the type implements t. If the provided type is not
// an interface (i.e. t.Kind() != reflect.Interface), this method will panic.
func (rb *RegistryBuilder) RegisterHookEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
if t.Kind() != reflect.Interface {
panicStr := fmt.Sprintf("RegisterHookEncoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", t, t.Kind())
panic(panicStr)
}
for idx, encoder := range rb.interfaceEncoders {
if encoder.i == t {
rb.interfaceEncoders[idx].ve = enc
return rb
}
}
rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc})
return rb
}
// RegisterTypeDecoder will register the provided ValueDecoder for the provided type.
//
// The type will be used directly, so a decoder can be registered for a type and a different decoder can be registered
// for a pointer to that type.
//
// If the given type is an interface, the decoder will be called when unmarshalling into a type that is that interface.
// It will not be called when unmarshalling into a non-interface type that implements the interface.
func (rb *RegistryBuilder) RegisterTypeDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
rb.typeDecoders[t] = dec
return rb
}
// RegisterHookDecoder will register an decoder for the provided interface type t. This decoder will be called when
// unmarshalling into a type if the type implements t or a pointer to the type implements t. If the provided type is not
// an interface (i.e. t.Kind() != reflect.Interface), this method will panic.
func (rb *RegistryBuilder) RegisterHookDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
if t.Kind() != reflect.Interface {
panicStr := fmt.Sprintf("RegisterHookDecoder expects a type with kind reflect.Interface, "+
"got type %s with kind %s", t, t.Kind())
panic(panicStr)
}
for idx, decoder := range rb.interfaceDecoders {
if decoder.i == t {
rb.interfaceDecoders[idx].vd = dec
return rb
}
}
rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec})
return rb
}
// RegisterEncoder registers the provided type and encoder pair.
//
// Deprecated: Use RegisterTypeEncoder or RegisterHookEncoder instead.
func (rb *RegistryBuilder) RegisterEncoder(t reflect.Type, enc ValueEncoder) *RegistryBuilder {
if t == tEmpty {
rb.typeEncoders[t] = enc
return rb
}
switch t.Kind() {
case reflect.Interface:
for idx, ir := range rb.interfaceEncoders {
if ir.i == t {
rb.interfaceEncoders[idx].ve = enc
return rb
}
}
rb.interfaceEncoders = append(rb.interfaceEncoders, interfaceValueEncoder{i: t, ve: enc})
default:
rb.typeEncoders[t] = enc
}
return rb
}
// RegisterDecoder registers the provided type and decoder pair.
//
// Deprecated: Use RegisterTypeDecoder or RegisterHookDecoder instead.
func (rb *RegistryBuilder) RegisterDecoder(t reflect.Type, dec ValueDecoder) *RegistryBuilder {
if t == nil {
rb.typeDecoders[nil] = dec
return rb
}
if t == tEmpty {
rb.typeDecoders[t] = dec
return rb
}
switch t.Kind() {
case reflect.Interface:
for idx, ir := range rb.interfaceDecoders {
if ir.i == t {
rb.interfaceDecoders[idx].vd = dec
return rb
}
}
rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: t, vd: dec})
default:
rb.typeDecoders[t] = dec
}
return rb
}
// RegisterDefaultEncoder will registr the provided ValueEncoder to the provided
// kind.
func (rb *RegistryBuilder) RegisterDefaultEncoder(kind reflect.Kind, enc ValueEncoder) *RegistryBuilder {
rb.kindEncoders[kind] = enc
return rb
}
// RegisterDefaultDecoder will register the provided ValueDecoder to the
// provided kind.
func (rb *RegistryBuilder) RegisterDefaultDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder {
rb.kindDecoders[kind] = dec
return rb
}
// RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this
// mapping is decoding situations where an empty interface is used and a default type needs to be
// created and decoded into.
//
// By default, BSON documents will decode into interface{} values as bson.D. To change the default type for BSON
// documents, a type map entry for bsontype.EmbeddedDocument should be registered. For example, to force BSON documents
// to decode to bson.Raw, use the following code:
//
// rb.RegisterTypeMapEntry(bsontype.EmbeddedDocument, reflect.TypeOf(bson.Raw{}))
func (rb *RegistryBuilder) RegisterTypeMapEntry(bt bsontype.Type, rt reflect.Type) *RegistryBuilder {
rb.typeMap[bt] = rt
return rb
}
// Build creates a Registry from the current state of this RegistryBuilder.
func (rb *RegistryBuilder) Build() *Registry {
registry := new(Registry)
registry.typeEncoders = make(map[reflect.Type]ValueEncoder)
for t, enc := range rb.typeEncoders {
registry.typeEncoders[t] = enc
}
registry.typeDecoders = make(map[reflect.Type]ValueDecoder)
for t, dec := range rb.typeDecoders {
registry.typeDecoders[t] = dec
}
registry.interfaceEncoders = make([]interfaceValueEncoder, len(rb.interfaceEncoders))
copy(registry.interfaceEncoders, rb.interfaceEncoders)
registry.interfaceDecoders = make([]interfaceValueDecoder, len(rb.interfaceDecoders))
copy(registry.interfaceDecoders, rb.interfaceDecoders)
registry.kindEncoders = make(map[reflect.Kind]ValueEncoder)
for kind, enc := range rb.kindEncoders {
registry.kindEncoders[kind] = enc
}
registry.kindDecoders = make(map[reflect.Kind]ValueDecoder)
for kind, dec := range rb.kindDecoders {
registry.kindDecoders[kind] = dec
}
registry.typeMap = make(map[bsontype.Type]reflect.Type)
for bt, rt := range rb.typeMap {
registry.typeMap[bt] = rt
}
return registry
}
// LookupEncoder inspects the registry for an encoder for the given type. The lookup precedence works as follows:
//
// 1. An encoder registered for the exact type. If the given type represents an interface, an encoder registered using
// RegisterTypeEncoder for the interface will be selected.
//
// 2. An encoder registered using RegisterHookEncoder for an interface implemented by the type or by a pointer to the
// type.
//
// 3. An encoder registered for the reflect.Kind of the value.
//
// If no encoder is found, an error of type ErrNoEncoder is returned.
func (r *Registry) LookupEncoder(t reflect.Type) (ValueEncoder, error) {
encodererr := ErrNoEncoder{Type: t}
r.mu.RLock()
enc, found := r.lookupTypeEncoder(t)
r.mu.RUnlock()
if found {
if enc == nil {
return nil, ErrNoEncoder{Type: t}
}
return enc, nil
}
enc, found = r.lookupInterfaceEncoder(t, true)
if found {
r.mu.Lock()
r.typeEncoders[t] = enc
r.mu.Unlock()
return enc, nil
}
if t == nil {
r.mu.Lock()
r.typeEncoders[t] = nil
r.mu.Unlock()
return nil, encodererr
}
enc, found = r.kindEncoders[t.Kind()]
if !found {
r.mu.Lock()
r.typeEncoders[t] = nil
r.mu.Unlock()
return nil, encodererr
}
r.mu.Lock()
r.typeEncoders[t] = enc
r.mu.Unlock()
return enc, nil
}
func (r *Registry) lookupTypeEncoder(t reflect.Type) (ValueEncoder, bool) {
enc, found := r.typeEncoders[t]
return enc, found
}
func (r *Registry) lookupInterfaceEncoder(t reflect.Type, allowAddr bool) (ValueEncoder, bool) {
if t == nil {
return nil, false
}
for _, ienc := range r.interfaceEncoders {
if t.Implements(ienc.i) {
return ienc.ve, true
}
if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(ienc.i) {
// if *t implements an interface, this will catch if t implements an interface further ahead
// in interfaceEncoders
defaultEnc, found := r.lookupInterfaceEncoder(t, false)
if !found {
defaultEnc = r.kindEncoders[t.Kind()]
}
return newCondAddrEncoder(ienc.ve, defaultEnc), true
}
}
return nil, false
}
// LookupDecoder inspects the registry for an decoder for the given type. The lookup precedence works as follows:
//
// 1. A decoder registered for the exact type. If the given type represents an interface, a decoder registered using
// RegisterTypeDecoder for the interface will be selected.
//
// 2. A decoder registered using RegisterHookDecoder for an interface implemented by the type or by a pointer to the
// type.
//
// 3. A decoder registered for the reflect.Kind of the value.
//
// If no decoder is found, an error of type ErrNoDecoder is returned.
func (r *Registry) LookupDecoder(t reflect.Type) (ValueDecoder, error) {
if t == nil {
return nil, ErrNilType
}
decodererr := ErrNoDecoder{Type: t}
r.mu.RLock()
dec, found := r.lookupTypeDecoder(t)
r.mu.RUnlock()
if found {
if dec == nil {
return nil, ErrNoDecoder{Type: t}
}
return dec, nil
}
dec, found = r.lookupInterfaceDecoder(t, true)
if found {
r.mu.Lock()
r.typeDecoders[t] = dec
r.mu.Unlock()
return dec, nil
}
dec, found = r.kindDecoders[t.Kind()]
if !found {
r.mu.Lock()
r.typeDecoders[t] = nil
r.mu.Unlock()
return nil, decodererr
}
r.mu.Lock()
r.typeDecoders[t] = dec
r.mu.Unlock()
return dec, nil
}
func (r *Registry) lookupTypeDecoder(t reflect.Type) (ValueDecoder, bool) {
dec, found := r.typeDecoders[t]
return dec, found
}
func (r *Registry) lookupInterfaceDecoder(t reflect.Type, allowAddr bool) (ValueDecoder, bool) {
for _, idec := range r.interfaceDecoders {
if t.Implements(idec.i) {
return idec.vd, true
}
if allowAddr && t.Kind() != reflect.Ptr && reflect.PtrTo(t).Implements(idec.i) {
// if *t implements an interface, this will catch if t implements an interface further ahead
// in interfaceDecoders
defaultDec, found := r.lookupInterfaceDecoder(t, false)
if !found {
defaultDec = r.kindDecoders[t.Kind()]
}
return newCondAddrDecoder(idec.vd, defaultDec), true
}
}
return nil, false
}
// LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON
// type. If no type is found, ErrNoTypeMapEntry is returned.
func (r *Registry) LookupTypeMapEntry(bt bsontype.Type) (reflect.Type, error) {
t, ok := r.typeMap[bt]
if !ok || t == nil {
return nil, ErrNoTypeMapEntry{Type: bt}
}
return t, nil
}
type interfaceValueEncoder struct {
i reflect.Type
ve ValueEncoder
}
type interfaceValueDecoder struct {
i reflect.Type
vd ValueDecoder
}

View File

@@ -0,0 +1,125 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec_test
import (
"fmt"
"math"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
func ExampleRegistry_customEncoder() {
// Write a custom encoder for an integer type that is multiplied by -1 when
// encoding.
// To register the default encoders and decoders in addition to this custom
// one, use bson.NewRegistryBuilder instead.
rb := bsoncodec.NewRegistryBuilder()
type negatedInt int
niType := reflect.TypeOf(negatedInt(0))
encoder := func(
ec bsoncodec.EncodeContext,
vw bsonrw.ValueWriter,
val reflect.Value,
) error {
// All encoder implementations should check that val is valid and is of
// the correct type before proceeding.
if !val.IsValid() || val.Type() != niType {
return bsoncodec.ValueEncoderError{
Name: "negatedIntEncodeValue",
Types: []reflect.Type{niType},
Received: val,
}
}
// Negate val and encode as a BSON int32 if it can fit in 32 bits and a
// BSON int64 otherwise.
negatedVal := val.Int() * -1
if math.MinInt32 <= negatedVal && negatedVal <= math.MaxInt32 {
return vw.WriteInt32(int32(negatedVal))
}
return vw.WriteInt64(negatedVal)
}
rb.RegisterTypeEncoder(niType, bsoncodec.ValueEncoderFunc(encoder))
}
func ExampleRegistry_customDecoder() {
// Write a custom decoder for a boolean type that can be stored in the
// database as a BSON boolean, int32, int64, double, or null. BSON int32,
// int64, and double values are considered "true" in this decoder if they
// are non-zero. BSON null values are always considered false.
// To register the default encoders and decoders in addition to this custom
// one, use bson.NewRegistryBuilder instead.
rb := bsoncodec.NewRegistryBuilder()
type lenientBool bool
decoder := func(
dc bsoncodec.DecodeContext,
vr bsonrw.ValueReader,
val reflect.Value,
) error {
// All decoder implementations should check that val is valid, settable,
// and is of the correct kind before proceeding.
if !val.IsValid() || !val.CanSet() || val.Kind() != reflect.Bool {
return bsoncodec.ValueDecoderError{
Name: "lenientBoolDecodeValue",
Kinds: []reflect.Kind{reflect.Bool},
Received: val,
}
}
var result bool
switch vr.Type() {
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return err
}
result = b
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return err
}
result = i32 != 0
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return err
}
result = i64 != 0
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return err
}
result = f64 != 0
case bsontype.Null:
if err := vr.ReadNull(); err != nil {
return err
}
result = false
default:
return fmt.Errorf(
"received invalid BSON type to decode into lenientBool: %s",
vr.Type())
}
val.SetBool(result)
return nil
}
lbType := reflect.TypeOf(lenientBool(true))
rb.RegisterTypeDecoder(lbType, bsoncodec.ValueDecoderFunc(decoder))
}

View File

@@ -0,0 +1,452 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestRegistry(t *testing.T) {
t.Run("Register", func(t *testing.T) {
fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec)
t.Run("interface", func(t *testing.T) {
var t1f *testInterface1
var t2f *testInterface2
var t4f *testInterface4
ips := []interfaceValueEncoder{
{i: reflect.TypeOf(t1f).Elem(), ve: fc1},
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
}
want := []interfaceValueEncoder{
{i: reflect.TypeOf(t1f).Elem(), ve: fc3},
{i: reflect.TypeOf(t2f).Elem(), ve: fc2},
{i: reflect.TypeOf(t4f).Elem(), ve: fc4},
}
rb := NewRegistryBuilder()
for _, ip := range ips {
rb.RegisterHookEncoder(ip.i, ip.ve)
}
got := rb.interfaceEncoders
if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) {
t.Errorf("The registered interfaces are not correct. got %v; want %v", got, want)
}
})
t.Run("type", func(t *testing.T) {
ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{}
rb := NewRegistryBuilder().
RegisterTypeEncoder(reflect.TypeOf(ft1), fc1).
RegisterTypeEncoder(reflect.TypeOf(ft2), fc2).
RegisterTypeEncoder(reflect.TypeOf(ft1), fc3).
RegisterTypeEncoder(reflect.TypeOf(ft4), fc4)
want := []struct {
t reflect.Type
c ValueEncoder
}{
{reflect.TypeOf(ft1), fc3},
{reflect.TypeOf(ft2), fc2},
{reflect.TypeOf(ft4), fc4},
}
got := rb.typeEncoders
for _, s := range want {
wantT, wantC := s.t, s.c
gotC, exists := got[wantT]
if !exists {
t.Errorf("Did not find type in the type registry: %v", wantT)
}
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
}
}
})
t.Run("kind", func(t *testing.T) {
k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map
rb := NewRegistryBuilder().
RegisterDefaultEncoder(k1, fc1).
RegisterDefaultEncoder(k2, fc2).
RegisterDefaultEncoder(k1, fc3).
RegisterDefaultEncoder(k4, fc4)
want := []struct {
k reflect.Kind
c ValueEncoder
}{
{k1, fc3},
{k2, fc2},
{k4, fc4},
}
got := rb.kindEncoders
for _, s := range want {
wantK, wantC := s.k, s.c
gotC, exists := got[wantK]
if !exists {
t.Errorf("Did not find kind in the kind registry: %v", wantK)
}
if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) {
t.Errorf("Codecs did not match. got %#v; want %#v", gotC, wantC)
}
}
})
t.Run("RegisterDefault", func(t *testing.T) {
t.Run("MapCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Map, codec)
if rb.kindEncoders[reflect.Map] != codec {
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec)
}
rb.RegisterDefaultEncoder(reflect.Map, codec2)
if rb.kindEncoders[reflect.Map] != codec2 {
t.Errorf("Did not properly set the map codec. got %v; want %v", rb.kindEncoders[reflect.Map], codec2)
}
})
t.Run("StructCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Struct, codec)
if rb.kindEncoders[reflect.Struct] != codec {
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec)
}
rb.RegisterDefaultEncoder(reflect.Struct, codec2)
if rb.kindEncoders[reflect.Struct] != codec2 {
t.Errorf("Did not properly set the struct codec. got %v; want %v", rb.kindEncoders[reflect.Struct], codec2)
}
})
t.Run("SliceCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Slice, codec)
if rb.kindEncoders[reflect.Slice] != codec {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec)
}
rb.RegisterDefaultEncoder(reflect.Slice, codec2)
if rb.kindEncoders[reflect.Slice] != codec2 {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Slice], codec2)
}
})
t.Run("ArrayCodec", func(t *testing.T) {
codec := fakeCodec{num: 1}
codec2 := fakeCodec{num: 2}
rb := NewRegistryBuilder()
rb.RegisterDefaultEncoder(reflect.Array, codec)
if rb.kindEncoders[reflect.Array] != codec {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec)
}
rb.RegisterDefaultEncoder(reflect.Array, codec2)
if rb.kindEncoders[reflect.Array] != codec2 {
t.Errorf("Did not properly set the slice codec. got %v; want %v", rb.kindEncoders[reflect.Array], codec2)
}
})
})
t.Run("Lookup", func(t *testing.T) {
type Codec interface {
ValueEncoder
ValueDecoder
}
var arrinstance [12]int
arr := reflect.TypeOf(arrinstance)
slc := reflect.TypeOf(make([]int, 12))
m := reflect.TypeOf(make(map[string]int))
strct := reflect.TypeOf(struct{ Foo string }{})
ft1 := reflect.PtrTo(reflect.TypeOf(fakeType1{}))
ft2 := reflect.TypeOf(fakeType2{})
ft3 := reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" }))
ti1 := reflect.TypeOf((*testInterface1)(nil)).Elem()
ti2 := reflect.TypeOf((*testInterface2)(nil)).Elem()
ti1Impl := reflect.TypeOf(testInterface1Impl{})
ti2Impl := reflect.TypeOf(testInterface2Impl{})
ti3 := reflect.TypeOf((*testInterface3)(nil)).Elem()
ti3Impl := reflect.TypeOf(testInterface3Impl{})
ti3ImplPtr := reflect.TypeOf((*testInterface3Impl)(nil))
fc1, fc2 := fakeCodec{num: 1}, fakeCodec{num: 2}
fsc, fslcc, fmc := new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec)
pc := NewPointerCodec()
reg := NewRegistryBuilder().
RegisterTypeEncoder(ft1, fc1).
RegisterTypeEncoder(ft2, fc2).
RegisterTypeEncoder(ti1, fc1).
RegisterDefaultEncoder(reflect.Struct, fsc).
RegisterDefaultEncoder(reflect.Slice, fslcc).
RegisterDefaultEncoder(reflect.Array, fslcc).
RegisterDefaultEncoder(reflect.Map, fmc).
RegisterDefaultEncoder(reflect.Ptr, pc).
RegisterTypeDecoder(ft1, fc1).
RegisterTypeDecoder(ft2, fc2).
RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder
RegisterDefaultDecoder(reflect.Struct, fsc).
RegisterDefaultDecoder(reflect.Slice, fslcc).
RegisterDefaultDecoder(reflect.Array, fslcc).
RegisterDefaultDecoder(reflect.Map, fmc).
RegisterDefaultDecoder(reflect.Ptr, pc).
RegisterHookEncoder(ti2, fc2).
RegisterHookDecoder(ti2, fc2).
RegisterHookEncoder(ti3, fc3).
RegisterHookDecoder(ti3, fc3).
Build()
testCases := []struct {
name string
t reflect.Type
wantcodec Codec
wanterr error
testcache bool
}{
{
"type registry (pointer)",
ft1,
fc1,
nil,
false,
},
{
"type registry (non-pointer)",
ft2,
fc2,
nil,
false,
},
{
// lookup an interface type and expect that the registered encoder is returned
"interface with type encoder",
ti1,
fc1,
nil,
true,
},
{
// lookup a type that implements an interface and expect that the default struct codec is returned
"interface implementation with type encoder",
ti1Impl,
fsc,
nil,
false,
},
{
// lookup an interface type and expect that the registered hook is returned
"interface with hook",
ti2,
fc2,
nil,
false,
},
{
// lookup a type that implements an interface and expect that the registered hook is returned
"interface implementation with hook",
ti2Impl,
fc2,
nil,
false,
},
{
// lookup a pointer to a type where the pointer implements an interface and expect that the
// registered hook is returned
"interface pointer to implementation with hook (pointer)",
ti3ImplPtr,
fc3,
nil,
false,
},
{
"default struct codec (pointer)",
reflect.PtrTo(strct),
pc,
nil,
false,
},
{
"default struct codec (non-pointer)",
strct,
fsc,
nil,
false,
},
{
"default array codec",
arr,
fslcc,
nil,
false,
},
{
"default slice codec",
slc,
fslcc,
nil,
false,
},
{
"default map",
m,
fmc,
nil,
false,
},
{
"map non-string key",
reflect.TypeOf(map[int]int{}),
fmc,
nil,
false,
},
{
"No Codec Registered",
ft3,
nil,
ErrNoEncoder{Type: ft3},
false,
},
}
allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{})
comparepc := func(pc1, pc2 *PointerCodec) bool { return true }
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
gotcodec, goterr := reg.LookupEncoder(tc.t)
if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Errors did not match. got %v; want %v", goterr, tc.wanterr)
}
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
}
})
t.Run("Decoder", func(t *testing.T) {
var wanterr error
if ene, ok := tc.wanterr.(ErrNoEncoder); ok {
wanterr = ErrNoDecoder(ene)
} else {
wanterr = tc.wanterr
}
gotcodec, goterr := reg.LookupDecoder(tc.t)
if !cmp.Equal(goterr, wanterr, cmp.Comparer(compareErrors)) {
t.Errorf("Errors did not match. got %v; want %v", goterr, wanterr)
}
if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("Codecs did not match. got %v; want %v", gotcodec, tc.wantcodec)
t.Errorf("Codecs did not match. got %T; want %T", gotcodec, tc.wantcodec)
}
})
})
}
// lookup a type whose pointer implements an interface and expect that the registered hook is
// returned
t.Run("interface implementation with hook (pointer)", func(t *testing.T) {
t.Run("Encoder", func(t *testing.T) {
gotEnc, err := reg.LookupEncoder(ti3Impl)
assert.Nil(t, err, "LookupEncoder error: %v", err)
cae, ok := gotEnc.(*condAddrEncoder)
assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc)
if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected canAddrEnc %v, got %v", cae.canAddrEnc, fc3)
}
if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected elseEnc %v, got %v", cae.elseEnc, fsc)
}
})
t.Run("Decoder", func(t *testing.T) {
gotDec, err := reg.LookupDecoder(ti3Impl)
assert.Nil(t, err, "LookupDecoder error: %v", err)
cad, ok := gotDec.(*condAddrDecoder)
assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec)
if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected canAddrDec %v, got %v", cad.canAddrDec, fc3)
}
if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) {
t.Errorf("expected elseDec %v, got %v", cad.elseDec, fsc)
}
})
})
})
})
t.Run("Type Map", func(t *testing.T) {
reg := NewRegistryBuilder().
RegisterTypeMapEntry(bsontype.String, reflect.TypeOf("")).
RegisterTypeMapEntry(bsontype.Int32, reflect.TypeOf(int(0))).
Build()
var got, want reflect.Type
want = reflect.TypeOf("")
got, err := reg.LookupTypeMapEntry(bsontype.String)
noerr(t, err)
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
want = reflect.TypeOf(int(0))
got, err = reg.LookupTypeMapEntry(bsontype.Int32)
noerr(t, err)
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
want = nil
wanterr := ErrNoTypeMapEntry{Type: bsontype.ObjectID}
got, err = reg.LookupTypeMapEntry(bsontype.ObjectID)
if err != wanterr {
t.Errorf("Did not get expected error. got %v; want %v", err, wanterr)
}
if got != want {
t.Errorf("Did not get expected type. got %v; want %v", got, want)
}
})
}
type fakeType1 struct{}
type fakeType2 struct{}
type fakeType4 struct{}
type fakeType5 func(string, string) string
type fakeStructCodec struct{ fakeCodec }
type fakeSliceCodec struct{ fakeCodec }
type fakeMapCodec struct{ fakeCodec }
type fakeCodec struct{ num int }
func (fc fakeCodec) EncodeValue(EncodeContext, bsonrw.ValueWriter, reflect.Value) error {
return nil
}
func (fc fakeCodec) DecodeValue(DecodeContext, bsonrw.ValueReader, reflect.Value) error {
return nil
}
type testInterface1 interface{ test1() }
type testInterface2 interface{ test2() }
type testInterface3 interface{ test3() }
type testInterface4 interface{ test4() }
type testInterface1Impl struct{}
var _ testInterface1 = testInterface1Impl{}
func (testInterface1Impl) test1() {}
type testInterface2Impl struct{}
var _ testInterface2 = testInterface2Impl{}
func (testInterface2Impl) test2() {}
type testInterface3Impl struct{}
var _ testInterface3 = (*testInterface3Impl)(nil)
func (*testInterface3Impl) test3() {}
func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 }

View File

@@ -0,0 +1,199 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var defaultSliceCodec = NewSliceCodec()
// SliceCodec is the Codec used for slice values.
type SliceCodec struct {
EncodeNilAsEmpty bool
}
var _ ValueCodec = &MapCodec{}
// NewSliceCodec returns a MapCodec with options opts.
func NewSliceCodec(opts ...*bsonoptions.SliceCodecOptions) *SliceCodec {
sliceOpt := bsonoptions.MergeSliceCodecOptions(opts...)
codec := SliceCodec{}
if sliceOpt.EncodeNilAsEmpty != nil {
codec.EncodeNilAsEmpty = *sliceOpt.EncodeNilAsEmpty
}
return &codec
}
// EncodeValue is the ValueEncoder for slice types.
func (sc SliceCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Slice {
return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
if val.IsNil() && !sc.EncodeNilAsEmpty {
return vw.WriteNull()
}
// If we have a []byte we want to treat it as a binary instead of as an array.
if val.Type().Elem() == tByte {
var byteSlice []byte
for idx := 0; idx < val.Len(); idx++ {
byteSlice = append(byteSlice, val.Index(idx).Interface().(byte))
}
return vw.WriteBinary(byteSlice)
}
// If we have a []primitive.E we want to treat it as a document instead of as an array.
if val.Type().ConvertibleTo(tD) {
d := val.Convert(tD).Interface().(primitive.D)
dw, err := vw.WriteDocument()
if err != nil {
return err
}
for _, e := range d {
err = encodeElement(ec, dw, e)
if err != nil {
return err
}
}
return dw.WriteDocumentEnd()
}
aw, err := vw.WriteArray()
if err != nil {
return err
}
elemType := val.Type().Elem()
encoder, err := ec.LookupEncoder(elemType)
if err != nil && elemType.Kind() != reflect.Interface {
return err
}
for idx := 0; idx < val.Len(); idx++ {
currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.Index(idx))
if lookupErr != nil && lookupErr != errInvalidValue {
return lookupErr
}
vw, err := aw.WriteArrayElement()
if err != nil {
return err
}
if lookupErr == errInvalidValue {
err = vw.WriteNull()
if err != nil {
return err
}
continue
}
err = currEncoder.EncodeValue(ec, vw, currVal)
if err != nil {
return err
}
}
return aw.WriteArrayEnd()
}
// DecodeValue is the ValueDecoder for slice types.
func (sc *SliceCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Slice {
return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val}
}
switch vrType := vr.Type(); vrType {
case bsontype.Array:
case bsontype.Null:
val.Set(reflect.Zero(val.Type()))
return vr.ReadNull()
case bsontype.Undefined:
val.Set(reflect.Zero(val.Type()))
return vr.ReadUndefined()
case bsontype.Type(0), bsontype.EmbeddedDocument:
if val.Type().Elem() != tE {
return fmt.Errorf("cannot decode document into %s", val.Type())
}
case bsontype.Binary:
if val.Type().Elem() != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType)
}
data, subtype, err := vr.ReadBinary()
if err != nil {
return err
}
if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld {
return fmt.Errorf("SliceDecodeValue can only be used to decode subtype 0x00 or 0x02 for %s, got %v", bsontype.Binary, subtype)
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(data)))
}
val.SetLen(0)
for _, elem := range data {
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
}
return nil
case bsontype.String:
if sliceType := val.Type().Elem(); sliceType != tByte {
return fmt.Errorf("SliceDecodeValue can only decode a string into a byte array, got %v", sliceType)
}
str, err := vr.ReadString()
if err != nil {
return err
}
byteStr := []byte(str)
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(byteStr)))
}
val.SetLen(0)
for _, elem := range byteStr {
val.Set(reflect.Append(val, reflect.ValueOf(elem)))
}
return nil
default:
return fmt.Errorf("cannot decode %v into a slice", vrType)
}
var elemsFunc func(DecodeContext, bsonrw.ValueReader, reflect.Value) ([]reflect.Value, error)
switch val.Type().Elem() {
case tE:
dc.Ancestor = val.Type()
elemsFunc = defaultValueDecoders.decodeD
default:
elemsFunc = defaultValueDecoders.decodeDefault
}
elems, err := elemsFunc(dc, vr, val)
if err != nil {
return err
}
if val.IsNil() {
val.Set(reflect.MakeSlice(val.Type(), 0, len(elems)))
}
val.SetLen(0)
val.Set(reflect.Append(val, elems...))
return nil
}

View File

@@ -0,0 +1,119 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// StringCodec is the Codec used for struct values.
type StringCodec struct {
DecodeObjectIDAsHex bool
}
var (
defaultStringCodec = NewStringCodec()
_ ValueCodec = defaultStringCodec
_ typeDecoder = defaultStringCodec
)
// NewStringCodec returns a StringCodec with options opts.
func NewStringCodec(opts ...*bsonoptions.StringCodecOptions) *StringCodec {
stringOpt := bsonoptions.MergeStringCodecOptions(opts...)
return &StringCodec{*stringOpt.DecodeObjectIDAsHex}
}
// EncodeValue is the ValueEncoder for string types.
func (sc *StringCodec) EncodeValue(ectx EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if val.Kind() != reflect.String {
return ValueEncoderError{
Name: "StringEncodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: val,
}
}
return vw.WriteString(val.String())
}
func (sc *StringCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t.Kind() != reflect.String {
return emptyValue, ValueDecoderError{
Name: "StringDecodeValue",
Kinds: []reflect.Kind{reflect.String},
Received: reflect.Zero(t),
}
}
var str string
var err error
switch vr.Type() {
case bsontype.String:
str, err = vr.ReadString()
if err != nil {
return emptyValue, err
}
case bsontype.ObjectID:
oid, err := vr.ReadObjectID()
if err != nil {
return emptyValue, err
}
if sc.DecodeObjectIDAsHex {
str = oid.Hex()
} else {
byteArray := [12]byte(oid)
str = string(byteArray[:])
}
case bsontype.Symbol:
str, err = vr.ReadSymbol()
if err != nil {
return emptyValue, err
}
case bsontype.Binary:
data, subtype, err := vr.ReadBinary()
if err != nil {
return emptyValue, err
}
if subtype != bsontype.BinaryGeneric && subtype != bsontype.BinaryBinaryOld {
return emptyValue, decodeBinaryError{subtype: subtype, typeName: "string"}
}
str = string(data)
case bsontype.Null:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case bsontype.Undefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a string type", vr.Type())
}
return reflect.ValueOf(str), nil
}
// DecodeValue is the ValueDecoder for string types.
func (sc *StringCodec) DecodeValue(dctx DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.String {
return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val}
}
elem, err := sc.decodeType(dctx, vr, val.Type())
if err != nil {
return err
}
val.SetString(elem.String())
return nil
}

View File

@@ -0,0 +1,48 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
)
func TestStringCodec(t *testing.T) {
t.Run("ObjectIDAsHex", func(t *testing.T) {
oid := primitive.NewObjectID()
byteArray := [12]byte(oid)
reader := &bsonrwtest.ValueReaderWriter{BSONType: bsontype.ObjectID, Return: oid}
testCases := []struct {
name string
opts *bsonoptions.StringCodecOptions
hex bool
result string
}{
{"default", bsonoptions.StringCodec(), true, oid.Hex()},
{"true", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(true), true, oid.Hex()},
{"false", bsonoptions.StringCodec().SetDecodeObjectIDAsHex(false), false, string(byteArray[:])},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
stringCodec := NewStringCodec(tc.opts)
actual := reflect.New(reflect.TypeOf("")).Elem()
err := stringCodec.DecodeValue(DecodeContext{}, reader, actual)
assert.Nil(t, err, "StringCodec.DecodeValue error: %v", err)
actualString := actual.Interface().(string)
assert.Equal(t, tc.result, actualString, "Expected string %v, got %v", tc.result, actualString)
})
}
})
}

View File

@@ -0,0 +1,675 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"sync"
"time"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// DecodeError represents an error that occurs when unmarshalling BSON bytes into a native Go type.
type DecodeError struct {
keys []string
wrapped error
}
// Unwrap returns the underlying error
func (de *DecodeError) Unwrap() error {
return de.wrapped
}
// Error implements the error interface.
func (de *DecodeError) Error() string {
// The keys are stored in reverse order because the de.keys slice is builtup while propagating the error up the
// stack of BSON keys, so we call de.Keys(), which reverses them.
keyPath := strings.Join(de.Keys(), ".")
return fmt.Sprintf("error decoding key %s: %v", keyPath, de.wrapped)
}
// Keys returns the BSON key path that caused an error as a slice of strings. The keys in the slice are in top-down
// order. For example, if the document being unmarshalled was {a: {b: {c: 1}}} and the value for c was supposed to be
// a string, the keys slice will be ["a", "b", "c"].
func (de *DecodeError) Keys() []string {
reversedKeys := make([]string, 0, len(de.keys))
for idx := len(de.keys) - 1; idx >= 0; idx-- {
reversedKeys = append(reversedKeys, de.keys[idx])
}
return reversedKeys
}
// Zeroer allows custom struct types to implement a report of zero
// state. All struct types that don't implement Zeroer or where IsZero
// returns false are considered to be not zero.
type Zeroer interface {
IsZero() bool
}
// StructCodec is the Codec used for struct values.
type StructCodec struct {
cache map[reflect.Type]*structDescription
l sync.RWMutex
parser StructTagParser
DecodeZeroStruct bool
DecodeDeepZeroInline bool
EncodeOmitDefaultStruct bool
AllowUnexportedFields bool
OverwriteDuplicatedInlinedFields bool
}
var _ ValueEncoder = &StructCodec{}
var _ ValueDecoder = &StructCodec{}
// NewStructCodec returns a StructCodec that uses p for struct tag parsing.
func NewStructCodec(p StructTagParser, opts ...*bsonoptions.StructCodecOptions) (*StructCodec, error) {
if p == nil {
return nil, errors.New("a StructTagParser must be provided to NewStructCodec")
}
structOpt := bsonoptions.MergeStructCodecOptions(opts...)
codec := &StructCodec{
cache: make(map[reflect.Type]*structDescription),
parser: p,
}
if structOpt.DecodeZeroStruct != nil {
codec.DecodeZeroStruct = *structOpt.DecodeZeroStruct
}
if structOpt.DecodeDeepZeroInline != nil {
codec.DecodeDeepZeroInline = *structOpt.DecodeDeepZeroInline
}
if structOpt.EncodeOmitDefaultStruct != nil {
codec.EncodeOmitDefaultStruct = *structOpt.EncodeOmitDefaultStruct
}
if structOpt.OverwriteDuplicatedInlinedFields != nil {
codec.OverwriteDuplicatedInlinedFields = *structOpt.OverwriteDuplicatedInlinedFields
}
if structOpt.AllowUnexportedFields != nil {
codec.AllowUnexportedFields = *structOpt.AllowUnexportedFields
}
return codec, nil
}
// EncodeValue handles encoding generic struct types.
func (sc *StructCodec) EncodeValue(r EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Kind() != reflect.Struct {
return ValueEncoderError{Name: "StructCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
sd, err := sc.describeStruct(r.Registry, val.Type())
if err != nil {
return err
}
dw, err := vw.WriteDocument()
if err != nil {
return err
}
var rv reflect.Value
for _, desc := range sd.fl {
if desc.omitAlways {
continue
}
if desc.inline == nil {
rv = val.Field(desc.idx)
} else {
rv, err = fieldByIndexErr(val, desc.inline)
if err != nil {
continue
}
}
desc.encoder, rv, err = defaultValueEncoders.lookupElementEncoder(r, desc.encoder, rv)
if err != nil && err != errInvalidValue {
return err
}
if err == errInvalidValue {
if desc.omitEmpty {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
err = vw2.WriteNull()
if err != nil {
return err
}
continue
}
if desc.encoder == nil {
return ErrNoEncoder{Type: rv.Type()}
}
encoder := desc.encoder
var isZero bool
rvInterface := rv.Interface()
if cz, ok := encoder.(CodecZeroer); ok {
isZero = cz.IsTypeZero(rvInterface)
} else if rv.Kind() == reflect.Interface {
// sc.isZero will not treat an interface rv as an interface, so we need to check for the zero interface separately.
isZero = rv.IsNil()
} else {
isZero = sc.isZero(rvInterface)
}
if desc.omitEmpty && isZero {
continue
}
vw2, err := dw.WriteDocumentElement(desc.name)
if err != nil {
return err
}
ectx := EncodeContext{Registry: r.Registry, MinSize: desc.minSize}
err = encoder.EncodeValue(ectx, vw2, rv)
if err != nil {
return err
}
}
if sd.inlineMap >= 0 {
rv := val.Field(sd.inlineMap)
collisionFn := func(key string) bool {
_, exists := sd.fm[key]
return exists
}
return defaultMapCodec.mapEncodeValue(r, dw, rv, collisionFn)
}
return dw.WriteDocumentEnd()
}
func newDecodeError(key string, original error) error {
de, ok := original.(*DecodeError)
if !ok {
return &DecodeError{
keys: []string{key},
wrapped: original,
}
}
de.keys = append(de.keys, key)
return de
}
// DecodeValue implements the Codec interface.
// By default, map types in val will not be cleared. If a map has existing key/value pairs, it will be extended with the new ones from vr.
// For slices, the decoder will set the length of the slice to zero and append all elements. The underlying array will not be cleared.
func (sc *StructCodec) DecodeValue(r DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Kind() != reflect.Struct {
return ValueDecoderError{Name: "StructCodec.DecodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val}
}
switch vrType := vr.Type(); vrType {
case bsontype.Type(0), bsontype.EmbeddedDocument:
case bsontype.Null:
if err := vr.ReadNull(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
case bsontype.Undefined:
if err := vr.ReadUndefined(); err != nil {
return err
}
val.Set(reflect.Zero(val.Type()))
return nil
default:
return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
}
sd, err := sc.describeStruct(r.Registry, val.Type())
if err != nil {
return err
}
if sc.DecodeZeroStruct {
val.Set(reflect.Zero(val.Type()))
}
if sc.DecodeDeepZeroInline && sd.inline {
val.Set(deepZero(val.Type()))
}
var decoder ValueDecoder
var inlineMap reflect.Value
if sd.inlineMap >= 0 {
inlineMap = val.Field(sd.inlineMap)
decoder, err = r.LookupDecoder(inlineMap.Type().Elem())
if err != nil {
return err
}
}
dr, err := vr.ReadDocument()
if err != nil {
return err
}
for {
name, vr, err := dr.ReadElement()
if err == bsonrw.ErrEOD {
break
}
if err != nil {
return err
}
fd, exists := sd.fm[name]
if !exists {
// if the original name isn't found in the struct description, try again with the name in lowercase
// this could match if a BSON tag isn't specified because by default, describeStruct lowercases all field
// names
fd, exists = sd.fm[strings.ToLower(name)]
}
if !exists {
if sd.inlineMap < 0 {
// The encoding/json package requires a flag to return on error for non-existent fields.
// This functionality seems appropriate for the struct codec.
err = vr.Skip()
if err != nil {
return err
}
continue
}
if inlineMap.IsNil() {
inlineMap.Set(reflect.MakeMap(inlineMap.Type()))
}
elem := reflect.New(inlineMap.Type().Elem()).Elem()
r.Ancestor = inlineMap.Type()
err = decoder.DecodeValue(r, vr, elem)
if err != nil {
return err
}
inlineMap.SetMapIndex(reflect.ValueOf(name), elem)
continue
}
var field reflect.Value
if fd.inline == nil {
field = val.Field(fd.idx)
} else {
field, err = getInlineField(val, fd.inline)
if err != nil {
return err
}
}
if !field.CanSet() { // Being settable is a super set of being addressable.
innerErr := fmt.Errorf("field %v is not settable", field)
return newDecodeError(fd.name, innerErr)
}
if field.Kind() == reflect.Ptr && field.IsNil() {
field.Set(reflect.New(field.Type().Elem()))
}
field = field.Addr()
dctx := DecodeContext{
Registry: r.Registry,
Truncate: fd.truncate || r.Truncate,
defaultDocumentType: r.defaultDocumentType,
}
if fd.decoder == nil {
return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()})
}
err = fd.decoder.DecodeValue(dctx, vr, field.Elem())
if err != nil {
return newDecodeError(fd.name, err)
}
}
return nil
}
func (sc *StructCodec) isZero(i interface{}) bool {
v := reflect.ValueOf(i)
// check the value validity
if !v.IsValid() {
return true
}
if z, ok := v.Interface().(Zeroer); ok && (v.Kind() != reflect.Ptr || !v.IsNil()) {
return z.IsZero()
}
switch v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
case reflect.Interface, reflect.Ptr:
return v.IsNil()
case reflect.Struct:
if sc.EncodeOmitDefaultStruct {
vt := v.Type()
if vt == tTime {
return v.Interface().(time.Time).IsZero()
}
for i := 0; i < v.NumField(); i++ {
if vt.Field(i).PkgPath != "" && !vt.Field(i).Anonymous {
continue // Private field
}
fld := v.Field(i)
if !sc.isZero(fld.Interface()) {
return false
}
}
return true
}
}
return false
}
type structDescription struct {
fm map[string]fieldDescription
fl []fieldDescription
inlineMap int
inline bool
}
type fieldDescription struct {
name string // BSON key name
fieldName string // struct field name
idx int
omitEmpty bool
omitAlways bool
minSize bool
truncate bool
inline []int
encoder ValueEncoder
decoder ValueDecoder
}
type byIndex []fieldDescription
func (bi byIndex) Len() int { return len(bi) }
func (bi byIndex) Swap(i, j int) { bi[i], bi[j] = bi[j], bi[i] }
func (bi byIndex) Less(i, j int) bool {
// If a field is inlined, its index in the top level struct is stored at inline[0]
iIdx, jIdx := bi[i].idx, bi[j].idx
if len(bi[i].inline) > 0 {
iIdx = bi[i].inline[0]
}
if len(bi[j].inline) > 0 {
jIdx = bi[j].inline[0]
}
if iIdx != jIdx {
return iIdx < jIdx
}
for k, biik := range bi[i].inline {
if k >= len(bi[j].inline) {
return false
}
if biik != bi[j].inline[k] {
return biik < bi[j].inline[k]
}
}
return len(bi[i].inline) < len(bi[j].inline)
}
func (sc *StructCodec) describeStruct(r *Registry, t reflect.Type) (*structDescription, error) {
// We need to analyze the struct, including getting the tags, collecting
// information about inlining, and create a map of the field name to the field.
sc.l.RLock()
ds, exists := sc.cache[t]
sc.l.RUnlock()
if exists {
return ds, nil
}
numFields := t.NumField()
sd := &structDescription{
fm: make(map[string]fieldDescription, numFields),
fl: make([]fieldDescription, 0, numFields),
inlineMap: -1,
}
var fields []fieldDescription
for i := 0; i < numFields; i++ {
sf := t.Field(i)
if sf.PkgPath != "" && (!sc.AllowUnexportedFields || !sf.Anonymous) {
// field is private or unexported fields aren't allowed, ignore
continue
}
sfType := sf.Type
encoder, err := r.LookupEncoder(sfType)
if err != nil {
encoder = nil
}
decoder, err := r.LookupDecoder(sfType)
if err != nil {
decoder = nil
}
description := fieldDescription{
fieldName: sf.Name,
idx: i,
encoder: encoder,
decoder: decoder,
}
stags, err := sc.parser.ParseStructTags(sf)
if err != nil {
return nil, err
}
if stags.Skip {
continue
}
description.name = stags.Name
description.omitEmpty = stags.OmitEmpty
description.omitAlways = stags.OmitAlways
description.minSize = stags.MinSize
description.truncate = stags.Truncate
if stags.Inline {
sd.inline = true
switch sfType.Kind() {
case reflect.Map:
if sd.inlineMap >= 0 {
return nil, errors.New("(struct " + t.String() + ") multiple inline maps")
}
if sfType.Key() != tString {
return nil, errors.New("(struct " + t.String() + ") inline map must have a string keys")
}
sd.inlineMap = description.idx
case reflect.Ptr:
sfType = sfType.Elem()
if sfType.Kind() != reflect.Struct {
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
fallthrough
case reflect.Struct:
inlinesf, err := sc.describeStruct(r, sfType)
if err != nil {
return nil, err
}
for _, fd := range inlinesf.fl {
if fd.inline == nil {
fd.inline = []int{i, fd.idx}
} else {
fd.inline = append([]int{i}, fd.inline...)
}
fields = append(fields, fd)
}
default:
return nil, fmt.Errorf("(struct %s) inline fields must be a struct, a struct pointer, or a map", t.String())
}
continue
}
fields = append(fields, description)
}
// Sort fieldDescriptions by name and use dominance rules to determine which should be added for each name
sort.Slice(fields, func(i, j int) bool {
x := fields
// sort field by name, breaking ties with depth, then
// breaking ties with index sequence.
if x[i].name != x[j].name {
return x[i].name < x[j].name
}
if len(x[i].inline) != len(x[j].inline) {
return len(x[i].inline) < len(x[j].inline)
}
return byIndex(x).Less(i, j)
})
for advance, i := 0, 0; i < len(fields); i += advance {
// One iteration per name.
// Find the sequence of fields with the name of this first field.
fi := fields[i]
name := fi.name
for advance = 1; i+advance < len(fields); advance++ {
fj := fields[i+advance]
if fj.name != name {
break
}
}
if advance == 1 { // Only one field with this name
sd.fl = append(sd.fl, fi)
sd.fm[name] = fi
continue
}
dominant, ok := dominantField(fields[i : i+advance])
if !ok || !sc.OverwriteDuplicatedInlinedFields {
return nil, fmt.Errorf("struct %s has duplicated key %s", t.String(), name)
}
sd.fl = append(sd.fl, dominant)
sd.fm[name] = dominant
}
sort.Sort(byIndex(sd.fl))
sc.l.Lock()
sc.cache[t] = sd
sc.l.Unlock()
return sd, nil
}
// dominantField looks through the fields, all of which are known to
// have the same name, to find the single field that dominates the
// others using Go's inlining rules. If there are multiple top-level
// fields, the boolean will be false: This condition is an error in Go
// and we skip all the fields.
func dominantField(fields []fieldDescription) (fieldDescription, bool) {
// The fields are sorted in increasing index-length order, then by presence of tag.
// That means that the first field is the dominant one. We need only check
// for error cases: two fields at top level.
if len(fields) > 1 &&
len(fields[0].inline) == len(fields[1].inline) {
return fieldDescription{}, false
}
return fields[0], true
}
func fieldByIndexErr(v reflect.Value, index []int) (result reflect.Value, err error) {
defer func() {
if recovered := recover(); recovered != nil {
switch r := recovered.(type) {
case string:
err = fmt.Errorf("%s", r)
case error:
err = r
}
}
}()
result = v.FieldByIndex(index)
return
}
func getInlineField(val reflect.Value, index []int) (reflect.Value, error) {
field, err := fieldByIndexErr(val, index)
if err == nil {
return field, nil
}
// if parent of this element doesn't exist, fix its parent
inlineParent := index[:len(index)-1]
var fParent reflect.Value
if fParent, err = fieldByIndexErr(val, inlineParent); err != nil {
fParent, err = getInlineField(val, inlineParent)
if err != nil {
return fParent, err
}
}
fParent.Set(reflect.New(fParent.Type().Elem()))
return fieldByIndexErr(val, index)
}
// DeepZero returns recursive zero object
func deepZero(st reflect.Type) (result reflect.Value) {
result = reflect.Indirect(reflect.New(st))
if result.Kind() == reflect.Struct {
for i := 0; i < result.NumField(); i++ {
if f := result.Field(i); f.Kind() == reflect.Ptr {
if f.CanInterface() {
if ft := reflect.TypeOf(f.Interface()); ft.Elem().Kind() == reflect.Struct {
result.Field(i).Set(recursivePointerTo(deepZero(ft.Elem())))
}
}
}
}
}
return
}
// recursivePointerTo calls reflect.New(v.Type) but recursively for its fields inside
func recursivePointerTo(v reflect.Value) reflect.Value {
v = reflect.Indirect(v)
result := reflect.New(v.Type())
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
if f := v.Field(i); f.Kind() == reflect.Ptr {
if f.Elem().Kind() == reflect.Struct {
result.Elem().Field(i).Set(recursivePointerTo(f))
}
}
}
}
return result
}

View File

@@ -0,0 +1,47 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestZeoerInterfaceUsedByDecoder(t *testing.T) {
enc := &StructCodec{}
// cases that are zero, because they are known types or pointers
var st *nonZeroer
assert.True(t, enc.isZero(st))
assert.True(t, enc.isZero(0))
assert.True(t, enc.isZero(false))
// cases that shouldn't be zero
st = &nonZeroer{value: false}
assert.False(t, enc.isZero(struct{ val bool }{val: true}))
assert.False(t, enc.isZero(struct{ val bool }{val: false}))
assert.False(t, enc.isZero(st))
st.value = true
assert.False(t, enc.isZero(st))
// a test to see if the interface impacts the outcome
z := zeroTest{}
assert.False(t, enc.isZero(z))
z.reportZero = true
assert.True(t, enc.isZero(z))
// *time.Time with nil should be zero
var tp *time.Time
assert.True(t, enc.isZero(tp))
// actually all zeroer if nil should also be zero
var zp *zeroTest
assert.True(t, enc.isZero(zp))
}

View File

@@ -0,0 +1,142 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"strings"
)
// StructTagParser returns the struct tags for a given struct field.
type StructTagParser interface {
ParseStructTags(reflect.StructField) (StructTags, error)
}
// StructTagParserFunc is an adapter that allows a generic function to be used
// as a StructTagParser.
type StructTagParserFunc func(reflect.StructField) (StructTags, error)
// ParseStructTags implements the StructTagParser interface.
func (stpf StructTagParserFunc) ParseStructTags(sf reflect.StructField) (StructTags, error) {
return stpf(sf)
}
// StructTags represents the struct tag fields that the StructCodec uses during
// the encoding and decoding process.
//
// In the case of a struct, the lowercased field name is used as the key for each exported
// field but this behavior may be changed using a struct tag. The tag may also contain flags to
// adjust the marshalling behavior for the field.
//
// The properties are defined below:
//
// OmitEmpty Only include the field if it's not set to the zero value for the type or to
// empty slices or maps.
//
// MinSize Marshal an integer of a type larger than 32 bits value as an int32, if that's
// feasible while preserving the numeric value.
//
// Truncate When unmarshaling a BSON double, it is permitted to lose precision to fit within
// a float32.
//
// Inline Inline the field, which must be a struct or a map, causing all of its fields
// or keys to be processed as if they were part of the outer struct. For maps,
// keys must not conflict with the bson keys of other struct fields.
//
// Skip This struct field should be skipped. This is usually denoted by parsing a "-"
// for the name.
//
// TODO(skriptble): Add tags for undefined as nil and for null as nil.
type StructTags struct {
Name string
OmitEmpty bool
OmitAlways bool
MinSize bool
Truncate bool
Inline bool
Skip bool
}
// DefaultStructTagParser is the StructTagParser used by the StructCodec by default.
// It will handle the bson struct tag. See the documentation for StructTags to see
// what each of the returned fields means.
//
// If there is no name in the struct tag fields, the struct field name is lowercased.
// The tag formats accepted are:
//
// "[<key>][,<flag1>[,<flag2>]]"
//
// `(...) bson:"[<key>][,<flag1>[,<flag2>]]" (...)`
//
// An example:
//
// type T struct {
// A bool
// B int "myb"
// C string "myc,omitempty"
// D string `bson:",omitempty" json:"jsonkey"`
// E int64 ",minsize"
// F int64 "myf,omitempty,minsize"
// }
//
// A struct tag either consisting entirely of '-' or with a bson key with a
// value consisting entirely of '-' will return a StructTags with Skip true and
// the remaining fields will be their default values.
var DefaultStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}
func parseTags(key string, tag string) (StructTags, error) {
var st StructTags
if tag == "-" {
st.Skip = true
return st, nil
}
for idx, str := range strings.Split(tag, ",") {
if idx == 0 && str != "" {
key = str
}
switch str {
case "omitempty":
st.OmitEmpty = true
case "omitalways":
st.OmitAlways = true
case "minsize":
st.MinSize = true
case "truncate":
st.Truncate = true
case "inline":
st.Inline = true
}
}
st.Name = key
return st, nil
}
// JSONFallbackStructTagParser has the same behavior as DefaultStructTagParser
// but will also fallback to parsing the json tag instead on a field where the
// bson tag isn't available.
var JSONFallbackStructTagParser StructTagParserFunc = func(sf reflect.StructField) (StructTags, error) {
key := strings.ToLower(sf.Name)
tag, ok := sf.Tag.Lookup("bson")
if !ok {
tag, ok = sf.Tag.Lookup("json")
}
if !ok && !strings.Contains(string(sf.Tag), ":") && len(sf.Tag) > 0 {
tag = string(sf.Tag)
}
return parseTags(key, tag)
}

View File

@@ -0,0 +1,160 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestStructTagParsers(t *testing.T) {
testCases := []struct {
name string
sf reflect.StructField
want StructTags
parser StructTagParserFunc
}{
{
"default no bson tag",
reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")},
StructTags{Name: "bar"},
DefaultStructTagParser,
},
{
"default empty",
reflect.StructField{Name: "foo", Tag: reflect.StructTag("")},
StructTags{Name: "foo"},
DefaultStructTagParser,
},
{
"default tag only dash",
reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")},
StructTags{Skip: true},
DefaultStructTagParser,
},
{
"default bson tag only dash",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)},
StructTags{Skip: true},
DefaultStructTagParser,
},
{
"default all options",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)},
StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
DefaultStructTagParser,
},
{
"default all options default name",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)},
StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
DefaultStructTagParser,
},
{
"default bson tag all options",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)},
StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
DefaultStructTagParser,
},
{
"default bson tag all options default name",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)},
StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
DefaultStructTagParser,
},
{
"default ignore xml",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)},
StructTags{Name: "foo"},
DefaultStructTagParser,
},
{
"JSONFallback no bson tag",
reflect.StructField{Name: "foo", Tag: reflect.StructTag("bar")},
StructTags{Name: "bar"},
JSONFallbackStructTagParser,
},
{
"JSONFallback empty",
reflect.StructField{Name: "foo", Tag: reflect.StructTag("")},
StructTags{Name: "foo"},
JSONFallbackStructTagParser,
},
{
"JSONFallback tag only dash",
reflect.StructField{Name: "foo", Tag: reflect.StructTag("-")},
StructTags{Skip: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback bson tag only dash",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"-"`)},
StructTags{Skip: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback all options",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bar,omitempty,minsize,truncate,inline`)},
StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback all options default name",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`,omitempty,minsize,truncate,inline`)},
StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback bson tag all options",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar,omitempty,minsize,truncate,inline"`)},
StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback bson tag all options default name",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:",omitempty,minsize,truncate,inline"`)},
StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback json tag all options",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:"bar,omitempty,minsize,truncate,inline"`)},
StructTags{Name: "bar", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback json tag all options default name",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`json:",omitempty,minsize,truncate,inline"`)},
StructTags{Name: "foo", OmitEmpty: true, MinSize: true, Truncate: true, Inline: true},
JSONFallbackStructTagParser,
},
{
"JSONFallback bson tag overrides other tags",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`bson:"bar" json:"qux,truncate"`)},
StructTags{Name: "bar"},
JSONFallbackStructTagParser,
},
{
"JSONFallback ignore xml",
reflect.StructField{Name: "foo", Tag: reflect.StructTag(`xml:"bar"`)},
StructTags{Name: "foo"},
JSONFallbackStructTagParser,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.parser(tc.sf)
noerr(t, err)
if !cmp.Equal(got, tc.want) {
t.Errorf("Returned struct tags do not match. got %#v; want %#v", got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,127 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
)
const (
timeFormatString = "2006-01-02T15:04:05.999Z07:00"
)
// TimeCodec is the Codec used for time.Time values.
type TimeCodec struct {
UseLocalTimeZone bool
}
var (
defaultTimeCodec = NewTimeCodec()
_ ValueCodec = defaultTimeCodec
_ typeDecoder = defaultTimeCodec
)
// NewTimeCodec returns a TimeCodec with options opts.
func NewTimeCodec(opts ...*bsonoptions.TimeCodecOptions) *TimeCodec {
timeOpt := bsonoptions.MergeTimeCodecOptions(opts...)
codec := TimeCodec{}
if timeOpt.UseLocalTimeZone != nil {
codec.UseLocalTimeZone = *timeOpt.UseLocalTimeZone
}
return &codec
}
func (tc *TimeCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
if t != tTime {
return emptyValue, ValueDecoderError{
Name: "TimeDecodeValue",
Types: []reflect.Type{tTime},
Received: reflect.Zero(t),
}
}
var timeVal time.Time
switch vrType := vr.Type(); vrType {
case bsontype.DateTime:
dt, err := vr.ReadDateTime()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(dt/1000, dt%1000*1000000)
case bsontype.String:
// assume strings are in the isoTimeFormat
timeStr, err := vr.ReadString()
if err != nil {
return emptyValue, err
}
timeVal, err = time.Parse(timeFormatString, timeStr)
if err != nil {
return emptyValue, err
}
case bsontype.Int64:
i64, err := vr.ReadInt64()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(i64/1000, i64%1000*1000000)
case bsontype.Timestamp:
t, _, err := vr.ReadTimestamp()
if err != nil {
return emptyValue, err
}
timeVal = time.Unix(int64(t), 0)
case bsontype.Null:
if err := vr.ReadNull(); err != nil {
return emptyValue, err
}
case bsontype.Undefined:
if err := vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into a time.Time", vrType)
}
if !tc.UseLocalTimeZone {
timeVal = timeVal.UTC()
}
return reflect.ValueOf(timeVal), nil
}
// DecodeValue is the ValueDecoderFunc for time.Time.
func (tc *TimeCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() || val.Type() != tTime {
return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val}
}
elem, err := tc.decodeType(dc, vr, tTime)
if err != nil {
return err
}
val.Set(elem)
return nil
}
// EncodeValue is the ValueEncoderFunc for time.TIme.
func (tc *TimeCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tTime {
return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val}
}
tt := val.Interface().(time.Time)
dt := primitive.NewDateTimeFromTime(tt)
return vw.WriteDateTime(int64(dt))
}

View File

@@ -0,0 +1,79 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"reflect"
"testing"
"time"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw/bsonrwtest"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/internal/testutil/assert"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
func TestTimeCodec(t *testing.T) {
now := time.Now().Truncate(time.Millisecond)
t.Run("UseLocalTimeZone", func(t *testing.T) {
reader := &bsonrwtest.ValueReaderWriter{BSONType: bsontype.DateTime, Return: now.UnixNano() / int64(time.Millisecond)}
testCases := []struct {
name string
opts *bsonoptions.TimeCodecOptions
utc bool
}{
{"default", bsonoptions.TimeCodec(), true},
{"false", bsonoptions.TimeCodec().SetUseLocalTimeZone(false), true},
{"true", bsonoptions.TimeCodec().SetUseLocalTimeZone(true), false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
timeCodec := NewTimeCodec(tc.opts)
actual := reflect.New(reflect.TypeOf(now)).Elem()
err := timeCodec.DecodeValue(DecodeContext{}, reader, actual)
assert.Nil(t, err, "TimeCodec.DecodeValue error: %v", err)
actualTime := actual.Interface().(time.Time)
assert.Equal(t, actualTime.Location().String() == "UTC", tc.utc,
"Expected UTC: %v, got %v", tc.utc, actualTime.Location())
assert.Equal(t, now, actualTime, "expected time %v, got %v", now, actualTime)
})
}
})
t.Run("DecodeFromBsontype", func(t *testing.T) {
testCases := []struct {
name string
reader *bsonrwtest.ValueReaderWriter
}{
{"string", &bsonrwtest.ValueReaderWriter{BSONType: bsontype.String, Return: now.Format(timeFormatString)}},
{"int64", &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Int64, Return: now.Unix()*1000 + int64(now.Nanosecond()/1e6)}},
{"timestamp", &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Timestamp,
Return: bsoncore.Value{
Type: bsontype.Timestamp,
Data: bsoncore.AppendTimestamp(nil, uint32(now.Unix()), 0),
}},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
actual := reflect.New(reflect.TypeOf(now)).Elem()
err := defaultTimeCodec.DecodeValue(DecodeContext{}, tc.reader, actual)
assert.Nil(t, err, "DecodeValue error: %v", err)
actualTime := actual.Interface().(time.Time)
if tc.name == "timestamp" {
now = time.Unix(now.Unix(), 0)
}
assert.Equal(t, now, actualTime, "expected time %v, got %v", now, actualTime)
})
}
})
}

View File

@@ -0,0 +1,57 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"encoding/json"
"net/url"
"reflect"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
)
var tBool = reflect.TypeOf(false)
var tFloat64 = reflect.TypeOf(float64(0))
var tInt32 = reflect.TypeOf(int32(0))
var tInt64 = reflect.TypeOf(int64(0))
var tString = reflect.TypeOf("")
var tTime = reflect.TypeOf(time.Time{})
var tEmpty = reflect.TypeOf((*interface{})(nil)).Elem()
var tByteSlice = reflect.TypeOf([]byte(nil))
var tByte = reflect.TypeOf(byte(0x00))
var tURL = reflect.TypeOf(url.URL{})
var tJSONNumber = reflect.TypeOf(json.Number(""))
var tValueMarshaler = reflect.TypeOf((*ValueMarshaler)(nil)).Elem()
var tValueUnmarshaler = reflect.TypeOf((*ValueUnmarshaler)(nil)).Elem()
var tMarshaler = reflect.TypeOf((*Marshaler)(nil)).Elem()
var tUnmarshaler = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
var tProxy = reflect.TypeOf((*Proxy)(nil)).Elem()
var tBinary = reflect.TypeOf(primitive.Binary{})
var tUndefined = reflect.TypeOf(primitive.Undefined{})
var tOID = reflect.TypeOf(primitive.ObjectID{})
var tDateTime = reflect.TypeOf(primitive.DateTime(0))
var tNull = reflect.TypeOf(primitive.Null{})
var tRegex = reflect.TypeOf(primitive.Regex{})
var tCodeWithScope = reflect.TypeOf(primitive.CodeWithScope{})
var tDBPointer = reflect.TypeOf(primitive.DBPointer{})
var tJavaScript = reflect.TypeOf(primitive.JavaScript(""))
var tSymbol = reflect.TypeOf(primitive.Symbol(""))
var tTimestamp = reflect.TypeOf(primitive.Timestamp{})
var tDecimal = reflect.TypeOf(primitive.Decimal128{})
var tMinKey = reflect.TypeOf(primitive.MinKey{})
var tMaxKey = reflect.TypeOf(primitive.MaxKey{})
var tD = reflect.TypeOf(primitive.D{})
var tA = reflect.TypeOf(primitive.A{})
var tE = reflect.TypeOf(primitive.E{})
var tCoreDocument = reflect.TypeOf(bsoncore.Document{})
var tCoreArray = reflect.TypeOf(bsoncore.Array{})

View File

@@ -0,0 +1,173 @@
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package bsoncodec
import (
"fmt"
"math"
"reflect"
"go.mongodb.org/mongo-driver/bson/bsonoptions"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
)
// UIntCodec is the Codec used for uint values.
type UIntCodec struct {
EncodeToMinSize bool
}
var (
defaultUIntCodec = NewUIntCodec()
_ ValueCodec = defaultUIntCodec
_ typeDecoder = defaultUIntCodec
)
// NewUIntCodec returns a UIntCodec with options opts.
func NewUIntCodec(opts ...*bsonoptions.UIntCodecOptions) *UIntCodec {
uintOpt := bsonoptions.MergeUIntCodecOptions(opts...)
codec := UIntCodec{}
if uintOpt.EncodeToMinSize != nil {
codec.EncodeToMinSize = *uintOpt.EncodeToMinSize
}
return &codec
}
// EncodeValue is the ValueEncoder for uint types.
func (uic *UIntCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
switch val.Kind() {
case reflect.Uint8, reflect.Uint16:
return vw.WriteInt32(int32(val.Uint()))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
u64 := val.Uint()
// If ec.MinSize or if encodeToMinSize is true for a non-uint64 value we should write val as an int32
useMinSize := ec.MinSize || (uic.EncodeToMinSize && val.Kind() != reflect.Uint64)
if u64 <= math.MaxInt32 && useMinSize {
return vw.WriteInt32(int32(u64))
}
if u64 > math.MaxInt64 {
return fmt.Errorf("%d overflows int64", u64)
}
return vw.WriteInt64(int64(u64))
}
return ValueEncoderError{
Name: "UintEncodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
func (uic *UIntCodec) decodeType(dc DecodeContext, vr bsonrw.ValueReader, t reflect.Type) (reflect.Value, error) {
var i64 int64
var err error
switch vrType := vr.Type(); vrType {
case bsontype.Int32:
i32, err := vr.ReadInt32()
if err != nil {
return emptyValue, err
}
i64 = int64(i32)
case bsontype.Int64:
i64, err = vr.ReadInt64()
if err != nil {
return emptyValue, err
}
case bsontype.Double:
f64, err := vr.ReadDouble()
if err != nil {
return emptyValue, err
}
if !dc.Truncate && math.Floor(f64) != f64 {
return emptyValue, errCannotTruncate
}
if f64 > float64(math.MaxInt64) {
return emptyValue, fmt.Errorf("%g overflows int64", f64)
}
i64 = int64(f64)
case bsontype.Boolean:
b, err := vr.ReadBoolean()
if err != nil {
return emptyValue, err
}
if b {
i64 = 1
}
case bsontype.Null:
if err = vr.ReadNull(); err != nil {
return emptyValue, err
}
case bsontype.Undefined:
if err = vr.ReadUndefined(); err != nil {
return emptyValue, err
}
default:
return emptyValue, fmt.Errorf("cannot decode %v into an integer type", vrType)
}
switch t.Kind() {
case reflect.Uint8:
if i64 < 0 || i64 > math.MaxUint8 {
return emptyValue, fmt.Errorf("%d overflows uint8", i64)
}
return reflect.ValueOf(uint8(i64)), nil
case reflect.Uint16:
if i64 < 0 || i64 > math.MaxUint16 {
return emptyValue, fmt.Errorf("%d overflows uint16", i64)
}
return reflect.ValueOf(uint16(i64)), nil
case reflect.Uint32:
if i64 < 0 || i64 > math.MaxUint32 {
return emptyValue, fmt.Errorf("%d overflows uint32", i64)
}
return reflect.ValueOf(uint32(i64)), nil
case reflect.Uint64:
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint64", i64)
}
return reflect.ValueOf(uint64(i64)), nil
case reflect.Uint:
if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
return reflect.ValueOf(uint(i64)), nil
default:
return emptyValue, ValueDecoderError{
Name: "UintDecodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: reflect.Zero(t),
}
}
}
// DecodeValue is the ValueDecoder for uint types.
func (uic *UIntCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if !val.CanSet() {
return ValueDecoderError{
Name: "UintDecodeValue",
Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint},
Received: val,
}
}
elem, err := uic.decodeType(dc, vr, val.Type())
if err != nil {
return err
}
val.SetUint(elem.Uint())
return nil
}

Some files were not shown because too many files have changed in this diff Show More