Compare commits

...

103 Commits

Author SHA1 Message Date
021465e524 v0.0.121 2023-05-24 21:55:21 +02:00
cf9c73aa4a v0.0.120 2023-05-24 21:42:10 +02:00
0652bf22dc v0.0.119 2023-05-24 21:32:00 +02:00
b196adffc7 v0.0.118 2023-05-09 11:33:01 +02:00
717065e62d v0.0.117 2023-05-09 09:57:05 +02:00
e7b2b040b2 v0.0.116 2023-05-05 18:22:15 +02:00
05d0f9e469 v0.0.115 2023-05-05 18:18:20 +02:00
ccd03e50c8 v0.0.114 2023-05-05 18:17:15 +02:00
1c77c2b8e8 v0.0.113 2023-05-05 18:05:58 +02:00
9f6f967299 v0.0.112 2023-05-05 18:00:25 +02:00
18c83f0f76 v0.0.111 2023-05-05 17:57:21 +02:00
a64f336e24 v0.0.110 2023-05-05 17:47:30 +02:00
14bbd205f8 v0.0.109 2023-05-05 15:04:08 +02:00
cecfb0d788 v0.0.108 2023-05-05 14:43:40 +02:00
a445e6f623 v0.0.107 2023-04-26 11:35:28 +02:00
0aa6310971 v0.0.106 2023-04-26 11:34:46 +02:00
2f66ab1cf0 v0.0.105 2023-04-23 19:31:48 +02:00
304e779470 v0.0.104 2023-04-23 14:54:23 +02:00
5e295d65c5 v0.0.103 2023-04-20 14:35:55 +02:00
ef3705937c gojson: added MarshalSafeCollections 2023-04-20 14:34:57 +02:00
d780c7965f added gojson as a go/json fork (tag go1.20.2) 2023-04-20 14:30:24 +02:00
c13db6802e v0.0.102 2023-04-13 14:40:07 +02:00
c5e23ab451 v0.0.101 2023-04-08 19:39:13 +02:00
c266d9204b v0.0.100 2023-04-04 17:10:38 +02:00
2550691e2e v0.0.99 2023-03-31 13:33:06 +02:00
ca24e1d5bf v0.0.98 2023-03-29 20:25:03 +02:00
b156052e6f v0.0.97 2023-03-29 19:53:53 +02:00
dda2418255 v0.0.96 2023-03-29 19:53:10 +02:00
8e40deae6a add git-pull to Makefile 2023-03-28 16:30:56 +02:00
289b9f47a2 v0.0.95 2023-03-28 16:29:16 +02:00
007c44df85 v0.0.94 2023-03-21 16:00:15 +01:00
a6252f0743 v0.0.93 2023-03-15 15:41:55 +01:00
86c01659d7 base58 2023-03-15 14:00:48 +01:00
62acddda5e v0.0.91 2023-03-11 14:38:19 +01:00
ee325f67fd v0.0.90 2023-03-09 14:51:53 +01:00
dba0cd229e v0.0.89 2023-03-07 10:43:30 +01:00
ec4dba173f v0.0.88 2023-02-16 13:27:34 +01:00
22ce2d26f3 v0.0.87 2023-02-16 13:22:15 +01:00
4fd768e573 v0.0.86 2023-02-14 17:18:58 +01:00
bf16a8165f v0.0.85 2023-02-14 16:25:45 +01:00
9f5612248a fix fd0 read error on long stdout output (scanner buffer was too small) 2023-02-13 01:41:33 +01:00
4a2b830252 added more tests to cmdrunner (reproduce another ?? cmdrunner bug...) 2023-02-09 16:49:33 +01:00
c492c80881 v0.0.83 2023-02-09 15:06:37 +01:00
26dd16d021 v0.0.82 2023-02-09 15:01:54 +01:00
b0b43de8ca v0.0.81 2023-02-09 11:27:49 +01:00
94f72e4ddf v0.0.80 2023-02-09 11:16:23 +01:00
df4388e6dc v0.0.79 2023-02-08 18:55:51 +01:00
fd33b43f31 v0.0.78 2023-02-03 01:05:36 +01:00
be4de07eb8 v0.0.77 2023-02-03 00:59:54 +01:00
36ed474bfe v0.0.76 2023-01-31 23:46:35 +01:00
fdc590c8c3 v0.0.75 2023-01-31 22:41:12 +01:00
1990e5d32d v0.0.74 2023-01-31 11:01:45 +01:00
72883cf6bd v0.0.73 2023-01-31 10:56:30 +01:00
ff08d5f180 v0.0.72 2023-01-30 19:55:55 +01:00
72d6b538f7 v0.0.71 2023-01-29 22:28:08 +01:00
48dd30fb94 v0.0.70 2023-01-29 22:07:28 +01:00
b7c5756f11 v0.0.69 2023-01-29 22:00:40 +01:00
2070a432a5 v0.0.68 2023-01-29 21:27:55 +01:00
34e6d1819d v0.0.67 2023-01-29 20:42:02 +01:00
87fa6021e4 v0.0.66 2023-01-29 06:02:58 +01:00
297d6c52a8 v0.0.65 2023-01-29 05:45:29 +01:00
b9c46947d2 v0.0.64 2023-01-29 01:10:14 +01:00
412277b3e0 v0.0.63 2023-01-28 22:29:45 +01:00
e46f8019ec v0.0.62 2023-01-28 22:29:21 +01:00
ae952b2166 v0.0.61 2023-01-28 22:28:20 +01:00
b24dba9a45 v0.0.60 2023-01-28 14:44:12 +01:00
cfbc20367d v0.0.59 2023-01-15 02:27:08 +01:00
e25912758e v0.0.58 2023-01-15 02:05:05 +01:00
e1ae77a9db v0.0.57 2023-01-15 01:56:40 +01:00
9d07b3955f v0.0.56 2023-01-13 16:05:39 +01:00
02be696c25 v0.0.55 2023-01-06 02:02:22 +01:00
ba07625b7c v0.0.54 2022-12-29 22:52:52 +01:00
aeded3fb37 v0.0.53 2022-12-24 03:11:09 +01:00
1a1cd6d0aa v0.0.52 2022-12-24 02:50:46 +01:00
64cc1342a0 v0.0.51 2022-12-24 01:14:58 +01:00
8431b6adf5 v0.0.50 2022-12-23 20:08:59 +01:00
24e923fe84 v0.0.49 2022-12-23 19:11:18 +01:00
10ddc7c190 v0.0.48 2022-12-23 14:47:16 +01:00
7f88a0726c v0.0.47 2022-12-23 10:11:01 +01:00
2224db8e85 v0.0.46 2022-12-22 15:59:12 +01:00
c60afc89bb v0.0.45 2022-12-22 15:55:32 +01:00
bbb33e9fd6 v0.0.44 2022-12-22 15:49:10 +01:00
ac05eff1e8 v0.0.43 2022-12-22 10:23:34 +01:00
1aaad66233 v0.0.42 2022-12-22 10:06:25 +01:00
d4994b8c8d v0.0.41 2022-12-21 15:41:41 +01:00
e3b8d2cc0f v0.0.40 2022-12-21 15:34:59 +01:00
fff609db4a v0.0.39 2022-12-21 14:39:59 +01:00
5e99e07f40 v0.0.38 2022-12-21 13:00:39 +01:00
bdb181cb3a v0.0.37 2022-12-20 09:50:13 +01:00
3552acd38b v0.0.36 2022-12-15 12:36:24 +01:00
c42324c58f v0.0.35 2022-12-14 18:25:07 +01:00
3a9c3f4e9e v0.0.34 2022-12-11 03:12:02 +01:00
becd8f1ebc v0.0.33 2022-12-11 02:34:38 +01:00
e733f30c38 v0.0.32 2022-12-10 03:33:13 +01:00
1a9e5c70fc v0.0.31 2022-12-07 23:25:21 +01:00
f3700a772d v0.0.30 2022-12-07 23:21:36 +01:00
2c69b33547 v0.0.29 2022-12-07 23:15:16 +01:00
6a304b875a v0.0.28 2022-11-30 23:52:19 +01:00
d12bf23b46 v0.0.27 2022-11-30 23:40:59 +01:00
52f7f6e690 v0.0.26 2022-11-30 23:38:42 +01:00
b1e3891256 v0.0.25 2022-11-30 23:35:47 +01:00
bdf5b53c20 v0.0.24 2022-11-30 23:34:16 +01:00
496c4e4f59 v0.0.23 2022-11-30 23:08:50 +01:00
91 changed files with 16529 additions and 138 deletions

View File

@@ -1,6 +1,7 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="GoMixedReceiverTypes" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="LanguageDetectionInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
<option name="processCode" value="true" />

View File

@@ -7,6 +7,22 @@ set -o pipefail # Return value of a pipeline is the value of the last (rightmos
IFS=$'\n\t' # Set $IFS to only newline and tab.
function black() { echo -e "\x1B[30m $1 \x1B[0m"; }
function red() { echo -e "\x1B[31m $1 \x1B[0m"; }
function green() { echo -e "\x1B[32m $1 \x1B[0m"; }
function yellow(){ echo -e "\x1B[33m $1 \x1B[0m"; }
function blue() { echo -e "\x1B[34m $1 \x1B[0m"; }
function purple(){ echo -e "\x1B[35m $1 \x1B[0m"; }
function cyan() { echo -e "\x1B[36m $1 \x1B[0m"; }
function white() { echo -e "\x1B[37m $1 \x1B[0m"; }
if [ "$( git rev-parse --abbrev-ref HEAD )" != "master" ]; then
>&2 red "[ERROR] Can only create versions of <master>"
exit 1
fi
git pull --ff
curr_vers=$(git describe --tags --abbrev=0 | sed 's/v//g')
next_ver=$(echo "$curr_vers" | awk -F. -v OFS=. 'NF==1{print ++$NF}; NF>1{if(length($NF+1)>length($NF))$(NF-1)++; $NF=sprintf("%0*d", length($NF), ($NF+1)%(10^length($NF))); print}')
@@ -18,7 +34,13 @@ echo ""
git add --verbose .
git commit -a -m "v${next_ver}"
msg="v${next_ver}"
if [ $# -gt 0 ]; then
msg="$1"
fi
git commit -a -m "${msg}"
git tag "v${next_ver}"

93
cmdext/builder.go Normal file
View File

@@ -0,0 +1,93 @@
package cmdext
import (
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"time"
)
type CommandRunner struct {
program string
args []string
timeout *time.Duration
env []string
listener []CommandListener
enforceExitCodes *[]int
enforceNoTimeout bool
}
func Runner(program string) *CommandRunner {
return &CommandRunner{
program: program,
args: make([]string, 0),
timeout: nil,
env: make([]string, 0),
listener: make([]CommandListener, 0),
enforceExitCodes: nil,
enforceNoTimeout: false,
}
}
func (r *CommandRunner) Arg(arg string) *CommandRunner {
r.args = append(r.args, arg)
return r
}
func (r *CommandRunner) Args(arg []string) *CommandRunner {
r.args = append(r.args, arg...)
return r
}
func (r *CommandRunner) Timeout(timeout time.Duration) *CommandRunner {
r.timeout = &timeout
return r
}
func (r *CommandRunner) Env(key, value string) *CommandRunner {
r.env = append(r.env, fmt.Sprintf("%s=%s", key, value))
return r
}
func (r *CommandRunner) RawEnv(env string) *CommandRunner {
r.env = append(r.env, env)
return r
}
func (r *CommandRunner) Envs(env []string) *CommandRunner {
r.env = append(r.env, env...)
return r
}
func (r *CommandRunner) EnsureExitcode(arg ...int) *CommandRunner {
r.enforceExitCodes = langext.Ptr(langext.ForceArray(arg))
return r
}
func (r *CommandRunner) FailOnExitCode() *CommandRunner {
r.enforceExitCodes = langext.Ptr([]int{0})
return r
}
func (r *CommandRunner) FailOnTimeout() *CommandRunner {
r.enforceNoTimeout = true
return r
}
func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner {
r.listener = append(r.listener, lstr)
return r
}
func (r *CommandRunner) ListenStdout(lstr func(string)) *CommandRunner {
r.listener = append(r.listener, genericCommandListener{_readStdoutLine: &lstr})
return r
}
func (r *CommandRunner) ListenStderr(lstr func(string)) *CommandRunner {
r.listener = append(r.listener, genericCommandListener{_readStderrLine: &lstr})
return r
}
func (r *CommandRunner) Run() (CommandResult, error) {
return run(*r)
}

154
cmdext/cmdrunner.go Normal file
View File

@@ -0,0 +1,154 @@
package cmdext
import (
"errors"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/mathext"
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
"os/exec"
"time"
)
var ErrExitCode = errors.New("process exited with an unexpected exitcode")
var ErrTimeout = errors.New("process did not exit after the specified timeout")
type CommandResult struct {
StdOut string
StdErr string
StdCombined string
ExitCode int
CommandTimedOut bool
}
func run(opt CommandRunner) (CommandResult, error) {
cmd := exec.Command(opt.program, opt.args...)
cmd.Env = append(cmd.Env, opt.env...)
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return CommandResult{}, err
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
return CommandResult{}, err
}
preader := pipeReader{
lineBufferSize: langext.Ptr(128 * 1024 * 1024), // 128MB max size of a single line, is hopefully enough....
stdout: stdoutPipe,
stderr: stderrPipe,
}
err = cmd.Start()
if err != nil {
return CommandResult{}, err
}
type resultObj struct {
stdout string
stderr string
stdcombined string
err error
}
outputChan := make(chan resultObj)
go func() {
// we need to first fully read the pipes and then call Wait
// see https://pkg.go.dev/os/exec#Cmd.StdoutPipe
stdout, stderr, stdcombined, err := preader.Read(opt.listener)
if err != nil {
outputChan <- resultObj{stdout, stderr, stdcombined, err}
_ = cmd.Process.Kill()
return
}
err = cmd.Wait()
if err != nil {
outputChan <- resultObj{stdout, stderr, stdcombined, err}
} else {
outputChan <- resultObj{stdout, stderr, stdcombined, nil}
}
}()
var timeoutChan <-chan time.Time = make(chan time.Time, 1)
if opt.timeout != nil {
timeoutChan = time.After(*opt.timeout)
}
select {
case <-timeoutChan:
_ = cmd.Process.Kill()
for _, lstr := range opt.listener {
lstr.Timeout()
}
if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, mathext.Min(32*time.Millisecond, *opt.timeout)); ok {
// most of the time the cmd.Process.Kill() should also ahve finished the pipereader
// and we can at least return the already collected stdout, stderr, etc
res := CommandResult{
StdOut: fallback.stdout,
StdErr: fallback.stderr,
StdCombined: fallback.stdcombined,
ExitCode: -1,
CommandTimedOut: true,
}
if opt.enforceNoTimeout {
return res, ErrTimeout
}
return res, nil
} else {
res := CommandResult{
StdOut: "",
StdErr: "",
StdCombined: "",
ExitCode: -1,
CommandTimedOut: true,
}
if opt.enforceNoTimeout {
return res, ErrTimeout
}
return res, nil
}
case outobj := <-outputChan:
if exiterr, ok := outobj.err.(*exec.ExitError); ok {
excode := exiterr.ExitCode()
for _, lstr := range opt.listener {
lstr.Finished(excode)
}
res := CommandResult{
StdOut: outobj.stdout,
StdErr: outobj.stderr,
StdCombined: outobj.stdcombined,
ExitCode: excode,
CommandTimedOut: false,
}
if opt.enforceExitCodes != nil && !langext.InArray(excode, *opt.enforceExitCodes) {
return res, ErrExitCode
}
return res, nil
} else if err != nil {
return CommandResult{}, err
} else {
for _, lstr := range opt.listener {
lstr.Finished(0)
}
res := CommandResult{
StdOut: outobj.stdout,
StdErr: outobj.stderr,
StdCombined: outobj.stdcombined,
ExitCode: 0,
CommandTimedOut: false,
}
if opt.enforceExitCodes != nil && !langext.InArray(0, *opt.enforceExitCodes) {
return res, ErrExitCode
}
return res, nil
}
}
}

323
cmdext/cmdrunner_test.go Normal file
View File

@@ -0,0 +1,323 @@
package cmdext
import (
"fmt"
"testing"
"time"
)
func TestStdout(t *testing.T) {
res1, err := Runner("printf").Arg("hello").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "hello" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "hello\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestStderr(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "error" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "error\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestStdcombined(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"1\", file=sys.stderr, flush=True); time.sleep(0.1); print(\"2\", file=sys.stdout, flush=True); time.sleep(0.1); print(\"3\", file=sys.stderr, flush=True)").
Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "1\n3\n" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "2\n" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "1\n2\n3\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialRead(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", flush=True); time.sleep(5); print(\"cant see me\", flush=True);").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "first message\n" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadStderr(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, flush=True); time.sleep(5); print(\"cant see me\", file=sys.stderr, flush=True);").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "first message\n" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestReadUnflushedStdout(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stdout, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "message101" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "message101\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestReadUnflushedStderr(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "message101" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "message101\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadUnflushed(t *testing.T) {
t.SkipNow()
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", end=''); time.sleep(5); print(\"cant see me\", end='');").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "first message" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadUnflushedStderr(t *testing.T) {
t.SkipNow()
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, end=''); time.sleep(5); print(\"cant see me\", file=sys.stderr, end='');").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "first message" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestListener(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys;" +
"import time;" +
"print(\"message 1\", flush=True);" +
"time.sleep(1);" +
"print(\"message 2\", flush=True);" +
"time.sleep(1);" +
"print(\"message 3\", flush=True);" +
"time.sleep(1);" +
"print(\"message 4\", file=sys.stderr, flush=True);" +
"time.sleep(1);" +
"print(\"message 5\", flush=True);" +
"time.sleep(1);" +
"print(\"final\");").
ListenStdout(func(s string) { fmt.Printf("@@STDOUT <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
ListenStderr(func(s string) { fmt.Printf("@@STDERR <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
Timeout(10 * time.Second).
Run()
if err != nil {
panic(err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
}
func TestLongStdout(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"X\" * 125001 + \"\\n\"); print(\"Y\" * 125001 + \"\\n\"); print(\"Z\" * 125001 + \"\\n\");").
Timeout(5000 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if len(res1.StdOut) != 375009 {
t.Errorf("len(res1.StdOut) == '%v'", len(res1.StdOut))
}
}
func TestFailOnTimeout(t *testing.T) {
_, err := Runner("sleep").Arg("2").Timeout(200 * time.Millisecond).FailOnTimeout().Run()
if err != ErrTimeout {
t.Errorf("wrong err := %v", err)
}
}
func TestFailOnExitcode(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).FailOnExitCode().Run()
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}
}
func TestEnsureExitcode1(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).EnsureExitcode(1).Run()
if err != nil {
t.Errorf("wrong err := %v", err)
}
}
func TestEnsureExitcode2(t *testing.T) {
_, err := Runner("false").Timeout(200*time.Millisecond).EnsureExitcode(0, 2, 3).Run()
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}
}

12
cmdext/helper.go Normal file
View File

@@ -0,0 +1,12 @@
package cmdext
import "time"
func RunCommand(program string, args []string, timeout *time.Duration) (CommandResult, error) {
b := Runner(program)
b = b.Args(args)
if timeout != nil {
b = b.Timeout(*timeout)
}
return b.Run()
}

57
cmdext/listener.go Normal file
View File

@@ -0,0 +1,57 @@
package cmdext
type CommandListener interface {
ReadRawStdout([]byte)
ReadRawStderr([]byte)
ReadStdoutLine(string)
ReadStderrLine(string)
Finished(int)
Timeout()
}
type genericCommandListener struct {
_readRawStdout *func([]byte)
_readRawStderr *func([]byte)
_readStdoutLine *func(string)
_readStderrLine *func(string)
_finished *func(int)
_timeout *func()
}
func (g genericCommandListener) ReadRawStdout(v []byte) {
if g._readRawStdout != nil {
(*g._readRawStdout)(v)
}
}
func (g genericCommandListener) ReadRawStderr(v []byte) {
if g._readRawStderr != nil {
(*g._readRawStderr)(v)
}
}
func (g genericCommandListener) ReadStdoutLine(v string) {
if g._readStdoutLine != nil {
(*g._readStdoutLine)(v)
}
}
func (g genericCommandListener) ReadStderrLine(v string) {
if g._readStderrLine != nil {
(*g._readStderrLine)(v)
}
}
func (g genericCommandListener) Finished(v int) {
if g._finished != nil {
(*g._finished)(v)
}
}
func (g genericCommandListener) Timeout() {
if g._timeout != nil {
(*g._timeout)()
}
}

158
cmdext/pipereader.go Normal file
View File

@@ -0,0 +1,158 @@
package cmdext
import (
"bufio"
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
"io"
"sync"
)
type pipeReader struct {
lineBufferSize *int
stdout io.ReadCloser
stderr io.ReadCloser
}
// Read ready stdout and stdin until finished
// also splits both pipes into lines and calld the listener
func (pr *pipeReader) Read(listener []CommandListener) (string, string, string, error) {
type combevt struct {
line string
stop bool
}
errch := make(chan error, 8)
wg := sync.WaitGroup{}
// [1] read raw stdout
wg.Add(1)
stdoutBufferReader, stdoutBufferWriter := io.Pipe()
stdout := ""
go func() {
buf := make([]byte, 128)
for true {
n, out := pr.stdout.Read(buf)
if n > 0 {
txt := string(buf[:n])
stdout += txt
_, _ = stdoutBufferWriter.Write(buf[:n])
for _, lstr := range listener {
lstr.ReadRawStdout(buf[:n])
}
}
if out == io.EOF {
break
}
if out != nil {
errch <- out
break
}
}
_ = stdoutBufferWriter.Close()
wg.Done()
}()
// [2] read raw stderr
wg.Add(1)
stderrBufferReader, stderrBufferWriter := io.Pipe()
stderr := ""
go func() {
buf := make([]byte, 128)
for true {
n, err := pr.stderr.Read(buf)
if n > 0 {
txt := string(buf[:n])
stderr += txt
_, _ = stderrBufferWriter.Write(buf[:n])
for _, lstr := range listener {
lstr.ReadRawStderr(buf[:n])
}
}
if err == io.EOF {
break
}
if err != nil {
errch <- err
break
}
}
_ = stderrBufferWriter.Close()
wg.Done()
}()
combch := make(chan combevt, 32)
// [3] collect stdout line-by-line
wg.Add(1)
go func() {
scanner := bufio.NewScanner(stdoutBufferReader)
if pr.lineBufferSize != nil {
scanner.Buffer([]byte{}, *pr.lineBufferSize)
}
for scanner.Scan() {
txt := scanner.Text()
for _, lstr := range listener {
lstr.ReadStdoutLine(txt)
}
combch <- combevt{txt, false}
}
if err := scanner.Err(); err != nil {
errch <- err
}
combch <- combevt{"", true}
wg.Done()
}()
// [4] collect stderr line-by-line
wg.Add(1)
go func() {
scanner := bufio.NewScanner(stderrBufferReader)
if pr.lineBufferSize != nil {
scanner.Buffer([]byte{}, *pr.lineBufferSize)
}
for scanner.Scan() {
txt := scanner.Text()
for _, lstr := range listener {
lstr.ReadStderrLine(txt)
}
combch <- combevt{txt, false}
}
if err := scanner.Err(); err != nil {
errch <- err
}
combch <- combevt{"", true}
wg.Done()
}()
// [5] combine stdcombined
wg.Add(1)
stdcombined := ""
go func() {
stopctr := 0
for stopctr < 2 {
vvv := <-combch
if vvv.stop {
stopctr++
} else {
stdcombined += vvv.line + "\n" // this comes from bufio.Scanner and has no newlines...
}
}
wg.Done()
}()
// wait for all (5) goroutines to finish
wg.Wait()
if err, ok := syncext.ReadNonBlocking(errch); ok {
return "", "", "", err
}
return stdout, stderr, stdcombined, nil
}

183
confext/confParser.go Normal file
View File

@@ -0,0 +1,183 @@
package confext
import (
"errors"
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/timeext"
"math/bits"
"os"
"reflect"
"strconv"
"time"
)
// ApplyEnvOverrides overrides field values from environment variables
//
// fields must be tagged with `env:"env_key"`
//
// only works on exported fields
//
// fields without an env tag are ignored
// fields with an `env:"-"` tag are ignore
//
// sub-structs are recursively parsed (if they have an env tag) and the env-variable keys are delimited by the delim parameter
// sub-structs with `env:""` are also parsed, but the delimited is skipped (they are handled as if they were one level higher)
func ApplyEnvOverrides[T any](prefix string, c *T, delim string) error {
rval := reflect.ValueOf(c).Elem()
return processEnvOverrides(rval, delim, prefix)
}
func processEnvOverrides(rval reflect.Value, delim string, prefix string) error {
rtyp := rval.Type()
for i := 0; i < rtyp.NumField(); i++ {
rsfield := rtyp.Field(i)
rvfield := rval.Field(i)
if !rsfield.IsExported() {
continue
}
if rvfield.Kind() == reflect.Struct {
envkey, found := rsfield.Tag.Lookup("env")
if !found || envkey == "-" {
continue
}
subPrefix := prefix
if envkey != "" {
subPrefix = subPrefix + envkey + delim
}
err := processEnvOverrides(rvfield, delim, subPrefix)
if err != nil {
return err
}
}
envkey := rsfield.Tag.Get("env")
if envkey == "" || envkey == "-" {
continue
}
fullEnvKey := prefix + envkey
envval, efound := os.LookupEnv(fullEnvKey)
if !efound {
continue
}
if rvfield.Type().Kind() == reflect.Pointer {
newval, err := parseEnvToValue(envval, fullEnvKey, rvfield.Type().Elem())
if err != nil {
return err
}
// converts reflect.Value to pointer
ptrval := reflect.New(rvfield.Type().Elem())
ptrval.Elem().Set(newval)
rvfield.Set(ptrval)
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else {
newval, err := parseEnvToValue(envval, fullEnvKey, rvfield.Type())
if err != nil {
return err
}
rvfield.Set(newval)
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
}
}
return nil
}
func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (reflect.Value, error) {
if rvtype == reflect.TypeOf("") {
return reflect.ValueOf(envval), nil
} else if rvtype == reflect.TypeOf(int(0)) {
envint, err := strconv.ParseInt(envval, 10, bits.UintSize)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int(envint)), nil
} else if rvtype == reflect.TypeOf(int64(0)) {
envint, err := strconv.ParseInt(envval, 10, 64)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int64 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int64(envint)), nil
} else if rvtype == reflect.TypeOf(int32(0)) {
envint, err := strconv.ParseInt(envval, 10, 32)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int32(envint)), nil
} else if rvtype == reflect.TypeOf(int8(0)) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int8(envint)), nil
} else if rvtype == reflect.TypeOf(time.Duration(0)) {
dur, err := timeext.ParseDurationShortString(envval)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to duration (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(dur), nil
} else if rvtype == reflect.TypeOf(time.UnixMilli(0)) {
tim, err := time.Parse(time.RFC3339Nano, envval)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to time.time (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(tim), nil
} else if rvtype.ConvertibleTo(reflect.TypeOf(int(0))) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,int> (value := '%s')", rvtype.Name(), fullEnvKey, envval))
}
envcvl := reflect.ValueOf(envint).Convert(rvtype)
return envcvl, nil
} else if rvtype.ConvertibleTo(reflect.TypeOf("")) {
envcvl := reflect.ValueOf(envval).Convert(rvtype)
return envcvl, nil
} else {
return reflect.Value{}, errors.New(fmt.Sprintf("Unknown kind/type in config: [ %s | %s ]", rvtype.Kind().String(), rvtype.String()))
}
}

278
confext/confParser_test.go Normal file
View File

@@ -0,0 +1,278 @@
package confext
import (
"gogs.mikescher.com/BlackForestBytes/goext/timeext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
"time"
)
func TestApplyEnvOverridesNoop(t *testing.T) {
type aliasint int
type aliasstring string
type testdata struct {
V1 int `env:"TEST_V1"`
VX string ``
V2 string `env:"TEST_V2"`
V3 int8 `env:"TEST_V3"`
V4 int32 `env:"TEST_V4"`
V5 int64 `env:"TEST_V5"`
V6 aliasint `env:"TEST_V6"`
VY aliasint ``
V7 aliasstring `env:"TEST_V7"`
V8 time.Duration `env:"TEST_V8"`
V9 time.Time `env:"TEST_V9"`
}
input := testdata{
V1: 1,
VX: "X",
V2: "2",
V3: 3,
V4: 4,
V5: 5,
V6: 6,
VY: 99,
V7: "7",
V8: 9,
V9: time.Unix(1671102873, 0),
}
output := input
err := ApplyEnvOverrides("", &output, ".")
if err != nil {
t.Errorf("%v", err)
t.FailNow()
}
tst.AssertEqual(t, input, output)
}
func TestApplyEnvOverridesSimple(t *testing.T) {
type aliasint int
type aliasstring string
type testdata struct {
V1 int `env:"TEST_V1"`
VX string ``
V2 string `env:"TEST_V2"`
V3 int8 `env:"TEST_V3"`
V4 int32 `env:"TEST_V4"`
V5 int64 `env:"TEST_V5"`
V6 aliasint `env:"TEST_V6"`
VY aliasint ``
V7 aliasstring `env:"TEST_V7"`
V8 time.Duration `env:"TEST_V8"`
V9 time.Time `env:"TEST_V9"`
}
data := testdata{
V1: 1,
VX: "X",
V2: "2",
V3: 3,
V4: 4,
V5: 5,
V6: 6,
VY: 99,
V7: "7",
V8: 9,
V9: time.Unix(1671102873, 0),
}
t.Setenv("TEST_V1", "846")
t.Setenv("TEST_V2", "hello_world")
t.Setenv("TEST_V3", "6")
t.Setenv("TEST_V4", "333")
t.Setenv("TEST_V5", "-937")
t.Setenv("TEST_V6", "070")
t.Setenv("TEST_V7", "AAAAAA")
t.Setenv("TEST_V8", "1min4s")
t.Setenv("TEST_V9", "2009-11-10T23:00:00Z")
err := ApplyEnvOverrides("", &data, ".")
if err != nil {
t.Errorf("%v", err)
t.FailNow()
}
tst.AssertEqual(t, data.V1, 846)
tst.AssertEqual(t, data.V2, "hello_world")
tst.AssertEqual(t, data.V3, 6)
tst.AssertEqual(t, data.V4, 333)
tst.AssertEqual(t, data.V5, -937)
tst.AssertEqual(t, data.V6, 70)
tst.AssertEqual(t, data.V7, "AAAAAA")
tst.AssertEqual(t, data.V8, time.Second*64)
tst.AssertEqual(t, data.V9, time.Unix(1257894000, 0).UTC())
}
func TestApplyEnvOverridesRecursive(t *testing.T) {
type subdata struct {
V1 int `env:"SUB_V1"`
VX string ``
V2 string `env:"SUB_V2"`
V8 time.Duration `env:"SUB_V3"`
V9 time.Time `env:"SUB_V4"`
}
type testdata struct {
V1 int `env:"TEST_V1"`
VX string ``
Sub1 subdata ``
Sub2 subdata `env:"TEST_V2"`
Sub3 subdata `env:"TEST_V3"`
Sub4 subdata `env:""`
V5 string `env:"-"`
}
data := testdata{
V1: 1,
VX: "2",
V5: "no",
Sub1: subdata{
V1: 3,
VX: "4",
V2: "5",
V8: 6 * time.Second,
V9: time.Date(2000, 1, 7, 1, 1, 1, 0, time.UTC),
},
Sub2: subdata{
V1: 8,
VX: "9",
V2: "10",
V8: 11 * time.Second,
V9: time.Date(2000, 1, 12, 1, 1, 1, 0, timeext.TimezoneBerlin),
},
Sub3: subdata{
V1: 13,
VX: "14",
V2: "15",
V8: 16 * time.Second,
V9: time.Date(2000, 1, 17, 1, 1, 1, 0, timeext.TimezoneBerlin),
},
Sub4: subdata{
V1: 18,
VX: "19",
V2: "20",
V8: 21 * time.Second,
V9: time.Date(2000, 1, 22, 1, 1, 1, 0, timeext.TimezoneBerlin),
},
}
t.Setenv("TEST_V1", "999")
t.Setenv("-", "yes")
t.Setenv("TEST_V2_SUB_V1", "846")
t.Setenv("TEST_V2_SUB_V2", "222_hello_world")
t.Setenv("TEST_V2_SUB_V3", "1min4s")
t.Setenv("TEST_V2_SUB_V4", "2009-11-10T23:00:00Z")
t.Setenv("TEST_V3_SUB_V1", "33846")
t.Setenv("TEST_V3_SUB_V2", "33_hello_world")
t.Setenv("TEST_V3_SUB_V3", "33min4s")
t.Setenv("TEST_V3_SUB_V4", "2033-11-10T23:00:00Z")
t.Setenv("SUB_V1", "11")
t.Setenv("SUB_V2", "22")
t.Setenv("SUB_V3", "33min")
t.Setenv("SUB_V4", "2044-01-01T00:00:00Z")
err := ApplyEnvOverrides("", &data, "_")
if err != nil {
t.Errorf("%v", err)
t.FailNow()
}
tst.AssertEqual(t, data.V1, 999)
tst.AssertEqual(t, data.VX, "2")
tst.AssertEqual(t, data.V5, "no")
tst.AssertEqual(t, data.Sub1.V1, 3)
tst.AssertEqual(t, data.Sub1.VX, "4")
tst.AssertEqual(t, data.Sub1.V2, "5")
tst.AssertEqual(t, data.Sub1.V8, time.Second*6)
tst.AssertEqual(t, data.Sub1.V9, time.Unix(947206861, 0).UTC())
tst.AssertEqual(t, data.Sub2.V1, 846)
tst.AssertEqual(t, data.Sub2.VX, "9")
tst.AssertEqual(t, data.Sub2.V2, "222_hello_world")
tst.AssertEqual(t, data.Sub2.V8, time.Second*64)
tst.AssertEqual(t, data.Sub2.V9, time.Unix(1257894000, 0).UTC())
tst.AssertEqual(t, data.Sub3.V1, 33846)
tst.AssertEqual(t, data.Sub3.VX, "14")
tst.AssertEqual(t, data.Sub3.V2, "33_hello_world")
tst.AssertEqual(t, data.Sub3.V8, time.Second*1984)
tst.AssertEqual(t, data.Sub3.V9, time.Unix(2015276400, 0).UTC())
tst.AssertEqual(t, data.Sub4.V1, 11)
tst.AssertEqual(t, data.Sub4.VX, "19")
tst.AssertEqual(t, data.Sub4.V2, "22")
tst.AssertEqual(t, data.Sub4.V8, time.Second*1980)
tst.AssertEqual(t, data.Sub4.V9, time.Unix(2335219200, 0).UTC())
}
func TestApplyEnvOverridesPointer(t *testing.T) {
type aliasint int
type aliasstring string
type testdata struct {
V1 *int `env:"TEST_V1"`
VX *string ``
V2 *string `env:"TEST_V2"`
V3 *int8 `env:"TEST_V3"`
V4 *int32 `env:"TEST_V4"`
V5 *int64 `env:"TEST_V5"`
V6 *aliasint `env:"TEST_V6"`
VY *aliasint ``
V7 *aliasstring `env:"TEST_V7"`
V8 *time.Duration `env:"TEST_V8"`
V9 *time.Time `env:"TEST_V9"`
}
data := testdata{}
t.Setenv("TEST_V1", "846")
t.Setenv("TEST_V2", "hello_world")
t.Setenv("TEST_V3", "6")
t.Setenv("TEST_V4", "333")
t.Setenv("TEST_V5", "-937")
t.Setenv("TEST_V6", "070")
t.Setenv("TEST_V7", "AAAAAA")
t.Setenv("TEST_V8", "1min4s")
t.Setenv("TEST_V9", "2009-11-10T23:00:00Z")
err := ApplyEnvOverrides("", &data, ".")
if err != nil {
t.Errorf("%v", err)
t.FailNow()
}
tst.AssertDeRefEqual(t, data.V1, 846)
tst.AssertDeRefEqual(t, data.V2, "hello_world")
tst.AssertDeRefEqual(t, data.V3, 6)
tst.AssertDeRefEqual(t, data.V4, 333)
tst.AssertDeRefEqual(t, data.V5, -937)
tst.AssertDeRefEqual(t, data.V6, 70)
tst.AssertDeRefEqual(t, data.V7, "AAAAAA")
tst.AssertDeRefEqual(t, data.V8, time.Second*64)
tst.AssertDeRefEqual(t, data.V9, time.Unix(1257894000, 0).UTC())
}
func assertEqual[T comparable](t *testing.T, actual T, expected T) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}
func assertPtrEqual[T comparable](t *testing.T, actual *T, expected T) {
if actual == nil {
t.Errorf("values differ: Actual: NIL, Expected: '%v'", expected)
}
if *actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

132
cryptext/aes.go Normal file
View File

@@ -0,0 +1,132 @@
package cryptext
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base32"
"encoding/json"
"errors"
"golang.org/x/crypto/scrypt"
"io"
)
// https://stackoverflow.com/a/18819040/1761622
type aesPayload struct {
Salt []byte `json:"s"`
IV []byte `json:"i"`
Data []byte `json:"d"`
Rounds int `json:"r"`
Version uint `json:"v"`
}
func EncryptAESSimple(password []byte, data []byte, rounds int) (string, error) {
salt := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, salt)
if err != nil {
return "", err
}
key, err := scrypt.Key(password, salt, rounds, 8, 1, 32)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
h := sha256.New()
h.Write(data)
checksum := h.Sum(nil)
if len(checksum) != 32 {
return "", errors.New("wrong cs size")
}
ciphertext := make([]byte, 32+len(data))
iv := make([]byte, aes.BlockSize)
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return "", err
}
combinedData := make([]byte, 0, 32+len(data))
combinedData = append(combinedData, checksum...)
combinedData = append(combinedData, data...)
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext, combinedData)
pl := aesPayload{
Salt: salt,
IV: iv,
Data: ciphertext,
Version: 1,
Rounds: rounds,
}
jbin, err := json.Marshal(pl)
if err != nil {
return "", err
}
res := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(jbin)
return res, nil
}
func DecryptAESSimple(password []byte, encText string) ([]byte, error) {
jbin, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encText)
if err != nil {
return nil, err
}
var pl aesPayload
err = json.Unmarshal(jbin, &pl)
if err != nil {
return nil, err
}
if pl.Version != 1 {
return nil, errors.New("unsupported version")
}
key, err := scrypt.Key(password, pl.Salt, pl.Rounds, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
dest := make([]byte, len(pl.Data))
cfb := cipher.NewCFBDecrypter(block, pl.IV)
cfb.XORKeyStream(dest, pl.Data)
if len(dest) < 32 {
return nil, errors.New("payload too small")
}
chck := dest[:32]
data := dest[32:]
h := sha256.New()
h.Write(data)
chck2 := h.Sum(nil)
if !bytes.Equal(chck, chck2) {
return nil, errors.New("checksum mismatch")
}
return data, nil
}

36
cryptext/aes_test.go Normal file
View File

@@ -0,0 +1,36 @@
package cryptext
import (
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
func TestEncryptAESSimple(t *testing.T) {
pw := []byte("hunter12")
str1 := []byte("Hello World")
str2, err := EncryptAESSimple(pw, str1, 512)
if err != nil {
panic(err)
}
fmt.Printf("%s\n", str2)
str3, err := DecryptAESSimple(pw, str2)
if err != nil {
panic(err)
}
tst.AssertEqual(t, string(str1), string(str3))
str4, err := EncryptAESSimple(pw, str3, 512)
if err != nil {
panic(err)
}
tst.AssertNotEqual(t, string(str2), string(str4))
}

View File

@@ -1,25 +1,20 @@
package cryptext
import (
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
func TestStrSha256(t *testing.T) {
assertEqual(t, StrSha256(""), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
assertEqual(t, StrSha256("0"), "5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91b46729d73a27fb57e9")
assertEqual(t, StrSha256("80085"), "b3786e141d65638ad8a98173e26b5f6a53c927737b23ff31fb1843937250f44b")
assertEqual(t, StrSha256("Hello World"), "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e")
tst.AssertEqual(t, StrSha256(""), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
tst.AssertEqual(t, StrSha256("0"), "5feceb66ffc86f38d952786c6d696c79c2dbc239dd4e91b46729d73a27fb57e9")
tst.AssertEqual(t, StrSha256("80085"), "b3786e141d65638ad8a98173e26b5f6a53c927737b23ff31fb1843937250f44b")
tst.AssertEqual(t, StrSha256("Hello World"), "a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e")
}
func TestBytesSha256(t *testing.T) {
assertEqual(t, BytesSha256([]byte{}), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
assertEqual(t, BytesSha256([]byte{0}), "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d")
assertEqual(t, BytesSha256([]byte{128}), "76be8b528d0075f7aae98d6fa57a6d3c83ae480a8469e668d7b0af968995ac71")
assertEqual(t, BytesSha256([]byte{0, 1, 2, 4, 8, 16, 32, 64, 128, 255}), "55016a318ba538e00123c736b2a8b6db368d00e7e25727547655b653e5853603")
}
func assertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
tst.AssertEqual(t, BytesSha256([]byte{}), "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
tst.AssertEqual(t, BytesSha256([]byte{0}), "6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d")
tst.AssertEqual(t, BytesSha256([]byte{128}), "76be8b528d0075f7aae98d6fa57a6d3c83ae480a8469e668d7b0af968995ac71")
tst.AssertEqual(t, BytesSha256([]byte{0, 1, 2, 4, 8, 16, 32, 64, 128, 255}), "55016a318ba538e00123c736b2a8b6db368d00e7e25727547655b653e5853603")
}

365
cryptext/passHash.go Normal file
View File

@@ -0,0 +1,365 @@
package cryptext
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/totpext"
"golang.org/x/crypto/bcrypt"
"strconv"
"strings"
)
const LatestPassHashVersion = 4
// PassHash
// - [v0]: plaintext password ( `0|...` )
// - [v1]: sha256(plaintext)
// - [v2]: seed | sha256<seed>(plaintext)
// - [v3]: seed | sha256<seed>(plaintext) | [hex(totp)]
// - [v4]: bcrypt(plaintext) | [hex(totp)]
type PassHash string
func (ph PassHash) Valid() bool {
_, _, _, _, _, valid := ph.Data()
return valid
}
func (ph PassHash) HasTOTP() bool {
_, _, _, otp, _, _ := ph.Data()
return otp
}
func (ph PassHash) Data() (_version int, _seed []byte, _payload []byte, _totp bool, _totpsecret []byte, _valid bool) {
split := strings.Split(string(ph), "|")
if len(split) == 0 {
return -1, nil, nil, false, nil, false
}
version, err := strconv.ParseInt(split[0], 10, 32)
if err != nil {
return -1, nil, nil, false, nil, false
}
if version == 0 {
if len(split) != 2 {
return -1, nil, nil, false, nil, false
}
return int(version), nil, []byte(split[1]), false, nil, true
}
if version == 1 {
if len(split) != 2 {
return -1, nil, nil, false, nil, false
}
payload, err := base64.RawStdEncoding.DecodeString(split[1])
if err != nil {
return -1, nil, nil, false, nil, false
}
return int(version), nil, payload, false, nil, true
}
//
if version == 2 {
if len(split) != 3 {
return -1, nil, nil, false, nil, false
}
seed, err := base64.RawStdEncoding.DecodeString(split[1])
if err != nil {
return -1, nil, nil, false, nil, false
}
payload, err := base64.RawStdEncoding.DecodeString(split[2])
if err != nil {
return -1, nil, nil, false, nil, false
}
return int(version), seed, payload, false, nil, true
}
if version == 3 {
if len(split) != 4 {
return -1, nil, nil, false, nil, false
}
seed, err := base64.RawStdEncoding.DecodeString(split[1])
if err != nil {
return -1, nil, nil, false, nil, false
}
payload, err := base64.RawStdEncoding.DecodeString(split[2])
if err != nil {
return -1, nil, nil, false, nil, false
}
totp := false
totpsecret := make([]byte, 0)
if split[3] != "0" {
totpsecret, err = hex.DecodeString(split[3])
totp = true
}
return int(version), seed, payload, totp, totpsecret, true
}
if version == 4 {
if len(split) != 3 {
return -1, nil, nil, false, nil, false
}
payload := []byte(split[1])
totp := false
totpsecret := make([]byte, 0)
if split[2] != "0" {
totpsecret, err = hex.DecodeString(split[3])
totp = true
}
return int(version), nil, payload, totp, totpsecret, true
}
return -1, nil, nil, false, nil, false
}
func (ph PassHash) Verify(plainpass string, totp *string) bool {
version, seed, payload, hastotp, totpsecret, valid := ph.Data()
if !valid {
return false
}
if hastotp && totp == nil {
return false
}
if version == 0 {
return langext.ArrEqualsExact([]byte(plainpass), payload)
}
if version == 1 {
return langext.ArrEqualsExact(hash256(plainpass), payload)
}
if version == 2 {
return langext.ArrEqualsExact(hash256Seeded(plainpass, seed), payload)
}
if version == 3 {
if !hastotp {
return langext.ArrEqualsExact(hash256Seeded(plainpass, seed), payload)
} else {
return langext.ArrEqualsExact(hash256Seeded(plainpass, seed), payload) && totpext.Validate(totpsecret, *totp)
}
}
if version == 4 {
if !hastotp {
return bcrypt.CompareHashAndPassword(payload, []byte(plainpass)) == nil
} else {
return bcrypt.CompareHashAndPassword(payload, []byte(plainpass)) == nil && totpext.Validate(totpsecret, *totp)
}
}
return false
}
func (ph PassHash) NeedsPasswordUpgrade() bool {
version, _, _, _, _, valid := ph.Data()
return valid && version < LatestPassHashVersion
}
func (ph PassHash) Upgrade(plainpass string) (PassHash, error) {
version, _, _, hastotp, totpsecret, valid := ph.Data()
if !valid {
return "", errors.New("invalid password")
}
if version == LatestPassHashVersion {
return ph, nil
}
if hastotp {
return HashPassword(plainpass, totpsecret)
} else {
return HashPassword(plainpass, nil)
}
}
func (ph PassHash) ClearTOTP() (PassHash, error) {
version, _, _, _, _, valid := ph.Data()
if !valid {
return "", errors.New("invalid PassHash")
}
if version == 0 {
return ph, nil
}
if version == 1 {
return ph, nil
}
if version == 2 {
return ph, nil
}
if version == 3 {
split := strings.Split(string(ph), "|")
split[3] = "0"
return PassHash(strings.Join(split, "|")), nil
}
if version == 4 {
split := strings.Split(string(ph), "|")
split[2] = "0"
return PassHash(strings.Join(split, "|")), nil
}
return "", errors.New("unknown version")
}
func (ph PassHash) WithTOTP(totpSecret []byte) (PassHash, error) {
version, _, _, _, _, valid := ph.Data()
if !valid {
return "", errors.New("invalid PassHash")
}
if version == 0 {
return "", errors.New("version does not support totp, needs upgrade")
}
if version == 1 {
return "", errors.New("version does not support totp, needs upgrade")
}
if version == 2 {
return "", errors.New("version does not support totp, needs upgrade")
}
if version == 3 {
split := strings.Split(string(ph), "|")
split[3] = hex.EncodeToString(totpSecret)
return PassHash(strings.Join(split, "|")), nil
}
if version == 4 {
split := strings.Split(string(ph), "|")
split[2] = hex.EncodeToString(totpSecret)
return PassHash(strings.Join(split, "|")), nil
}
return "", errors.New("unknown version")
}
func (ph PassHash) Change(newPlainPass string) (PassHash, error) {
version, _, _, hastotp, totpsecret, valid := ph.Data()
if !valid {
return "", errors.New("invalid PassHash")
}
if version == 0 {
return HashPasswordV0(newPlainPass)
}
if version == 1 {
return HashPasswordV1(newPlainPass)
}
if version == 2 {
return HashPasswordV2(newPlainPass)
}
if version == 3 {
return HashPasswordV3(newPlainPass, langext.Conditional(hastotp, totpsecret, nil))
}
if version == 4 {
return HashPasswordV4(newPlainPass, langext.Conditional(hastotp, totpsecret, nil))
}
return "", errors.New("unknown version")
}
func (ph PassHash) String() string {
return string(ph)
}
func HashPassword(plainpass string, totpSecret []byte) (PassHash, error) {
return HashPasswordV4(plainpass, totpSecret)
}
func HashPasswordV4(plainpass string, totpSecret []byte) (PassHash, error) {
var strtotp string
if totpSecret == nil {
strtotp = "0"
} else {
strtotp = hex.EncodeToString(totpSecret)
}
payload, err := bcrypt.GenerateFromPassword([]byte(plainpass), bcrypt.MinCost)
if err != nil {
return "", err
}
return PassHash(fmt.Sprintf("4|%s|%s", string(payload), strtotp)), nil
}
func HashPasswordV3(plainpass string, totpSecret []byte) (PassHash, error) {
var strtotp string
if totpSecret == nil {
strtotp = "0"
} else {
strtotp = hex.EncodeToString(totpSecret)
}
seed, err := newSeed()
if err != nil {
return "", err
}
checksum := hash256Seeded(plainpass, seed)
return PassHash(fmt.Sprintf("3|%s|%s|%s",
base64.RawStdEncoding.EncodeToString(seed),
base64.RawStdEncoding.EncodeToString(checksum),
strtotp)), nil
}
func HashPasswordV2(plainpass string) (PassHash, error) {
seed, err := newSeed()
if err != nil {
return "", err
}
checksum := hash256Seeded(plainpass, seed)
return PassHash(fmt.Sprintf("2|%s|%s", base64.RawStdEncoding.EncodeToString(seed), base64.RawStdEncoding.EncodeToString(checksum))), nil
}
func HashPasswordV1(plainpass string) (PassHash, error) {
return PassHash(fmt.Sprintf("1|%s", base64.RawStdEncoding.EncodeToString(hash256(plainpass)))), nil
}
func HashPasswordV0(plainpass string) (PassHash, error) {
return PassHash(fmt.Sprintf("0|%s", plainpass)), nil
}
func hash256(s string) []byte {
h := sha256.New()
h.Write([]byte(s))
bs := h.Sum(nil)
return bs
}
func hash256Seeded(s string, seed []byte) []byte {
h := sha256.New()
h.Write(seed)
h.Write([]byte(s))
bs := h.Sum(nil)
return bs
}
func newSeed() ([]byte, error) {
secret := make([]byte, 32)
_, err := rand.Read(secret)
if err != nil {
return nil, err
}
return secret, nil
}

View File

@@ -1,56 +1,157 @@
package dataext
import "io"
import (
"errors"
"io"
)
type brcMode int
const (
modeSourceReading brcMode = 0
modeSourceFinished brcMode = 1
modeBufferReading brcMode = 2
modeBufferFinished brcMode = 3
)
type BufferedReadCloser interface {
io.ReadCloser
BufferedAll() ([]byte, error)
Reset() error
}
type bufferedReadCloser struct {
buffer []byte
inner io.ReadCloser
finished bool
}
func (b *bufferedReadCloser) Read(p []byte) (int, error) {
n, err := b.inner.Read(p)
if n > 0 {
b.buffer = append(b.buffer, p[0:n]...)
}
if err == io.EOF {
b.finished = true
}
return n, err
mode brcMode
off int
}
func NewBufferedReadCloser(sub io.ReadCloser) BufferedReadCloser {
return &bufferedReadCloser{
buffer: make([]byte, 0, 1024),
inner: sub,
finished: false,
mode: modeSourceReading,
off: 0,
}
}
func (b *bufferedReadCloser) Read(p []byte) (int, error) {
switch b.mode {
case modeSourceReading:
n, err := b.inner.Read(p)
if n > 0 {
b.buffer = append(b.buffer, p[0:n]...)
}
if err == io.EOF {
b.mode = modeSourceFinished
}
return n, err
case modeSourceFinished:
return 0, io.EOF
case modeBufferReading:
if len(b.buffer) <= b.off {
b.mode = modeBufferFinished
if len(p) == 0 {
return 0, nil
}
return 0, io.EOF
}
n := copy(p, b.buffer[b.off:])
b.off += n
return n, nil
case modeBufferFinished:
return 0, io.EOF
default:
return 0, errors.New("object in undefined status")
}
}
func (b *bufferedReadCloser) Close() error {
err := b.inner.Close()
switch b.mode {
case modeSourceReading:
_, err := b.BufferedAll()
if err != nil {
b.finished = true
}
return err
}
err = b.inner.Close()
if err != nil {
return err
}
b.mode = modeSourceFinished
return nil
case modeSourceFinished:
return nil
case modeBufferReading:
b.mode = modeBufferFinished
return nil
case modeBufferFinished:
return nil
default:
return errors.New("object in undefined status")
}
}
func (b *bufferedReadCloser) BufferedAll() ([]byte, error) {
switch b.mode {
case modeSourceReading:
arr := make([]byte, 1024)
for !b.finished {
for b.mode == modeSourceReading {
_, err := b.Read(arr)
if err != nil && err != io.EOF {
return nil, err
}
}
return b.buffer, nil
case modeSourceFinished:
return b.buffer, nil
case modeBufferReading:
return b.buffer, nil
case modeBufferFinished:
return b.buffer, nil
default:
return nil, errors.New("object in undefined status")
}
}
func (b *bufferedReadCloser) Reset() error {
switch b.mode {
case modeSourceReading:
fallthrough
case modeSourceFinished:
err := b.Close()
if err != nil {
return err
}
b.mode = modeBufferReading
b.off = 0
return nil
case modeBufferReading:
fallthrough
case modeBufferFinished:
b.mode = modeBufferReading
b.off = 0
return nil
default:
return errors.New("object in undefined status")
}
}

View File

@@ -19,40 +19,38 @@ import (
// There are also a bunch of unit tests to ensure that the cache is always in a consistent state
//
type LRUData interface{}
type LRUMap struct {
type LRUMap[TKey comparable, TData any] struct {
maxsize int
lock sync.Mutex
cache map[string]*cacheNode
cache map[TKey]*cacheNode[TKey, TData]
lfuHead *cacheNode
lfuTail *cacheNode
lfuHead *cacheNode[TKey, TData]
lfuTail *cacheNode[TKey, TData]
}
type cacheNode struct {
key string
data LRUData
parent *cacheNode
child *cacheNode
type cacheNode[TKey comparable, TData any] struct {
key TKey
data TData
parent *cacheNode[TKey, TData]
child *cacheNode[TKey, TData]
}
func NewLRUMap(size int) *LRUMap {
func NewLRUMap[TKey comparable, TData any](size int) *LRUMap[TKey, TData] {
if size <= 2 && size != 0 {
panic("Size must be > 2 (or 0)")
}
return &LRUMap{
return &LRUMap[TKey, TData]{
maxsize: size,
lock: sync.Mutex{},
cache: make(map[string]*cacheNode, size+1),
cache: make(map[TKey]*cacheNode[TKey, TData], size+1),
lfuHead: nil,
lfuTail: nil,
}
}
func (c *LRUMap) Put(key string, value LRUData) {
func (c *LRUMap[TKey, TData]) Put(key TKey, value TData) {
if c.maxsize == 0 {
return // cache disabled
}
@@ -70,7 +68,7 @@ func (c *LRUMap) Put(key string, value LRUData) {
}
// key does not exist: insert into map and add to top of LFU
node = &cacheNode{
node = &cacheNode[TKey, TData]{
key: key,
data: value,
parent: nil,
@@ -95,9 +93,9 @@ func (c *LRUMap) Put(key string, value LRUData) {
}
}
func (c *LRUMap) TryGet(key string) (LRUData, bool) {
func (c *LRUMap[TKey, TData]) TryGet(key TKey) (TData, bool) {
if c.maxsize == 0 {
return nil, false // cache disabled
return *new(TData), false // cache disabled
}
c.lock.Lock()
@@ -105,13 +103,13 @@ func (c *LRUMap) TryGet(key string) (LRUData, bool) {
val, ok := c.cache[key]
if !ok {
return nil, false
return *new(TData), false
}
c.moveNodeToTop(val)
return val.data, ok
}
func (c *LRUMap) moveNodeToTop(node *cacheNode) {
func (c *LRUMap[TKey, TData]) moveNodeToTop(node *cacheNode[TKey, TData]) {
// (only called in critical section !)
if c.lfuHead == node { // fast case
@@ -144,7 +142,7 @@ func (c *LRUMap) moveNodeToTop(node *cacheNode) {
}
}
func (c *LRUMap) Size() int {
func (c *LRUMap[TKey, TData]) Size() int {
c.lock.Lock()
defer c.lock.Unlock()
return len(c.cache)

View File

@@ -12,7 +12,7 @@ func init() {
}
func TestResultCache1(t *testing.T) {
cache := NewLRUMap(8)
cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t)
key := randomKey()
@@ -39,7 +39,7 @@ func TestResultCache1(t *testing.T) {
if !ok {
t.Errorf("cache TryGet returned no value")
}
if !eq(cacheval, val) {
if cacheval != val {
t.Errorf("cache TryGet returned different value (%+v <> %+v)", cacheval, val)
}
@@ -50,7 +50,7 @@ func TestResultCache1(t *testing.T) {
}
func TestResultCache2(t *testing.T) {
cache := NewLRUMap(8)
cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t)
key1 := "key1"
@@ -150,7 +150,7 @@ func TestResultCache2(t *testing.T) {
}
func TestResultCache3(t *testing.T) {
cache := NewLRUMap(8)
cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t)
key1 := "key1"
@@ -160,20 +160,20 @@ func TestResultCache3(t *testing.T) {
cache.Put(key1, val1)
verifyLRUList(cache, t)
if val, ok := cache.TryGet(key1); !ok || !eq(val, val1) {
if val, ok := cache.TryGet(key1); !ok || val != val1 {
t.Errorf("Value in cache should be [val1]")
}
cache.Put(key1, val2)
verifyLRUList(cache, t)
if val, ok := cache.TryGet(key1); !ok || !eq(val, val2) {
if val, ok := cache.TryGet(key1); !ok || val != val2 {
t.Errorf("Value in cache should be [val2]")
}
}
// does a basic consistency check over the internal cache representation
func verifyLRUList(cache *LRUMap, t *testing.T) {
func verifyLRUList[TKey comparable, TData any](cache *LRUMap[TKey, TData], t *testing.T) {
size := 0
tailFound := false
@@ -250,23 +250,10 @@ func randomKey() string {
return strconv.FormatInt(rand.Int63(), 16)
}
func randomVal() LRUData {
func randomVal() string {
v, err := langext.NewHexUUID()
if err != nil {
panic(err)
}
return &v
}
func eq(a LRUData, b LRUData) bool {
v1, ok1 := a.(*string)
v2, ok2 := b.(*string)
if ok1 && ok2 {
if v1 == nil || v2 == nil {
return false
}
return v1 == v2
}
return false
return v
}

View File

@@ -2,6 +2,7 @@ package dataext
import (
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
@@ -43,10 +44,10 @@ func TestObjectMerge(t *testing.T) {
valueMerge := ObjectMerge(valueA, valueB)
assertPtrEqual(t, "Field1", valueMerge.Field1, valueB.Field1)
assertPtrEqual(t, "Field2", valueMerge.Field2, valueA.Field2)
assertPtrEqual(t, "Field3", valueMerge.Field3, valueB.Field3)
assertPtrEqual(t, "Field4", valueMerge.Field4, nil)
tst.AssertIdentPtrEqual(t, "Field1", valueMerge.Field1, valueB.Field1)
tst.AssertIdentPtrEqual(t, "Field2", valueMerge.Field2, valueA.Field2)
tst.AssertIdentPtrEqual(t, "Field3", valueMerge.Field3, valueB.Field3)
tst.AssertIdentPtrEqual(t, "Field4", valueMerge.Field4, nil)
}

116
dataext/stack.go Normal file
View File

@@ -0,0 +1,116 @@
package dataext
import (
"errors"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"sync"
)
var ErrEmptyStack = errors.New("stack is empty")
type Stack[T any] struct {
lock *sync.Mutex
data []T
}
func NewStack[T any](threadsafe bool, initialCapacity int) *Stack[T] {
var lck *sync.Mutex = nil
if threadsafe {
lck = &sync.Mutex{}
}
return &Stack[T]{
lock: lck,
data: make([]T, 0, initialCapacity),
}
}
func (s *Stack[T]) Push(v T) {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
s.data = append(s.data, v)
}
func (s *Stack[T]) Pop() (T, error) {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
l := len(s.data)
if l == 0 {
return *new(T), ErrEmptyStack
}
result := s.data[l-1]
s.data = s.data[:l-1]
return result, nil
}
func (s *Stack[T]) OptPop() *T {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
l := len(s.data)
if l == 0 {
return nil
}
result := s.data[l-1]
s.data = s.data[:l-1]
return langext.Ptr(result)
}
func (s *Stack[T]) Peek() (T, error) {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
l := len(s.data)
if l == 0 {
return *new(T), ErrEmptyStack
}
return s.data[l-1], nil
}
func (s *Stack[T]) OptPeek() *T {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
l := len(s.data)
if l == 0 {
return nil
}
return langext.Ptr(s.data[l-1])
}
func (s *Stack[T]) Length() int {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
return len(s.data)
}
func (s *Stack[T]) Empty() bool {
if s.lock != nil {
s.lock.Lock()
defer s.lock.Unlock()
}
return len(s.data) == 0
}

254
dataext/structHash.go Normal file
View File

@@ -0,0 +1,254 @@
package dataext
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"hash"
"io"
"reflect"
"sort"
)
type StructHashOptions struct {
HashAlgo hash.Hash
Tag *string
SkipChannel bool
SkipFunc bool
}
func StructHash(dat any, opt ...StructHashOptions) (r []byte, err error) {
defer func() {
if rec := recover(); rec != nil {
r = nil
err = errors.New(fmt.Sprintf("recovered panic: %v", rec))
}
}()
shopt := StructHashOptions{}
if len(opt) > 1 {
return nil, errors.New("multiple options supplied")
} else if len(opt) == 1 {
shopt = opt[0]
}
if shopt.HashAlgo == nil {
shopt.HashAlgo = sha256.New()
}
writer := new(bytes.Buffer)
if langext.IsNil(dat) {
shopt.HashAlgo.Reset()
shopt.HashAlgo.Write(writer.Bytes())
res := shopt.HashAlgo.Sum(nil)
return res, nil
}
err = binarize(writer, reflect.ValueOf(dat), shopt)
if err != nil {
return nil, err
}
shopt.HashAlgo.Reset()
shopt.HashAlgo.Write(writer.Bytes())
res := shopt.HashAlgo.Sum(nil)
return res, nil
}
func writeBinarized(writer io.Writer, dat any) error {
tmp := bytes.Buffer{}
err := binary.Write(&tmp, binary.LittleEndian, dat)
if err != nil {
return err
}
err = binary.Write(writer, binary.LittleEndian, uint64(tmp.Len()))
if err != nil {
return err
}
_, err = writer.Write(tmp.Bytes())
if err != nil {
return err
}
return nil
}
func binarize(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
var err error
err = binary.Write(writer, binary.LittleEndian, uint8(dat.Kind()))
switch dat.Kind() {
case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice, reflect.Interface:
if dat.IsNil() {
err = binary.Write(writer, binary.LittleEndian, uint64(0))
if err != nil {
return err
}
return nil
}
}
err = binary.Write(writer, binary.LittleEndian, uint64(len(dat.Type().String())))
if err != nil {
return err
}
_, err = writer.Write([]byte(dat.Type().String()))
if err != nil {
return err
}
switch dat.Type().Kind() {
case reflect.Invalid:
return errors.New("cannot binarize value of kind <Invalid>")
case reflect.Bool:
return writeBinarized(writer, dat.Bool())
case reflect.Int:
return writeBinarized(writer, int64(dat.Int()))
case reflect.Int8:
fallthrough
case reflect.Int16:
fallthrough
case reflect.Int32:
fallthrough
case reflect.Int64:
return writeBinarized(writer, dat.Interface())
case reflect.Uint:
return writeBinarized(writer, uint64(dat.Int()))
case reflect.Uint8:
fallthrough
case reflect.Uint16:
fallthrough
case reflect.Uint32:
fallthrough
case reflect.Uint64:
return writeBinarized(writer, dat.Interface())
case reflect.Uintptr:
return errors.New("cannot binarize value of kind <Uintptr>")
case reflect.Float32:
fallthrough
case reflect.Float64:
return writeBinarized(writer, dat.Interface())
case reflect.Complex64:
return errors.New("cannot binarize value of kind <Complex64>")
case reflect.Complex128:
return errors.New("cannot binarize value of kind <Complex128>")
case reflect.Slice:
fallthrough
case reflect.Array:
return binarizeArrayOrSlice(writer, dat, opt)
case reflect.Chan:
if opt.SkipChannel {
return nil
}
return errors.New("cannot binarize value of kind <Chan>")
case reflect.Func:
if opt.SkipFunc {
return nil
}
return errors.New("cannot binarize value of kind <Func>")
case reflect.Interface:
return binarize(writer, dat.Elem(), opt)
case reflect.Map:
return binarizeMap(writer, dat, opt)
case reflect.Pointer:
return binarize(writer, dat.Elem(), opt)
case reflect.String:
v := dat.String()
err = binary.Write(writer, binary.LittleEndian, uint64(len(v)))
if err != nil {
return err
}
_, err = writer.Write([]byte(v))
if err != nil {
return err
}
return nil
case reflect.Struct:
return binarizeStruct(writer, dat, opt)
case reflect.UnsafePointer:
return errors.New("cannot binarize value of kind <UnsafePointer>")
default:
return errors.New("cannot binarize value of unknown kind <" + dat.Type().Kind().String() + ">")
}
}
func binarizeStruct(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
err := binary.Write(writer, binary.LittleEndian, uint64(dat.NumField()))
if err != nil {
return err
}
for i := 0; i < dat.NumField(); i++ {
if opt.Tag != nil {
if _, ok := dat.Type().Field(i).Tag.Lookup(*opt.Tag); !ok {
continue
}
}
err = binary.Write(writer, binary.LittleEndian, uint64(len(dat.Type().Field(i).Name)))
if err != nil {
return err
}
_, err = writer.Write([]byte(dat.Type().Field(i).Name))
if err != nil {
return err
}
err = binarize(writer, dat.Field(i), opt)
if err != nil {
return err
}
}
return nil
}
func binarizeArrayOrSlice(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
err := binary.Write(writer, binary.LittleEndian, uint64(dat.Len()))
if err != nil {
return err
}
for i := 0; i < dat.Len(); i++ {
err := binarize(writer, dat.Index(i), opt)
if err != nil {
return err
}
}
return nil
}
func binarizeMap(writer io.Writer, dat reflect.Value, opt StructHashOptions) error {
err := binary.Write(writer, binary.LittleEndian, uint64(dat.Len()))
if err != nil {
return err
}
sub := make([][]byte, 0, dat.Len())
for _, k := range dat.MapKeys() {
tmp := bytes.Buffer{}
err = binarize(&tmp, dat.MapIndex(k), opt)
if err != nil {
return err
}
sub = append(sub, tmp.Bytes())
}
sort.Slice(sub, func(i1, i2 int) bool { return bytes.Compare(sub[i1], sub[i2]) < 0 })
for _, v := range sub {
_, err = writer.Write(v)
if err != nil {
return err
}
}
return nil
}

136
dataext/structHash_test.go Normal file
View File

@@ -0,0 +1,136 @@
package dataext
import (
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
)
func noErrStructHash(t *testing.T, dat any, opt ...StructHashOptions) []byte {
res, err := StructHash(dat, opt...)
if err != nil {
t.Error(err)
t.FailNow()
return nil
}
return res
}
func TestStructHashSimple(t *testing.T) {
tst.AssertHexEqual(t, "209bf774af36cc3a045c152d9f1269ef3684ad819c1359ee73ff0283a308fefa", noErrStructHash(t, "Hello"))
tst.AssertHexEqual(t, "c32f3626b981ae2997db656f3acad3f1dc9d30ef6b6d14296c023e391b25f71a", noErrStructHash(t, 0))
tst.AssertHexEqual(t, "01b781b03e9586b257d387057dfc70d9f06051e7d3c1e709a57e13cc8daf3e35", noErrStructHash(t, []byte{}))
tst.AssertHexEqual(t, "93e1dcd45c732fe0079b0fb3204c7c803f0921835f6bfee2e6ff263e73eed53c", noErrStructHash(t, []int{}))
tst.AssertHexEqual(t, "54f637a376aad55b3160d98ebbcae8099b70d91b9400df23fb3709855d59800a", noErrStructHash(t, []int{1, 2, 3}))
tst.AssertHexEqual(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", noErrStructHash(t, nil))
tst.AssertHexEqual(t, "349a7db91aa78fd30bbaa7c7f9c7bfb2fcfe72869b4861162a96713a852f60d3", noErrStructHash(t, []any{1, "", nil}))
tst.AssertHexEqual(t, "ca51aab87808bf0062a4a024de6aac0c2bad54275cc857a4944569f89fd245ad", noErrStructHash(t, struct{}{}))
}
func TestStructHashSimpleStruct(t *testing.T) {
type t0 struct {
F1 int
F2 []string
F3 *int
}
tst.AssertHexEqual(t, "a90bff751c70c738bb5cfc9b108e783fa9c19c0bc9273458e0aaee6e74aa1b92", noErrStructHash(t, t0{
F1: 10,
F2: []string{"1", "2", "3"},
F3: nil,
}))
tst.AssertHexEqual(t, "5d09090dc34ac59dd645f197a255f653387723de3afa1b614721ea5a081c675f", noErrStructHash(t, t0{
F1: 10,
F2: []string{"1", "2", "3"},
F3: langext.Ptr(99),
}))
}
func TestStructHashLayeredStruct(t *testing.T) {
type t1_1 struct {
F10 float32
F12 float64
F15 bool
}
type t1_2 struct {
SV1 *t1_1
SV2 *t1_1
SV3 t1_1
}
tst.AssertHexEqual(t, "fd4ca071fb40a288fee4b7a3dfdaab577b30cb8f80f81ec511e7afd72dc3b469", noErrStructHash(t, t1_2{
SV1: nil,
SV2: nil,
SV3: t1_1{
F10: 1,
F12: 2,
F15: false,
},
}))
tst.AssertHexEqual(t, "3fbf7c67d8121deda075cc86319a4e32d71744feb2cebf89b43bc682f072a029", noErrStructHash(t, t1_2{
SV1: nil,
SV2: &t1_1{},
SV3: t1_1{
F10: 3,
F12: 4,
F15: true,
},
}))
tst.AssertHexEqual(t, "b1791ccd1b346c3ede5bbffda85555adcd8216b93ffca23f14fe175ec47c5104", noErrStructHash(t, t1_2{
SV1: &t1_1{},
SV2: &t1_1{},
SV3: t1_1{
F10: 5,
F12: 6,
F15: false,
},
}))
}
func TestStructHashMap(t *testing.T) {
type t0 struct {
F1 int
F2 map[string]int
}
tst.AssertHexEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{
F1: 10,
F2: map[string]int{
"x": 1,
"0": 2,
"a": 99,
},
}))
tst.AssertHexEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{
F1: 10,
F2: map[string]int{
"a": 99,
"x": 1,
"0": 2,
},
}))
m3 := make(map[string]int, 99)
m3["a"] = 0
m3["x"] = 0
m3["0"] = 0
m3["0"] = 99
m3["x"] = 1
m3["a"] = 2
tst.AssertHexEqual(t, "d50c53ad1fafb448c33fddd5aca01a86a2edf669ce2ecab07ba6fe877951d824", noErrStructHash(t, t0{
F1: 10,
F2: m3,
}))
}

View File

@@ -2,17 +2,17 @@ package dataext
import "sync"
type SyncStringSet struct {
data map[string]bool
type SyncSet[TData comparable] struct {
data map[TData]bool
lock sync.Mutex
}
func (s *SyncStringSet) Add(value string) bool {
func (s *SyncSet[TData]) Add(value TData) bool {
s.lock.Lock()
defer s.lock.Unlock()
if s.data == nil {
s.data = make(map[string]bool)
s.data = make(map[TData]bool)
}
_, ok := s.data[value]
@@ -21,12 +21,12 @@ func (s *SyncStringSet) Add(value string) bool {
return !ok
}
func (s *SyncStringSet) AddAll(values []string) {
func (s *SyncSet[TData]) AddAll(values []TData) {
s.lock.Lock()
defer s.lock.Unlock()
if s.data == nil {
s.data = make(map[string]bool)
s.data = make(map[TData]bool)
}
for _, value := range values {
@@ -34,12 +34,12 @@ func (s *SyncStringSet) AddAll(values []string) {
}
}
func (s *SyncStringSet) Contains(value string) bool {
func (s *SyncSet[TData]) Contains(value TData) bool {
s.lock.Lock()
defer s.lock.Unlock()
if s.data == nil {
s.data = make(map[string]bool)
s.data = make(map[TData]bool)
}
_, ok := s.data[value]
@@ -47,15 +47,15 @@ func (s *SyncStringSet) Contains(value string) bool {
return ok
}
func (s *SyncStringSet) Get() []string {
func (s *SyncSet[TData]) Get() []TData {
s.lock.Lock()
defer s.lock.Unlock()
if s.data == nil {
s.data = make(map[string]bool)
s.data = make(map[TData]bool)
}
r := make([]string, 0, len(s.data))
r := make([]TData, 0, len(s.data))
for k := range s.data {
r = append(r, k)

10
go.mod
View File

@@ -3,6 +3,12 @@ module gogs.mikescher.com/BlackForestBytes/goext
go 1.19
require (
golang.org/x/sys v0.1.0
golang.org/x/term v0.1.0
golang.org/x/sys v0.3.0
golang.org/x/term v0.3.0
)
require (
github.com/jmoiron/sqlx v1.3.5 // indirect
go.mongodb.org/mongo-driver v1.11.1 // indirect
golang.org/x/crypto v0.4.0 // indirect
)

46
go.sum
View File

@@ -1,4 +1,50 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
go.mongodb.org/mongo-driver v1.11.1 h1:QP0znIRTuL0jf1oBQoAoM0C6ZJfBK4kx0Uumtv1A7w8=
go.mongodb.org/mongo-driver v1.11.1/go.mod h1:s7p5vEtfbeR1gYi6pnj3c3/urpbLv2T5Sfd6Rp2HBB8=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

27
gojson/LICENSE Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

12
gojson/README.md Normal file
View File

@@ -0,0 +1,12 @@
JSON serializer which serializes nil-Arrays as `[]` and nil-maps als `{}`.
Idea from: https://github.com/homelight/json
Forked from https://github.com/golang/go/tree/547e8e22fe565d65d1fd4d6e71436a5a855447b0/src/encoding/json ( tag go1.20.2 )
Added:
- `MarshalSafeCollections()` method
- `Encoder.nilSafeSlices` and `Encoder.nilSafeMaps` fields

1311
gojson/decode.go Normal file

File diff suppressed because it is too large Load Diff

2574
gojson/decode_test.go Normal file

File diff suppressed because it is too large Load Diff

1459
gojson/encode.go Normal file

File diff suppressed because it is too large Load Diff

1285
gojson/encode_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,73 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json_test
import (
"encoding/json"
"fmt"
"log"
"strings"
)
type Animal int
const (
Unknown Animal = iota
Gopher
Zebra
)
func (a *Animal) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
switch strings.ToLower(s) {
default:
*a = Unknown
case "gopher":
*a = Gopher
case "zebra":
*a = Zebra
}
return nil
}
func (a Animal) MarshalJSON() ([]byte, error) {
var s string
switch a {
default:
s = "unknown"
case Gopher:
s = "gopher"
case Zebra:
s = "zebra"
}
return json.Marshal(s)
}
func Example_customMarshalJSON() {
blob := `["gopher","armadillo","zebra","unknown","gopher","bee","gopher","zebra"]`
var zoo []Animal
if err := json.Unmarshal([]byte(blob), &zoo); err != nil {
log.Fatal(err)
}
census := make(map[Animal]int)
for _, animal := range zoo {
census[animal] += 1
}
fmt.Printf("Zoo Census:\n* Gophers: %d\n* Zebras: %d\n* Unknown: %d\n",
census[Gopher], census[Zebra], census[Unknown])
// Output:
// Zoo Census:
// * Gophers: 3
// * Zebras: 2
// * Unknown: 3
}

310
gojson/example_test.go Normal file
View File

@@ -0,0 +1,310 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json_test
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"os"
"strings"
)
func ExampleMarshal() {
type ColorGroup struct {
ID int
Name string
Colors []string
}
group := ColorGroup{
ID: 1,
Name: "Reds",
Colors: []string{"Crimson", "Red", "Ruby", "Maroon"},
}
b, err := json.Marshal(group)
if err != nil {
fmt.Println("error:", err)
}
os.Stdout.Write(b)
// Output:
// {"ID":1,"Name":"Reds","Colors":["Crimson","Red","Ruby","Maroon"]}
}
func ExampleUnmarshal() {
var jsonBlob = []byte(`[
{"Name": "Platypus", "Order": "Monotremata"},
{"Name": "Quoll", "Order": "Dasyuromorphia"}
]`)
type Animal struct {
Name string
Order string
}
var animals []Animal
err := json.Unmarshal(jsonBlob, &animals)
if err != nil {
fmt.Println("error:", err)
}
fmt.Printf("%+v", animals)
// Output:
// [{Name:Platypus Order:Monotremata} {Name:Quoll Order:Dasyuromorphia}]
}
// This example uses a Decoder to decode a stream of distinct JSON values.
func ExampleDecoder() {
const jsonStream = `
{"Name": "Ed", "Text": "Knock knock."}
{"Name": "Sam", "Text": "Who's there?"}
{"Name": "Ed", "Text": "Go fmt."}
{"Name": "Sam", "Text": "Go fmt who?"}
{"Name": "Ed", "Text": "Go fmt yourself!"}
`
type Message struct {
Name, Text string
}
dec := json.NewDecoder(strings.NewReader(jsonStream))
for {
var m Message
if err := dec.Decode(&m); err == io.EOF {
break
} else if err != nil {
log.Fatal(err)
}
fmt.Printf("%s: %s\n", m.Name, m.Text)
}
// Output:
// Ed: Knock knock.
// Sam: Who's there?
// Ed: Go fmt.
// Sam: Go fmt who?
// Ed: Go fmt yourself!
}
// This example uses a Decoder to decode a stream of distinct JSON values.
func ExampleDecoder_Token() {
const jsonStream = `
{"Message": "Hello", "Array": [1, 2, 3], "Null": null, "Number": 1.234}
`
dec := json.NewDecoder(strings.NewReader(jsonStream))
for {
t, err := dec.Token()
if err == io.EOF {
break
}
if err != nil {
log.Fatal(err)
}
fmt.Printf("%T: %v", t, t)
if dec.More() {
fmt.Printf(" (more)")
}
fmt.Printf("\n")
}
// Output:
// json.Delim: { (more)
// string: Message (more)
// string: Hello (more)
// string: Array (more)
// json.Delim: [ (more)
// float64: 1 (more)
// float64: 2 (more)
// float64: 3
// json.Delim: ] (more)
// string: Null (more)
// <nil>: <nil> (more)
// string: Number (more)
// float64: 1.234
// json.Delim: }
}
// This example uses a Decoder to decode a streaming array of JSON objects.
func ExampleDecoder_Decode_stream() {
const jsonStream = `
[
{"Name": "Ed", "Text": "Knock knock."},
{"Name": "Sam", "Text": "Who's there?"},
{"Name": "Ed", "Text": "Go fmt."},
{"Name": "Sam", "Text": "Go fmt who?"},
{"Name": "Ed", "Text": "Go fmt yourself!"}
]
`
type Message struct {
Name, Text string
}
dec := json.NewDecoder(strings.NewReader(jsonStream))
// read open bracket
t, err := dec.Token()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%T: %v\n", t, t)
// while the array contains values
for dec.More() {
var m Message
// decode an array value (Message)
err := dec.Decode(&m)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%v: %v\n", m.Name, m.Text)
}
// read closing bracket
t, err = dec.Token()
if err != nil {
log.Fatal(err)
}
fmt.Printf("%T: %v\n", t, t)
// Output:
// json.Delim: [
// Ed: Knock knock.
// Sam: Who's there?
// Ed: Go fmt.
// Sam: Go fmt who?
// Ed: Go fmt yourself!
// json.Delim: ]
}
// This example uses RawMessage to delay parsing part of a JSON message.
func ExampleRawMessage_unmarshal() {
type Color struct {
Space string
Point json.RawMessage // delay parsing until we know the color space
}
type RGB struct {
R uint8
G uint8
B uint8
}
type YCbCr struct {
Y uint8
Cb int8
Cr int8
}
var j = []byte(`[
{"Space": "YCbCr", "Point": {"Y": 255, "Cb": 0, "Cr": -10}},
{"Space": "RGB", "Point": {"R": 98, "G": 218, "B": 255}}
]`)
var colors []Color
err := json.Unmarshal(j, &colors)
if err != nil {
log.Fatalln("error:", err)
}
for _, c := range colors {
var dst any
switch c.Space {
case "RGB":
dst = new(RGB)
case "YCbCr":
dst = new(YCbCr)
}
err := json.Unmarshal(c.Point, dst)
if err != nil {
log.Fatalln("error:", err)
}
fmt.Println(c.Space, dst)
}
// Output:
// YCbCr &{255 0 -10}
// RGB &{98 218 255}
}
// This example uses RawMessage to use a precomputed JSON during marshal.
func ExampleRawMessage_marshal() {
h := json.RawMessage(`{"precomputed": true}`)
c := struct {
Header *json.RawMessage `json:"header"`
Body string `json:"body"`
}{Header: &h, Body: "Hello Gophers!"}
b, err := json.MarshalIndent(&c, "", "\t")
if err != nil {
fmt.Println("error:", err)
}
os.Stdout.Write(b)
// Output:
// {
// "header": {
// "precomputed": true
// },
// "body": "Hello Gophers!"
// }
}
func ExampleIndent() {
type Road struct {
Name string
Number int
}
roads := []Road{
{"Diamond Fork", 29},
{"Sheep Creek", 51},
}
b, err := json.Marshal(roads)
if err != nil {
log.Fatal(err)
}
var out bytes.Buffer
json.Indent(&out, b, "=", "\t")
out.WriteTo(os.Stdout)
// Output:
// [
// = {
// = "Name": "Diamond Fork",
// = "Number": 29
// = },
// = {
// = "Name": "Sheep Creek",
// = "Number": 51
// = }
// =]
}
func ExampleMarshalIndent() {
data := map[string]int{
"a": 1,
"b": 2,
}
b, err := json.MarshalIndent(data, "<prefix>", "<indent>")
if err != nil {
log.Fatal(err)
}
fmt.Println(string(b))
// Output:
// {
// <prefix><indent>"a": 1,
// <prefix><indent>"b": 2
// <prefix>}
}
func ExampleValid() {
goodJSON := `{"example": 1}`
badJSON := `{"example":2:]}}`
fmt.Println(json.Valid([]byte(goodJSON)), json.Valid([]byte(badJSON)))
// Output:
// true false
}
func ExampleHTMLEscape() {
var out bytes.Buffer
json.HTMLEscape(&out, []byte(`{"Name":"<b>HTML content</b>"}`))
out.WriteTo(os.Stdout)
// Output:
//{"Name":"\u003cb\u003eHTML content\u003c/b\u003e"}
}

View File

@@ -0,0 +1,67 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json_test
import (
"encoding/json"
"fmt"
"log"
"strings"
)
type Size int
const (
Unrecognized Size = iota
Small
Large
)
func (s *Size) UnmarshalText(text []byte) error {
switch strings.ToLower(string(text)) {
default:
*s = Unrecognized
case "small":
*s = Small
case "large":
*s = Large
}
return nil
}
func (s Size) MarshalText() ([]byte, error) {
var name string
switch s {
default:
name = "unrecognized"
case Small:
name = "small"
case Large:
name = "large"
}
return []byte(name), nil
}
func Example_textMarshalJSON() {
blob := `["small","regular","large","unrecognized","small","normal","small","large"]`
var inventory []Size
if err := json.Unmarshal([]byte(blob), &inventory); err != nil {
log.Fatal(err)
}
counts := make(map[Size]int)
for _, size := range inventory {
counts[size] += 1
}
fmt.Printf("Inventory Counts:\n* Small: %d\n* Large: %d\n* Unrecognized: %d\n",
counts[Small], counts[Large], counts[Unrecognized])
// Output:
// Inventory Counts:
// * Small: 3
// * Large: 2
// * Unrecognized: 3
}

141
gojson/fold.go Normal file
View File

@@ -0,0 +1,141 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"unicode/utf8"
)
const (
caseMask = ^byte(0x20) // Mask to ignore case in ASCII.
kelvin = '\u212a'
smallLongEss = '\u017f'
)
// foldFunc returns one of four different case folding equivalence
// functions, from most general (and slow) to fastest:
//
// 1) bytes.EqualFold, if the key s contains any non-ASCII UTF-8
// 2) equalFoldRight, if s contains special folding ASCII ('k', 'K', 's', 'S')
// 3) asciiEqualFold, no special, but includes non-letters (including _)
// 4) simpleLetterEqualFold, no specials, no non-letters.
//
// The letters S and K are special because they map to 3 runes, not just 2:
// - S maps to s and to U+017F 'ſ' Latin small letter long s
// - k maps to K and to U+212A '' Kelvin sign
//
// See https://play.golang.org/p/tTxjOc0OGo
//
// The returned function is specialized for matching against s and
// should only be given s. It's not curried for performance reasons.
func foldFunc(s []byte) func(s, t []byte) bool {
nonLetter := false
special := false // special letter
for _, b := range s {
if b >= utf8.RuneSelf {
return bytes.EqualFold
}
upper := b & caseMask
if upper < 'A' || upper > 'Z' {
nonLetter = true
} else if upper == 'K' || upper == 'S' {
// See above for why these letters are special.
special = true
}
}
if special {
return equalFoldRight
}
if nonLetter {
return asciiEqualFold
}
return simpleLetterEqualFold
}
// equalFoldRight is a specialization of bytes.EqualFold when s is
// known to be all ASCII (including punctuation), but contains an 's',
// 'S', 'k', or 'K', requiring a Unicode fold on the bytes in t.
// See comments on foldFunc.
func equalFoldRight(s, t []byte) bool {
for _, sb := range s {
if len(t) == 0 {
return false
}
tb := t[0]
if tb < utf8.RuneSelf {
if sb != tb {
sbUpper := sb & caseMask
if 'A' <= sbUpper && sbUpper <= 'Z' {
if sbUpper != tb&caseMask {
return false
}
} else {
return false
}
}
t = t[1:]
continue
}
// sb is ASCII and t is not. t must be either kelvin
// sign or long s; sb must be s, S, k, or K.
tr, size := utf8.DecodeRune(t)
switch sb {
case 's', 'S':
if tr != smallLongEss {
return false
}
case 'k', 'K':
if tr != kelvin {
return false
}
default:
return false
}
t = t[size:]
}
return len(t) == 0
}
// asciiEqualFold is a specialization of bytes.EqualFold for use when
// s is all ASCII (but may contain non-letters) and contains no
// special-folding letters.
// See comments on foldFunc.
func asciiEqualFold(s, t []byte) bool {
if len(s) != len(t) {
return false
}
for i, sb := range s {
tb := t[i]
if sb == tb {
continue
}
if ('a' <= sb && sb <= 'z') || ('A' <= sb && sb <= 'Z') {
if sb&caseMask != tb&caseMask {
return false
}
} else {
return false
}
}
return true
}
// simpleLetterEqualFold is a specialization of bytes.EqualFold for
// use when s is all ASCII letters (no underscores, etc) and also
// doesn't contain 'k', 'K', 's', or 'S'.
// See comments on foldFunc.
func simpleLetterEqualFold(s, t []byte) bool {
if len(s) != len(t) {
return false
}
for i, b := range s {
if b&caseMask != t[i]&caseMask {
return false
}
}
return true
}

110
gojson/fold_test.go Normal file
View File

@@ -0,0 +1,110 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"strings"
"testing"
"unicode/utf8"
)
var foldTests = []struct {
fn func(s, t []byte) bool
s, t string
want bool
}{
{equalFoldRight, "", "", true},
{equalFoldRight, "a", "a", true},
{equalFoldRight, "", "a", false},
{equalFoldRight, "a", "", false},
{equalFoldRight, "a", "A", true},
{equalFoldRight, "AB", "ab", true},
{equalFoldRight, "AB", "ac", false},
{equalFoldRight, "sbkKc", "ſbKc", true},
{equalFoldRight, "SbKkc", "ſbKc", true},
{equalFoldRight, "SbKkc", "ſbKK", false},
{equalFoldRight, "e", "é", false},
{equalFoldRight, "s", "S", true},
{simpleLetterEqualFold, "", "", true},
{simpleLetterEqualFold, "abc", "abc", true},
{simpleLetterEqualFold, "abc", "ABC", true},
{simpleLetterEqualFold, "abc", "ABCD", false},
{simpleLetterEqualFold, "abc", "xxx", false},
{asciiEqualFold, "a_B", "A_b", true},
{asciiEqualFold, "aa@", "aa`", false}, // verify 0x40 and 0x60 aren't case-equivalent
}
func TestFold(t *testing.T) {
for i, tt := range foldTests {
if got := tt.fn([]byte(tt.s), []byte(tt.t)); got != tt.want {
t.Errorf("%d. %q, %q = %v; want %v", i, tt.s, tt.t, got, tt.want)
}
truth := strings.EqualFold(tt.s, tt.t)
if truth != tt.want {
t.Errorf("strings.EqualFold doesn't agree with case %d", i)
}
}
}
func TestFoldAgainstUnicode(t *testing.T) {
var buf1, buf2 []byte
var runes []rune
for i := 0x20; i <= 0x7f; i++ {
runes = append(runes, rune(i))
}
runes = append(runes, kelvin, smallLongEss)
funcs := []struct {
name string
fold func(s, t []byte) bool
letter bool // must be ASCII letter
simple bool // must be simple ASCII letter (not 'S' or 'K')
}{
{
name: "equalFoldRight",
fold: equalFoldRight,
},
{
name: "asciiEqualFold",
fold: asciiEqualFold,
simple: true,
},
{
name: "simpleLetterEqualFold",
fold: simpleLetterEqualFold,
simple: true,
letter: true,
},
}
for _, ff := range funcs {
for _, r := range runes {
if r >= utf8.RuneSelf {
continue
}
if ff.letter && !isASCIILetter(byte(r)) {
continue
}
if ff.simple && (r == 's' || r == 'S' || r == 'k' || r == 'K') {
continue
}
for _, r2 := range runes {
buf1 = append(utf8.AppendRune(append(buf1[:0], 'x'), r), 'x')
buf2 = append(utf8.AppendRune(append(buf2[:0], 'x'), r2), 'x')
want := bytes.EqualFold(buf1, buf2)
if got := ff.fold(buf1, buf2); got != want {
t.Errorf("%s(%q, %q) = %v; want %v", ff.name, buf1, buf2, got, want)
}
}
}
}
}
func isASCIILetter(b byte) bool {
return ('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z')
}

42
gojson/fuzz.go Normal file
View File

@@ -0,0 +1,42 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build gofuzz
package json
import (
"fmt"
)
func Fuzz(data []byte) (score int) {
for _, ctor := range []func() any{
func() any { return new(any) },
func() any { return new(map[string]any) },
func() any { return new([]any) },
} {
v := ctor()
err := Unmarshal(data, v)
if err != nil {
continue
}
score = 1
m, err := Marshal(v)
if err != nil {
fmt.Printf("v=%#v\n", v)
panic(err)
}
u := ctor()
err = Unmarshal(m, u)
if err != nil {
fmt.Printf("v=%#v\n", v)
fmt.Printf("m=%s\n", m)
panic(err)
}
}
return
}

83
gojson/fuzz_test.go Normal file
View File

@@ -0,0 +1,83 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"io"
"testing"
)
func FuzzUnmarshalJSON(f *testing.F) {
f.Add([]byte(`{
"object": {
"slice": [
1,
2.0,
"3",
[4],
{5: {}}
]
},
"slice": [[]],
"string": ":)",
"int": 1e5,
"float": 3e-9"
}`))
f.Fuzz(func(t *testing.T, b []byte) {
for _, typ := range []func() interface{}{
func() interface{} { return new(interface{}) },
func() interface{} { return new(map[string]interface{}) },
func() interface{} { return new([]interface{}) },
} {
i := typ()
if err := Unmarshal(b, i); err != nil {
return
}
encoded, err := Marshal(i)
if err != nil {
t.Fatalf("failed to marshal: %s", err)
}
if err := Unmarshal(encoded, i); err != nil {
t.Fatalf("failed to roundtrip: %s", err)
}
}
})
}
func FuzzDecoderToken(f *testing.F) {
f.Add([]byte(`{
"object": {
"slice": [
1,
2.0,
"3",
[4],
{5: {}}
]
},
"slice": [[]],
"string": ":)",
"int": 1e5,
"float": 3e-9"
}`))
f.Fuzz(func(t *testing.T, b []byte) {
r := bytes.NewReader(b)
d := NewDecoder(r)
for {
_, err := d.Token()
if err != nil {
if err == io.EOF {
break
}
return
}
}
})
}

44
gojson/gionic.go Normal file
View File

@@ -0,0 +1,44 @@
package json
import (
"net/http"
)
// Render interface is copied from github.com/gin-gonic/gin@v1.8.1/render/render.go
type Render interface {
// Render writes data with custom ContentType.
Render(http.ResponseWriter) error
// WriteContentType writes custom ContentType.
WriteContentType(w http.ResponseWriter)
}
type GoJsonRender struct {
Data any
NilSafeSlices bool
NilSafeMaps bool
Indent *IndentOpt
}
func (r GoJsonRender) Render(w http.ResponseWriter) error {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = []string{"application/json; charset=utf-8"}
}
jsonBytes, err := MarshalSafeCollections(r.Data, r.NilSafeSlices, r.NilSafeMaps, r.Indent)
if err != nil {
panic(err)
}
_, err = w.Write(jsonBytes)
if err != nil {
panic(err)
}
return nil
}
func (r GoJsonRender) WriteContentType(w http.ResponseWriter) {
header := w.Header()
if val := header["Content-Type"]; len(val) == 0 {
header["Content-Type"] = []string{"application/json; charset=utf-8"}
}
}

143
gojson/indent.go Normal file
View File

@@ -0,0 +1,143 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
)
// Compact appends to dst the JSON-encoded src with
// insignificant space characters elided.
func Compact(dst *bytes.Buffer, src []byte) error {
return compact(dst, src, false)
}
func compact(dst *bytes.Buffer, src []byte, escape bool) error {
origLen := dst.Len()
scan := newScanner()
defer freeScanner(scan)
start := 0
for i, c := range src {
if escape && (c == '<' || c == '>' || c == '&') {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u00`)
dst.WriteByte(hex[c>>4])
dst.WriteByte(hex[c&0xF])
start = i + 1
}
// Convert U+2028 and U+2029 (E2 80 A8 and E2 80 A9).
if escape && c == 0xE2 && i+2 < len(src) && src[i+1] == 0x80 && src[i+2]&^1 == 0xA8 {
if start < i {
dst.Write(src[start:i])
}
dst.WriteString(`\u202`)
dst.WriteByte(hex[src[i+2]&0xF])
start = i + 3
}
v := scan.step(scan, c)
if v >= scanSkipSpace {
if v == scanError {
break
}
if start < i {
dst.Write(src[start:i])
}
start = i + 1
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
if start < len(src) {
dst.Write(src[start:])
}
return nil
}
func newline(dst *bytes.Buffer, prefix, indent string, depth int) {
dst.WriteByte('\n')
dst.WriteString(prefix)
for i := 0; i < depth; i++ {
dst.WriteString(indent)
}
}
// Indent appends to dst an indented form of the JSON-encoded src.
// Each element in a JSON object or array begins on a new,
// indented line beginning with prefix followed by one or more
// copies of indent according to the indentation nesting.
// The data appended to dst does not begin with the prefix nor
// any indentation, to make it easier to embed inside other formatted JSON data.
// Although leading space characters (space, tab, carriage return, newline)
// at the beginning of src are dropped, trailing space characters
// at the end of src are preserved and copied to dst.
// For example, if src has no trailing spaces, neither will dst;
// if src ends in a trailing newline, so will dst.
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
origLen := dst.Len()
scan := newScanner()
defer freeScanner(scan)
needIndent := false
depth := 0
for _, c := range src {
scan.bytes++
v := scan.step(scan, c)
if v == scanSkipSpace {
continue
}
if v == scanError {
break
}
if needIndent && v != scanEndObject && v != scanEndArray {
needIndent = false
depth++
newline(dst, prefix, indent, depth)
}
// Emit semantically uninteresting bytes
// (in particular, punctuation in strings) unmodified.
if v == scanContinue {
dst.WriteByte(c)
continue
}
// Add spacing around real punctuation.
switch c {
case '{', '[':
// delay indent so that empty object and array are formatted as {} and [].
needIndent = true
dst.WriteByte(c)
case ',':
dst.WriteByte(c)
newline(dst, prefix, indent, depth)
case ':':
dst.WriteByte(c)
dst.WriteByte(' ')
case '}', ']':
if needIndent {
// suppress indent in empty object/array
needIndent = false
} else {
depth--
newline(dst, prefix, indent, depth)
}
dst.WriteByte(c)
default:
dst.WriteByte(c)
}
}
if scan.eof() == scanError {
dst.Truncate(origLen)
return scan.err
}
return nil
}

118
gojson/number_test.go Normal file
View File

@@ -0,0 +1,118 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"regexp"
"testing"
)
func TestNumberIsValid(t *testing.T) {
// From: https://stackoverflow.com/a/13340826
var jsonNumberRegexp = regexp.MustCompile(`^-?(?:0|[1-9]\d*)(?:\.\d+)?(?:[eE][+-]?\d+)?$`)
validTests := []string{
"0",
"-0",
"1",
"-1",
"0.1",
"-0.1",
"1234",
"-1234",
"12.34",
"-12.34",
"12E0",
"12E1",
"12e34",
"12E-0",
"12e+1",
"12e-34",
"-12E0",
"-12E1",
"-12e34",
"-12E-0",
"-12e+1",
"-12e-34",
"1.2E0",
"1.2E1",
"1.2e34",
"1.2E-0",
"1.2e+1",
"1.2e-34",
"-1.2E0",
"-1.2E1",
"-1.2e34",
"-1.2E-0",
"-1.2e+1",
"-1.2e-34",
"0E0",
"0E1",
"0e34",
"0E-0",
"0e+1",
"0e-34",
"-0E0",
"-0E1",
"-0e34",
"-0E-0",
"-0e+1",
"-0e-34",
}
for _, test := range validTests {
if !isValidNumber(test) {
t.Errorf("%s should be valid", test)
}
var f float64
if err := Unmarshal([]byte(test), &f); err != nil {
t.Errorf("%s should be valid but Unmarshal failed: %v", test, err)
}
if !jsonNumberRegexp.MatchString(test) {
t.Errorf("%s should be valid but regexp does not match", test)
}
}
invalidTests := []string{
"",
"invalid",
"1.0.1",
"1..1",
"-1-2",
"012a42",
"01.2",
"012",
"12E12.12",
"1e2e3",
"1e+-2",
"1e--23",
"1e",
"e1",
"1e+",
"1ea",
"1a",
"1.a",
"1.",
"01",
"1.e1",
}
for _, test := range invalidTests {
if isValidNumber(test) {
t.Errorf("%s should be invalid", test)
}
var f float64
if err := Unmarshal([]byte(test), &f); err == nil {
t.Errorf("%s should be invalid but unmarshal wrote %v", test, f)
}
if jsonNumberRegexp.MatchString(test) {
t.Errorf("%s should be invalid but matches regexp", test)
}
}
}

610
gojson/scanner.go Normal file
View File

@@ -0,0 +1,610 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
// JSON value parser state machine.
// Just about at the limit of what is reasonable to write by hand.
// Some parts are a bit tedious, but overall it nicely factors out the
// otherwise common code from the multiple scanning functions
// in this package (Compact, Indent, checkValid, etc).
//
// This file starts with two simple examples using the scanner
// before diving into the scanner itself.
import (
"strconv"
"sync"
)
// Valid reports whether data is a valid JSON encoding.
func Valid(data []byte) bool {
scan := newScanner()
defer freeScanner(scan)
return checkValid(data, scan) == nil
}
// checkValid verifies that data is valid JSON-encoded data.
// scan is passed in for use by checkValid to avoid an allocation.
// checkValid returns nil or a SyntaxError.
func checkValid(data []byte, scan *scanner) error {
scan.reset()
for _, c := range data {
scan.bytes++
if scan.step(scan, c) == scanError {
return scan.err
}
}
if scan.eof() == scanError {
return scan.err
}
return nil
}
// A SyntaxError is a description of a JSON syntax error.
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
msg string // description of error
Offset int64 // error occurred after reading Offset bytes
}
func (e *SyntaxError) Error() string { return e.msg }
// A scanner is a JSON scanning state machine.
// Callers call scan.reset and then pass bytes in one at a time
// by calling scan.step(&scan, c) for each byte.
// The return value, referred to as an opcode, tells the
// caller about significant parsing events like beginning
// and ending literals, objects, and arrays, so that the
// caller can follow along if it wishes.
// The return value scanEnd indicates that a single top-level
// JSON value has been completed, *before* the byte that
// just got passed in. (The indication must be delayed in order
// to recognize the end of numbers: is 123 a whole value or
// the beginning of 12345e+6?).
type scanner struct {
// The step is a func to be called to execute the next transition.
// Also tried using an integer constant and a single func
// with a switch, but using the func directly was 10% faster
// on a 64-bit Mac Mini, and it's nicer to read.
step func(*scanner, byte) int
// Reached end of top-level value.
endTop bool
// Stack of what we're in the middle of - array values, object keys, object values.
parseState []int
// Error that happened, if any.
err error
// total bytes consumed, updated by decoder.Decode (and deliberately
// not set to zero by scan.reset)
bytes int64
}
var scannerPool = sync.Pool{
New: func() any {
return &scanner{}
},
}
func newScanner() *scanner {
scan := scannerPool.Get().(*scanner)
// scan.reset by design doesn't set bytes to zero
scan.bytes = 0
scan.reset()
return scan
}
func freeScanner(scan *scanner) {
// Avoid hanging on to too much memory in extreme cases.
if len(scan.parseState) > 1024 {
scan.parseState = nil
}
scannerPool.Put(scan)
}
// These values are returned by the state transition functions
// assigned to scanner.state and the method scanner.eof.
// They give details about the current state of the scan that
// callers might be interested to know about.
// It is okay to ignore the return value of any particular
// call to scanner.state: if one call returns scanError,
// every subsequent call will return scanError too.
const (
// Continue.
scanContinue = iota // uninteresting byte
scanBeginLiteral // end implied by next result != scanContinue
scanBeginObject // begin object
scanObjectKey // just finished object key (string)
scanObjectValue // just finished non-last object value
scanEndObject // end object (implies scanObjectValue if possible)
scanBeginArray // begin array
scanArrayValue // just finished array value
scanEndArray // end array (implies scanArrayValue if possible)
scanSkipSpace // space byte; can skip; known to be last "continue" result
// Stop.
scanEnd // top-level value ended *before* this byte; known to be first "stop" result
scanError // hit an error, scanner.err.
)
// These values are stored in the parseState stack.
// They give the current state of a composite value
// being scanned. If the parser is inside a nested value
// the parseState describes the nested state, outermost at entry 0.
const (
parseObjectKey = iota // parsing object key (before colon)
parseObjectValue // parsing object value (after colon)
parseArrayValue // parsing array value
)
// This limits the max nesting depth to prevent stack overflow.
// This is permitted by https://tools.ietf.org/html/rfc7159#section-9
const maxNestingDepth = 10000
// reset prepares the scanner for use.
// It must be called before calling s.step.
func (s *scanner) reset() {
s.step = stateBeginValue
s.parseState = s.parseState[0:0]
s.err = nil
s.endTop = false
}
// eof tells the scanner that the end of input has been reached.
// It returns a scan status just as s.step does.
func (s *scanner) eof() int {
if s.err != nil {
return scanError
}
if s.endTop {
return scanEnd
}
s.step(s, ' ')
if s.endTop {
return scanEnd
}
if s.err == nil {
s.err = &SyntaxError{"unexpected end of JSON input", s.bytes}
}
return scanError
}
// pushParseState pushes a new parse state p onto the parse stack.
// an error state is returned if maxNestingDepth was exceeded, otherwise successState is returned.
func (s *scanner) pushParseState(c byte, newParseState int, successState int) int {
s.parseState = append(s.parseState, newParseState)
if len(s.parseState) <= maxNestingDepth {
return successState
}
return s.error(c, "exceeded max depth")
}
// popParseState pops a parse state (already obtained) off the stack
// and updates s.step accordingly.
func (s *scanner) popParseState() {
n := len(s.parseState) - 1
s.parseState = s.parseState[0:n]
if n == 0 {
s.step = stateEndTop
s.endTop = true
} else {
s.step = stateEndValue
}
}
func isSpace(c byte) bool {
return c <= ' ' && (c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
// stateBeginValueOrEmpty is the state after reading `[`.
func stateBeginValueOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == ']' {
return stateEndValue(s, c)
}
return stateBeginValue(s, c)
}
// stateBeginValue is the state at the beginning of the input.
func stateBeginValue(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
switch c {
case '{':
s.step = stateBeginStringOrEmpty
return s.pushParseState(c, parseObjectKey, scanBeginObject)
case '[':
s.step = stateBeginValueOrEmpty
return s.pushParseState(c, parseArrayValue, scanBeginArray)
case '"':
s.step = stateInString
return scanBeginLiteral
case '-':
s.step = stateNeg
return scanBeginLiteral
case '0': // beginning of 0.123
s.step = state0
return scanBeginLiteral
case 't': // beginning of true
s.step = stateT
return scanBeginLiteral
case 'f': // beginning of false
s.step = stateF
return scanBeginLiteral
case 'n': // beginning of null
s.step = stateN
return scanBeginLiteral
}
if '1' <= c && c <= '9' { // beginning of 1234.5
s.step = state1
return scanBeginLiteral
}
return s.error(c, "looking for beginning of value")
}
// stateBeginStringOrEmpty is the state after reading `{`.
func stateBeginStringOrEmpty(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '}' {
n := len(s.parseState)
s.parseState[n-1] = parseObjectValue
return stateEndValue(s, c)
}
return stateBeginString(s, c)
}
// stateBeginString is the state after reading `{"key": value,`.
func stateBeginString(s *scanner, c byte) int {
if isSpace(c) {
return scanSkipSpace
}
if c == '"' {
s.step = stateInString
return scanBeginLiteral
}
return s.error(c, "looking for beginning of object key string")
}
// stateEndValue is the state after completing a value,
// such as after reading `{}` or `true` or `["x"`.
func stateEndValue(s *scanner, c byte) int {
n := len(s.parseState)
if n == 0 {
// Completed top-level before the current byte.
s.step = stateEndTop
s.endTop = true
return stateEndTop(s, c)
}
if isSpace(c) {
s.step = stateEndValue
return scanSkipSpace
}
ps := s.parseState[n-1]
switch ps {
case parseObjectKey:
if c == ':' {
s.parseState[n-1] = parseObjectValue
s.step = stateBeginValue
return scanObjectKey
}
return s.error(c, "after object key")
case parseObjectValue:
if c == ',' {
s.parseState[n-1] = parseObjectKey
s.step = stateBeginString
return scanObjectValue
}
if c == '}' {
s.popParseState()
return scanEndObject
}
return s.error(c, "after object key:value pair")
case parseArrayValue:
if c == ',' {
s.step = stateBeginValue
return scanArrayValue
}
if c == ']' {
s.popParseState()
return scanEndArray
}
return s.error(c, "after array element")
}
return s.error(c, "")
}
// stateEndTop is the state after finishing the top-level value,
// such as after reading `{}` or `[1,2,3]`.
// Only space characters should be seen now.
func stateEndTop(s *scanner, c byte) int {
if !isSpace(c) {
// Complain about non-space byte on next call.
s.error(c, "after top-level value")
}
return scanEnd
}
// stateInString is the state after reading `"`.
func stateInString(s *scanner, c byte) int {
if c == '"' {
s.step = stateEndValue
return scanContinue
}
if c == '\\' {
s.step = stateInStringEsc
return scanContinue
}
if c < 0x20 {
return s.error(c, "in string literal")
}
return scanContinue
}
// stateInStringEsc is the state after reading `"\` during a quoted string.
func stateInStringEsc(s *scanner, c byte) int {
switch c {
case 'b', 'f', 'n', 'r', 't', '\\', '/', '"':
s.step = stateInString
return scanContinue
case 'u':
s.step = stateInStringEscU
return scanContinue
}
return s.error(c, "in string escape code")
}
// stateInStringEscU is the state after reading `"\u` during a quoted string.
func stateInStringEscU(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU1
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU1 is the state after reading `"\u1` during a quoted string.
func stateInStringEscU1(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU12
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU12 is the state after reading `"\u12` during a quoted string.
func stateInStringEscU12(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInStringEscU123
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateInStringEscU123 is the state after reading `"\u123` during a quoted string.
func stateInStringEscU123(s *scanner, c byte) int {
if '0' <= c && c <= '9' || 'a' <= c && c <= 'f' || 'A' <= c && c <= 'F' {
s.step = stateInString
return scanContinue
}
// numbers
return s.error(c, "in \\u hexadecimal character escape")
}
// stateNeg is the state after reading `-` during a number.
func stateNeg(s *scanner, c byte) int {
if c == '0' {
s.step = state0
return scanContinue
}
if '1' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return s.error(c, "in numeric literal")
}
// state1 is the state after reading a non-zero integer during a number,
// such as after reading `1` or `100` but not `0`.
func state1(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = state1
return scanContinue
}
return state0(s, c)
}
// state0 is the state after reading `0` during a number.
func state0(s *scanner, c byte) int {
if c == '.' {
s.step = stateDot
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateDot is the state after reading the integer and decimal point in a number,
// such as after reading `1.`.
func stateDot(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateDot0
return scanContinue
}
return s.error(c, "after decimal point in numeric literal")
}
// stateDot0 is the state after reading the integer, decimal point, and subsequent
// digits of a number, such as after reading `3.14`.
func stateDot0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
if c == 'e' || c == 'E' {
s.step = stateE
return scanContinue
}
return stateEndValue(s, c)
}
// stateE is the state after reading the mantissa and e in a number,
// such as after reading `314e` or `0.314e`.
func stateE(s *scanner, c byte) int {
if c == '+' || c == '-' {
s.step = stateESign
return scanContinue
}
return stateESign(s, c)
}
// stateESign is the state after reading the mantissa, e, and sign in a number,
// such as after reading `314e-` or `0.314e+`.
func stateESign(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
s.step = stateE0
return scanContinue
}
return s.error(c, "in exponent of numeric literal")
}
// stateE0 is the state after reading the mantissa, e, optional sign,
// and at least one digit of the exponent in a number,
// such as after reading `314e-2` or `0.314e+1` or `3.14e0`.
func stateE0(s *scanner, c byte) int {
if '0' <= c && c <= '9' {
return scanContinue
}
return stateEndValue(s, c)
}
// stateT is the state after reading `t`.
func stateT(s *scanner, c byte) int {
if c == 'r' {
s.step = stateTr
return scanContinue
}
return s.error(c, "in literal true (expecting 'r')")
}
// stateTr is the state after reading `tr`.
func stateTr(s *scanner, c byte) int {
if c == 'u' {
s.step = stateTru
return scanContinue
}
return s.error(c, "in literal true (expecting 'u')")
}
// stateTru is the state after reading `tru`.
func stateTru(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal true (expecting 'e')")
}
// stateF is the state after reading `f`.
func stateF(s *scanner, c byte) int {
if c == 'a' {
s.step = stateFa
return scanContinue
}
return s.error(c, "in literal false (expecting 'a')")
}
// stateFa is the state after reading `fa`.
func stateFa(s *scanner, c byte) int {
if c == 'l' {
s.step = stateFal
return scanContinue
}
return s.error(c, "in literal false (expecting 'l')")
}
// stateFal is the state after reading `fal`.
func stateFal(s *scanner, c byte) int {
if c == 's' {
s.step = stateFals
return scanContinue
}
return s.error(c, "in literal false (expecting 's')")
}
// stateFals is the state after reading `fals`.
func stateFals(s *scanner, c byte) int {
if c == 'e' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal false (expecting 'e')")
}
// stateN is the state after reading `n`.
func stateN(s *scanner, c byte) int {
if c == 'u' {
s.step = stateNu
return scanContinue
}
return s.error(c, "in literal null (expecting 'u')")
}
// stateNu is the state after reading `nu`.
func stateNu(s *scanner, c byte) int {
if c == 'l' {
s.step = stateNul
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateNul is the state after reading `nul`.
func stateNul(s *scanner, c byte) int {
if c == 'l' {
s.step = stateEndValue
return scanContinue
}
return s.error(c, "in literal null (expecting 'l')")
}
// stateError is the state after reaching a syntax error,
// such as after reading `[1}` or `5.1.2`.
func stateError(s *scanner, c byte) int {
return scanError
}
// error records an error and switches to the error state.
func (s *scanner) error(c byte, context string) int {
s.step = stateError
s.err = &SyntaxError{"invalid character " + quoteChar(c) + " " + context, s.bytes}
return scanError
}
// quoteChar formats c as a quoted character literal.
func quoteChar(c byte) string {
// special cases - different from quoted strings
if c == '\'' {
return `'\''`
}
if c == '"' {
return `'"'`
}
// use quoted string with different quotation marks
s := strconv.Quote(string(c))
return "'" + s[1:len(s)-1] + "'"
}

301
gojson/scanner_test.go Normal file
View File

@@ -0,0 +1,301 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"math"
"math/rand"
"reflect"
"testing"
)
var validTests = []struct {
data string
ok bool
}{
{`foo`, false},
{`}{`, false},
{`{]`, false},
{`{}`, true},
{`{"foo":"bar"}`, true},
{`{"foo":"bar","bar":{"baz":["qux"]}}`, true},
}
func TestValid(t *testing.T) {
for _, tt := range validTests {
if ok := Valid([]byte(tt.data)); ok != tt.ok {
t.Errorf("Valid(%#q) = %v, want %v", tt.data, ok, tt.ok)
}
}
}
// Tests of simple examples.
type example struct {
compact string
indent string
}
var examples = []example{
{`1`, `1`},
{`{}`, `{}`},
{`[]`, `[]`},
{`{"":2}`, "{\n\t\"\": 2\n}"},
{`[3]`, "[\n\t3\n]"},
{`[1,2,3]`, "[\n\t1,\n\t2,\n\t3\n]"},
{`{"x":1}`, "{\n\t\"x\": 1\n}"},
{ex1, ex1i},
{"{\"\":\"<>&\u2028\u2029\"}", "{\n\t\"\": \"<>&\u2028\u2029\"\n}"}, // See golang.org/issue/34070
}
var ex1 = `[true,false,null,"x",1,1.5,0,-5e+2]`
var ex1i = `[
true,
false,
null,
"x",
1,
1.5,
0,
-5e+2
]`
func TestCompact(t *testing.T) {
var buf bytes.Buffer
for _, tt := range examples {
buf.Reset()
if err := Compact(&buf, []byte(tt.compact)); err != nil {
t.Errorf("Compact(%#q): %v", tt.compact, err)
} else if s := buf.String(); s != tt.compact {
t.Errorf("Compact(%#q) = %#q, want original", tt.compact, s)
}
buf.Reset()
if err := Compact(&buf, []byte(tt.indent)); err != nil {
t.Errorf("Compact(%#q): %v", tt.indent, err)
continue
} else if s := buf.String(); s != tt.compact {
t.Errorf("Compact(%#q) = %#q, want %#q", tt.indent, s, tt.compact)
}
}
}
func TestCompactSeparators(t *testing.T) {
// U+2028 and U+2029 should be escaped inside strings.
// They should not appear outside strings.
tests := []struct {
in, compact string
}{
{"{\"\u2028\": 1}", "{\"\u2028\":1}"},
{"{\"\u2029\" :2}", "{\"\u2029\":2}"},
}
for _, tt := range tests {
var buf bytes.Buffer
if err := Compact(&buf, []byte(tt.in)); err != nil {
t.Errorf("Compact(%q): %v", tt.in, err)
} else if s := buf.String(); s != tt.compact {
t.Errorf("Compact(%q) = %q, want %q", tt.in, s, tt.compact)
}
}
}
func TestIndent(t *testing.T) {
var buf bytes.Buffer
for _, tt := range examples {
buf.Reset()
if err := Indent(&buf, []byte(tt.indent), "", "\t"); err != nil {
t.Errorf("Indent(%#q): %v", tt.indent, err)
} else if s := buf.String(); s != tt.indent {
t.Errorf("Indent(%#q) = %#q, want original", tt.indent, s)
}
buf.Reset()
if err := Indent(&buf, []byte(tt.compact), "", "\t"); err != nil {
t.Errorf("Indent(%#q): %v", tt.compact, err)
continue
} else if s := buf.String(); s != tt.indent {
t.Errorf("Indent(%#q) = %#q, want %#q", tt.compact, s, tt.indent)
}
}
}
// Tests of a large random structure.
func TestCompactBig(t *testing.T) {
initBig()
var buf bytes.Buffer
if err := Compact(&buf, jsonBig); err != nil {
t.Fatalf("Compact: %v", err)
}
b := buf.Bytes()
if !bytes.Equal(b, jsonBig) {
t.Error("Compact(jsonBig) != jsonBig")
diff(t, b, jsonBig)
return
}
}
func TestIndentBig(t *testing.T) {
t.Parallel()
initBig()
var buf bytes.Buffer
if err := Indent(&buf, jsonBig, "", "\t"); err != nil {
t.Fatalf("Indent1: %v", err)
}
b := buf.Bytes()
if len(b) == len(jsonBig) {
// jsonBig is compact (no unnecessary spaces);
// indenting should make it bigger
t.Fatalf("Indent(jsonBig) did not get bigger")
}
// should be idempotent
var buf1 bytes.Buffer
if err := Indent(&buf1, b, "", "\t"); err != nil {
t.Fatalf("Indent2: %v", err)
}
b1 := buf1.Bytes()
if !bytes.Equal(b1, b) {
t.Error("Indent(Indent(jsonBig)) != Indent(jsonBig)")
diff(t, b1, b)
return
}
// should get back to original
buf1.Reset()
if err := Compact(&buf1, b); err != nil {
t.Fatalf("Compact: %v", err)
}
b1 = buf1.Bytes()
if !bytes.Equal(b1, jsonBig) {
t.Error("Compact(Indent(jsonBig)) != jsonBig")
diff(t, b1, jsonBig)
return
}
}
type indentErrorTest struct {
in string
err error
}
var indentErrorTests = []indentErrorTest{
{`{"X": "foo", "Y"}`, &SyntaxError{"invalid character '}' after object key", 17}},
{`{"X": "foo" "Y": "bar"}`, &SyntaxError{"invalid character '\"' after object key:value pair", 13}},
}
func TestIndentErrors(t *testing.T) {
for i, tt := range indentErrorTests {
slice := make([]uint8, 0)
buf := bytes.NewBuffer(slice)
if err := Indent(buf, []uint8(tt.in), "", ""); err != nil {
if !reflect.DeepEqual(err, tt.err) {
t.Errorf("#%d: Indent: %#v", i, err)
continue
}
}
}
}
func diff(t *testing.T, a, b []byte) {
for i := 0; ; i++ {
if i >= len(a) || i >= len(b) || a[i] != b[i] {
j := i - 10
if j < 0 {
j = 0
}
t.Errorf("diverge at %d: «%s» vs «%s»", i, trim(a[j:]), trim(b[j:]))
return
}
}
}
func trim(b []byte) []byte {
if len(b) > 20 {
return b[0:20]
}
return b
}
// Generate a random JSON object.
var jsonBig []byte
func initBig() {
n := 10000
if testing.Short() {
n = 100
}
b, err := Marshal(genValue(n))
if err != nil {
panic(err)
}
jsonBig = b
}
func genValue(n int) any {
if n > 1 {
switch rand.Intn(2) {
case 0:
return genArray(n)
case 1:
return genMap(n)
}
}
switch rand.Intn(3) {
case 0:
return rand.Intn(2) == 0
case 1:
return rand.NormFloat64()
case 2:
return genString(30)
}
panic("unreachable")
}
func genString(stddev float64) string {
n := int(math.Abs(rand.NormFloat64()*stddev + stddev/2))
c := make([]rune, n)
for i := range c {
f := math.Abs(rand.NormFloat64()*64 + 32)
if f > 0x10ffff {
f = 0x10ffff
}
c[i] = rune(f)
}
return string(c)
}
func genArray(n int) []any {
f := int(math.Abs(rand.NormFloat64()) * math.Min(10, float64(n/2)))
if f > n {
f = n
}
if f < 1 {
f = 1
}
x := make([]any, f)
for i := range x {
x[i] = genValue(((i+1)*n)/f - (i*n)/f)
}
return x
}
func genMap(n int) map[string]any {
f := int(math.Abs(rand.NormFloat64()) * math.Min(10, float64(n/2)))
if f > n {
f = n
}
if n > 0 && f == 0 {
f = 1
}
x := make(map[string]any)
for i := 0; i < f; i++ {
x[genString(10)] = genValue(((i+1)*n)/f - (i*n)/f)
}
return x
}

524
gojson/stream.go Normal file
View File

@@ -0,0 +1,524 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"errors"
"io"
)
// A Decoder reads and decodes JSON values from an input stream.
type Decoder struct {
r io.Reader
buf []byte
d decodeState
scanp int // start of unread data in buf
scanned int64 // amount of data already scanned
scan scanner
err error
tokenState int
tokenStack []int
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may
// read data from r beyond the JSON values requested.
func NewDecoder(r io.Reader) *Decoder {
return &Decoder{r: r}
}
// UseNumber causes the Decoder to unmarshal a number into an interface{} as a
// Number instead of as a float64.
func (dec *Decoder) UseNumber() { dec.d.useNumber = true }
// DisallowUnknownFields causes the Decoder to return an error when the destination
// is a struct and the input contains object keys which do not match any
// non-ignored, exported fields in the destination.
func (dec *Decoder) DisallowUnknownFields() { dec.d.disallowUnknownFields = true }
// Decode reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v.
//
// See the documentation for Unmarshal for details about
// the conversion of JSON into a Go value.
func (dec *Decoder) Decode(v any) error {
if dec.err != nil {
return dec.err
}
if err := dec.tokenPrepareForDecode(); err != nil {
return err
}
if !dec.tokenValueAllowed() {
return &SyntaxError{msg: "not at beginning of value", Offset: dec.InputOffset()}
}
// Read whole value into buffer.
n, err := dec.readValue()
if err != nil {
return err
}
dec.d.init(dec.buf[dec.scanp : dec.scanp+n])
dec.scanp += n
// Don't save err from unmarshal into dec.err:
// the connection is still usable since we read a complete JSON
// object from it before the error happened.
err = dec.d.unmarshal(v)
// fixup token streaming state
dec.tokenValueEnd()
return err
}
// Buffered returns a reader of the data remaining in the Decoder's
// buffer. The reader is valid until the next call to Decode.
func (dec *Decoder) Buffered() io.Reader {
return bytes.NewReader(dec.buf[dec.scanp:])
}
// readValue reads a JSON value into dec.buf.
// It returns the length of the encoding.
func (dec *Decoder) readValue() (int, error) {
dec.scan.reset()
scanp := dec.scanp
var err error
Input:
// help the compiler see that scanp is never negative, so it can remove
// some bounds checks below.
for scanp >= 0 {
// Look in the buffer for a new value.
for ; scanp < len(dec.buf); scanp++ {
c := dec.buf[scanp]
dec.scan.bytes++
switch dec.scan.step(&dec.scan, c) {
case scanEnd:
// scanEnd is delayed one byte so we decrement
// the scanner bytes count by 1 to ensure that
// this value is correct in the next call of Decode.
dec.scan.bytes--
break Input
case scanEndObject, scanEndArray:
// scanEnd is delayed one byte.
// We might block trying to get that byte from src,
// so instead invent a space byte.
if stateEndValue(&dec.scan, ' ') == scanEnd {
scanp++
break Input
}
case scanError:
dec.err = dec.scan.err
return 0, dec.scan.err
}
}
// Did the last read have an error?
// Delayed until now to allow buffer scan.
if err != nil {
if err == io.EOF {
if dec.scan.step(&dec.scan, ' ') == scanEnd {
break Input
}
if nonSpace(dec.buf) {
err = io.ErrUnexpectedEOF
}
}
dec.err = err
return 0, err
}
n := scanp - dec.scanp
err = dec.refill()
scanp = dec.scanp + n
}
return scanp - dec.scanp, nil
}
func (dec *Decoder) refill() error {
// Make room to read more into the buffer.
// First slide down data already consumed.
if dec.scanp > 0 {
dec.scanned += int64(dec.scanp)
n := copy(dec.buf, dec.buf[dec.scanp:])
dec.buf = dec.buf[:n]
dec.scanp = 0
}
// Grow buffer if not large enough.
const minRead = 512
if cap(dec.buf)-len(dec.buf) < minRead {
newBuf := make([]byte, len(dec.buf), 2*cap(dec.buf)+minRead)
copy(newBuf, dec.buf)
dec.buf = newBuf
}
// Read. Delay error for next iteration (after scan).
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
return err
}
func nonSpace(b []byte) bool {
for _, c := range b {
if !isSpace(c) {
return true
}
}
return false
}
// An Encoder writes JSON values to an output stream.
type Encoder struct {
w io.Writer
err error
escapeHTML bool
nilSafeSlices bool
nilSafeMaps bool
indentBuf *bytes.Buffer
indentPrefix string
indentValue string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
return &Encoder{w: w, escapeHTML: true}
}
// Encode writes the JSON encoding of v to the stream,
// followed by a newline character.
//
// See the documentation for Marshal for details about the
// conversion of Go values to JSON.
func (enc *Encoder) Encode(v any) error {
if enc.err != nil {
return enc.err
}
e := newEncodeState()
defer encodeStatePool.Put(e)
err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML, nilSafeMaps: enc.nilSafeMaps, nilSafeSlices: enc.nilSafeSlices})
if err != nil {
return err
}
// Terminate each value with a newline.
// This makes the output look a little nicer
// when debugging, and some kind of space
// is required if the encoded value was a number,
// so that the reader knows there aren't more
// digits coming.
e.WriteByte('\n')
b := e.Bytes()
if enc.indentPrefix != "" || enc.indentValue != "" {
if enc.indentBuf == nil {
enc.indentBuf = new(bytes.Buffer)
}
enc.indentBuf.Reset()
err = Indent(enc.indentBuf, b, enc.indentPrefix, enc.indentValue)
if err != nil {
return err
}
b = enc.indentBuf.Bytes()
}
if _, err = enc.w.Write(b); err != nil {
enc.err = err
}
return err
}
// SetIndent instructs the encoder to format each subsequent encoded
// value as if indented by the package-level function Indent(dst, src, prefix, indent).
// Calling SetIndent("", "") disables indentation.
func (enc *Encoder) SetIndent(prefix, indent string) {
enc.indentPrefix = prefix
enc.indentValue = indent
}
// SetNilSafeCollection specifies whether to represent nil slices and maps as
// '[]' or '{}' respectfully (flag on) instead of 'null' (default) when marshaling json.
func (enc *Encoder) SetNilSafeCollection(nilSafeSlices bool, nilSafeMaps bool) {
enc.nilSafeSlices = nilSafeSlices
enc.nilSafeMaps = nilSafeMaps
}
// SetEscapeHTML specifies whether problematic HTML characters
// should be escaped inside JSON quoted strings.
// The default behavior is to escape &, <, and > to \u0026, \u003c, and \u003e
// to avoid certain safety problems that can arise when embedding JSON in HTML.
//
// In non-HTML settings where the escaping interferes with the readability
// of the output, SetEscapeHTML(false) disables this behavior.
func (enc *Encoder) SetEscapeHTML(on bool) {
enc.escapeHTML = on
}
// RawMessage is a raw encoded JSON value.
// It implements Marshaler and Unmarshaler and can
// be used to delay JSON decoding or precompute a JSON encoding.
type RawMessage []byte
// MarshalJSON returns m as the JSON encoding of m.
func (m RawMessage) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("null"), nil
}
return m, nil
}
// UnmarshalJSON sets *m to a copy of data.
func (m *RawMessage) UnmarshalJSON(data []byte) error {
if m == nil {
return errors.New("json.RawMessage: UnmarshalJSON on nil pointer")
}
*m = append((*m)[0:0], data...)
return nil
}
var _ Marshaler = (*RawMessage)(nil)
var _ Unmarshaler = (*RawMessage)(nil)
// A Token holds a value of one of these types:
//
// Delim, for the four JSON delimiters [ ] { }
// bool, for JSON booleans
// float64, for JSON numbers
// Number, for JSON numbers
// string, for JSON string literals
// nil, for JSON null
type Token any
const (
tokenTopValue = iota
tokenArrayStart
tokenArrayValue
tokenArrayComma
tokenObjectStart
tokenObjectKey
tokenObjectColon
tokenObjectValue
tokenObjectComma
)
// advance tokenstate from a separator state to a value state
func (dec *Decoder) tokenPrepareForDecode() error {
// Note: Not calling peek before switch, to avoid
// putting peek into the standard Decode path.
// peek is only called when using the Token API.
switch dec.tokenState {
case tokenArrayComma:
c, err := dec.peek()
if err != nil {
return err
}
if c != ',' {
return &SyntaxError{"expected comma after array element", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenArrayValue
case tokenObjectColon:
c, err := dec.peek()
if err != nil {
return err
}
if c != ':' {
return &SyntaxError{"expected colon after object key", dec.InputOffset()}
}
dec.scanp++
dec.tokenState = tokenObjectValue
}
return nil
}
func (dec *Decoder) tokenValueAllowed() bool {
switch dec.tokenState {
case tokenTopValue, tokenArrayStart, tokenArrayValue, tokenObjectValue:
return true
}
return false
}
func (dec *Decoder) tokenValueEnd() {
switch dec.tokenState {
case tokenArrayStart, tokenArrayValue:
dec.tokenState = tokenArrayComma
case tokenObjectValue:
dec.tokenState = tokenObjectComma
}
}
// A Delim is a JSON array or object delimiter, one of [ ] { or }.
type Delim rune
func (d Delim) String() string {
return string(d)
}
// Token returns the next JSON token in the input stream.
// At the end of the input stream, Token returns nil, io.EOF.
//
// Token guarantees that the delimiters [ ] { } it returns are
// properly nested and matched: if Token encounters an unexpected
// delimiter in the input, it will return an error.
//
// The input stream consists of basic JSON values—bool, string,
// number, and null—along with delimiters [ ] { } of type Delim
// to mark the start and end of arrays and objects.
// Commas and colons are elided.
func (dec *Decoder) Token() (Token, error) {
for {
c, err := dec.peek()
if err != nil {
return nil, err
}
switch c {
case '[':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenArrayStart
return Delim('['), nil
case ']':
if dec.tokenState != tokenArrayStart && dec.tokenState != tokenArrayComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim(']'), nil
case '{':
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenStack = append(dec.tokenStack, dec.tokenState)
dec.tokenState = tokenObjectStart
return Delim('{'), nil
case '}':
if dec.tokenState != tokenObjectStart && dec.tokenState != tokenObjectComma {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = dec.tokenStack[len(dec.tokenStack)-1]
dec.tokenStack = dec.tokenStack[:len(dec.tokenStack)-1]
dec.tokenValueEnd()
return Delim('}'), nil
case ':':
if dec.tokenState != tokenObjectColon {
return dec.tokenError(c)
}
dec.scanp++
dec.tokenState = tokenObjectValue
continue
case ',':
if dec.tokenState == tokenArrayComma {
dec.scanp++
dec.tokenState = tokenArrayValue
continue
}
if dec.tokenState == tokenObjectComma {
dec.scanp++
dec.tokenState = tokenObjectKey
continue
}
return dec.tokenError(c)
case '"':
if dec.tokenState == tokenObjectStart || dec.tokenState == tokenObjectKey {
var x string
old := dec.tokenState
dec.tokenState = tokenTopValue
err := dec.Decode(&x)
dec.tokenState = old
if err != nil {
return nil, err
}
dec.tokenState = tokenObjectColon
return x, nil
}
fallthrough
default:
if !dec.tokenValueAllowed() {
return dec.tokenError(c)
}
var x any
if err := dec.Decode(&x); err != nil {
return nil, err
}
return x, nil
}
}
}
func (dec *Decoder) tokenError(c byte) (Token, error) {
var context string
switch dec.tokenState {
case tokenTopValue:
context = " looking for beginning of value"
case tokenArrayStart, tokenArrayValue, tokenObjectValue:
context = " looking for beginning of value"
case tokenArrayComma:
context = " after array element"
case tokenObjectKey:
context = " looking for beginning of object key string"
case tokenObjectColon:
context = " after object key"
case tokenObjectComma:
context = " after object key:value pair"
}
return nil, &SyntaxError{"invalid character " + quoteChar(c) + context, dec.InputOffset()}
}
// More reports whether there is another element in the
// current array or object being parsed.
func (dec *Decoder) More() bool {
c, err := dec.peek()
return err == nil && c != ']' && c != '}'
}
func (dec *Decoder) peek() (byte, error) {
var err error
for {
for i := dec.scanp; i < len(dec.buf); i++ {
c := dec.buf[i]
if isSpace(c) {
continue
}
dec.scanp = i
return c, nil
}
// buffer has been scanned, now report any error
if err != nil {
return 0, err
}
err = dec.refill()
}
}
// InputOffset returns the input stream byte offset of the current decoder position.
// The offset gives the location of the end of the most recently returned token
// and the beginning of the next token.
func (dec *Decoder) InputOffset() int64 {
return dec.scanned + int64(dec.scanp)
}

539
gojson/stream_test.go Normal file
View File

@@ -0,0 +1,539 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"bytes"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"reflect"
"runtime/debug"
"strings"
"testing"
)
// Test values for the stream test.
// One of each JSON kind.
var streamTest = []any{
0.1,
"hello",
nil,
true,
false,
[]any{"a", "b", "c"},
map[string]any{"": "Kelvin", "ß": "long s"},
3.14, // another value to make sure something can follow map
}
var streamEncoded = `0.1
"hello"
null
true
false
["a","b","c"]
{"ß":"long s","":"Kelvin"}
3.14
`
func TestEncoder(t *testing.T) {
for i := 0; i <= len(streamTest); i++ {
var buf strings.Builder
enc := NewEncoder(&buf)
// Check that enc.SetIndent("", "") turns off indentation.
enc.SetIndent(">", ".")
enc.SetIndent("", "")
for j, v := range streamTest[0:i] {
if err := enc.Encode(v); err != nil {
t.Fatalf("encode #%d: %v", j, err)
}
}
if have, want := buf.String(), nlines(streamEncoded, i); have != want {
t.Errorf("encoding %d items: mismatch", i)
diff(t, []byte(have), []byte(want))
break
}
}
}
func TestEncoderErrorAndReuseEncodeState(t *testing.T) {
// Disable the GC temporarily to prevent encodeState's in Pool being cleaned away during the test.
percent := debug.SetGCPercent(-1)
defer debug.SetGCPercent(percent)
// Trigger an error in Marshal with cyclic data.
type Dummy struct {
Name string
Next *Dummy
}
dummy := Dummy{Name: "Dummy"}
dummy.Next = &dummy
var buf bytes.Buffer
enc := NewEncoder(&buf)
if err := enc.Encode(dummy); err == nil {
t.Errorf("Encode(dummy) == nil; want error")
}
type Data struct {
A string
I int
}
data := Data{A: "a", I: 1}
if err := enc.Encode(data); err != nil {
t.Errorf("Marshal(%v) = %v", data, err)
}
var data2 Data
if err := Unmarshal(buf.Bytes(), &data2); err != nil {
t.Errorf("Unmarshal(%v) = %v", data2, err)
}
if data2 != data {
t.Errorf("expect: %v, but get: %v", data, data2)
}
}
var streamEncodedIndent = `0.1
"hello"
null
true
false
[
>."a",
>."b",
>."c"
>]
{
>."ß": "long s",
>."": "Kelvin"
>}
3.14
`
func TestEncoderIndent(t *testing.T) {
var buf strings.Builder
enc := NewEncoder(&buf)
enc.SetIndent(">", ".")
for _, v := range streamTest {
enc.Encode(v)
}
if have, want := buf.String(), streamEncodedIndent; have != want {
t.Error("indented encoding mismatch")
diff(t, []byte(have), []byte(want))
}
}
type strMarshaler string
func (s strMarshaler) MarshalJSON() ([]byte, error) {
return []byte(s), nil
}
type strPtrMarshaler string
func (s *strPtrMarshaler) MarshalJSON() ([]byte, error) {
return []byte(*s), nil
}
func TestEncoderSetEscapeHTML(t *testing.T) {
var c C
var ct CText
var tagStruct struct {
Valid int `json:"<>&#! "`
Invalid int `json:"\\"`
}
// This case is particularly interesting, as we force the encoder to
// take the address of the Ptr field to use its MarshalJSON method. This
// is why the '&' is important.
marshalerStruct := &struct {
NonPtr strMarshaler
Ptr strPtrMarshaler
}{`"<str>"`, `"<str>"`}
// https://golang.org/issue/34154
stringOption := struct {
Bar string `json:"bar,string"`
}{`<html>foobar</html>`}
for _, tt := range []struct {
name string
v any
wantEscape string
want string
}{
{"c", c, `"\u003c\u0026\u003e"`, `"<&>"`},
{"ct", ct, `"\"\u003c\u0026\u003e\""`, `"\"<&>\""`},
{`"<&>"`, "<&>", `"\u003c\u0026\u003e"`, `"<&>"`},
{
"tagStruct", tagStruct,
`{"\u003c\u003e\u0026#! ":0,"Invalid":0}`,
`{"<>&#! ":0,"Invalid":0}`,
},
{
`"<str>"`, marshalerStruct,
`{"NonPtr":"\u003cstr\u003e","Ptr":"\u003cstr\u003e"}`,
`{"NonPtr":"<str>","Ptr":"<str>"}`,
},
{
"stringOption", stringOption,
`{"bar":"\"\\u003chtml\\u003efoobar\\u003c/html\\u003e\""}`,
`{"bar":"\"<html>foobar</html>\""}`,
},
} {
var buf strings.Builder
enc := NewEncoder(&buf)
if err := enc.Encode(tt.v); err != nil {
t.Errorf("Encode(%s): %s", tt.name, err)
continue
}
if got := strings.TrimSpace(buf.String()); got != tt.wantEscape {
t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.wantEscape)
}
buf.Reset()
enc.SetEscapeHTML(false)
if err := enc.Encode(tt.v); err != nil {
t.Errorf("SetEscapeHTML(false) Encode(%s): %s", tt.name, err)
continue
}
if got := strings.TrimSpace(buf.String()); got != tt.want {
t.Errorf("SetEscapeHTML(false) Encode(%s) = %#q, want %#q",
tt.name, got, tt.want)
}
}
}
func TestDecoder(t *testing.T) {
for i := 0; i <= len(streamTest); i++ {
// Use stream without newlines as input,
// just to stress the decoder even more.
// Our test input does not include back-to-back numbers.
// Otherwise stripping the newlines would
// merge two adjacent JSON values.
var buf bytes.Buffer
for _, c := range nlines(streamEncoded, i) {
if c != '\n' {
buf.WriteRune(c)
}
}
out := make([]any, i)
dec := NewDecoder(&buf)
for j := range out {
if err := dec.Decode(&out[j]); err != nil {
t.Fatalf("decode #%d/%d: %v", j, i, err)
}
}
if !reflect.DeepEqual(out, streamTest[0:i]) {
t.Errorf("decoding %d items: mismatch", i)
for j := range out {
if !reflect.DeepEqual(out[j], streamTest[j]) {
t.Errorf("#%d: have %v want %v", j, out[j], streamTest[j])
}
}
break
}
}
}
func TestDecoderBuffered(t *testing.T) {
r := strings.NewReader(`{"Name": "Gopher"} extra `)
var m struct {
Name string
}
d := NewDecoder(r)
err := d.Decode(&m)
if err != nil {
t.Fatal(err)
}
if m.Name != "Gopher" {
t.Errorf("Name = %q; want Gopher", m.Name)
}
rest, err := io.ReadAll(d.Buffered())
if err != nil {
t.Fatal(err)
}
if g, w := string(rest), " extra "; g != w {
t.Errorf("Remaining = %q; want %q", g, w)
}
}
func nlines(s string, n int) string {
if n <= 0 {
return ""
}
for i, c := range s {
if c == '\n' {
if n--; n == 0 {
return s[0 : i+1]
}
}
}
return s
}
func TestRawMessage(t *testing.T) {
var data struct {
X float64
Id RawMessage
Y float32
}
const raw = `["\u0056",null]`
const msg = `{"X":0.1,"Id":["\u0056",null],"Y":0.2}`
err := Unmarshal([]byte(msg), &data)
if err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if string([]byte(data.Id)) != raw {
t.Fatalf("Raw mismatch: have %#q want %#q", []byte(data.Id), raw)
}
b, err := Marshal(&data)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if string(b) != msg {
t.Fatalf("Marshal: have %#q want %#q", b, msg)
}
}
func TestNullRawMessage(t *testing.T) {
var data struct {
X float64
Id RawMessage
IdPtr *RawMessage
Y float32
}
const msg = `{"X":0.1,"Id":null,"IdPtr":null,"Y":0.2}`
err := Unmarshal([]byte(msg), &data)
if err != nil {
t.Fatalf("Unmarshal: %v", err)
}
if want, got := "null", string(data.Id); want != got {
t.Fatalf("Raw mismatch: have %q, want %q", got, want)
}
if data.IdPtr != nil {
t.Fatalf("Raw pointer mismatch: have non-nil, want nil")
}
b, err := Marshal(&data)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if string(b) != msg {
t.Fatalf("Marshal: have %#q want %#q", b, msg)
}
}
var blockingTests = []string{
`{"x": 1}`,
`[1, 2, 3]`,
}
func TestBlocking(t *testing.T) {
for _, enc := range blockingTests {
r, w := net.Pipe()
go w.Write([]byte(enc))
var val any
// If Decode reads beyond what w.Write writes above,
// it will block, and the test will deadlock.
if err := NewDecoder(r).Decode(&val); err != nil {
t.Errorf("decoding %s: %v", enc, err)
}
r.Close()
w.Close()
}
}
type tokenStreamCase struct {
json string
expTokens []any
}
type decodeThis struct {
v any
}
var tokenStreamCases = []tokenStreamCase{
// streaming token cases
{json: `10`, expTokens: []any{float64(10)}},
{json: ` [10] `, expTokens: []any{
Delim('['), float64(10), Delim(']')}},
{json: ` [false,10,"b"] `, expTokens: []any{
Delim('['), false, float64(10), "b", Delim(']')}},
{json: `{ "a": 1 }`, expTokens: []any{
Delim('{'), "a", float64(1), Delim('}')}},
{json: `{"a": 1, "b":"3"}`, expTokens: []any{
Delim('{'), "a", float64(1), "b", "3", Delim('}')}},
{json: ` [{"a": 1},{"a": 2}] `, expTokens: []any{
Delim('['),
Delim('{'), "a", float64(1), Delim('}'),
Delim('{'), "a", float64(2), Delim('}'),
Delim(']')}},
{json: `{"obj": {"a": 1}}`, expTokens: []any{
Delim('{'), "obj", Delim('{'), "a", float64(1), Delim('}'),
Delim('}')}},
{json: `{"obj": [{"a": 1}]}`, expTokens: []any{
Delim('{'), "obj", Delim('['),
Delim('{'), "a", float64(1), Delim('}'),
Delim(']'), Delim('}')}},
// streaming tokens with intermittent Decode()
{json: `{ "a": 1 }`, expTokens: []any{
Delim('{'), "a",
decodeThis{float64(1)},
Delim('}')}},
{json: ` [ { "a" : 1 } ] `, expTokens: []any{
Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
Delim(']')}},
{json: ` [{"a": 1},{"a": 2}] `, expTokens: []any{
Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
decodeThis{map[string]any{"a": float64(2)}},
Delim(']')}},
{json: `{ "obj" : [ { "a" : 1 } ] }`, expTokens: []any{
Delim('{'), "obj", Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
Delim(']'), Delim('}')}},
{json: `{"obj": {"a": 1}}`, expTokens: []any{
Delim('{'), "obj",
decodeThis{map[string]any{"a": float64(1)}},
Delim('}')}},
{json: `{"obj": [{"a": 1}]}`, expTokens: []any{
Delim('{'), "obj",
decodeThis{[]any{
map[string]any{"a": float64(1)},
}},
Delim('}')}},
{json: ` [{"a": 1} {"a": 2}] `, expTokens: []any{
Delim('['),
decodeThis{map[string]any{"a": float64(1)}},
decodeThis{&SyntaxError{"expected comma after array element", 11}},
}},
{json: `{ "` + strings.Repeat("a", 513) + `" 1 }`, expTokens: []any{
Delim('{'), strings.Repeat("a", 513),
decodeThis{&SyntaxError{"expected colon after object key", 518}},
}},
{json: `{ "\a" }`, expTokens: []any{
Delim('{'),
&SyntaxError{"invalid character 'a' in string escape code", 3},
}},
{json: ` \a`, expTokens: []any{
&SyntaxError{"invalid character '\\\\' looking for beginning of value", 1},
}},
}
func TestDecodeInStream(t *testing.T) {
for ci, tcase := range tokenStreamCases {
dec := NewDecoder(strings.NewReader(tcase.json))
for i, etk := range tcase.expTokens {
var tk any
var err error
if dt, ok := etk.(decodeThis); ok {
etk = dt.v
err = dec.Decode(&tk)
} else {
tk, err = dec.Token()
}
if experr, ok := etk.(error); ok {
if err == nil || !reflect.DeepEqual(err, experr) {
t.Errorf("case %v: Expected error %#v in %q, but was %#v", ci, experr, tcase.json, err)
}
break
} else if err == io.EOF {
t.Errorf("case %v: Unexpected EOF in %q", ci, tcase.json)
break
} else if err != nil {
t.Errorf("case %v: Unexpected error '%#v' in %q", ci, err, tcase.json)
break
}
if !reflect.DeepEqual(tk, etk) {
t.Errorf(`case %v: %q @ %v expected %T(%v) was %T(%v)`, ci, tcase.json, i, etk, etk, tk, tk)
break
}
}
}
}
// Test from golang.org/issue/11893
func TestHTTPDecoding(t *testing.T) {
const raw = `{ "foo": "bar" }`
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(raw))
}))
defer ts.Close()
res, err := http.Get(ts.URL)
if err != nil {
log.Fatalf("GET failed: %v", err)
}
defer res.Body.Close()
foo := struct {
Foo string
}{}
d := NewDecoder(res.Body)
err = d.Decode(&foo)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if foo.Foo != "bar" {
t.Errorf("decoded %q; want \"bar\"", foo.Foo)
}
// make sure we get the EOF the second time
err = d.Decode(&foo)
if err != io.EOF {
t.Errorf("err = %v; want io.EOF", err)
}
}
func TestEncoderSetNilSafeCollection(t *testing.T) {
var (
nilSlice []interface{}
pNilSlice *[]interface{}
nilMap map[string]interface{}
pNilMap *map[string]interface{}
)
for _, tt := range []struct {
name string
v interface{}
want string
rescuedWant string
}{
{"nilSlice", nilSlice, "null", "[]"},
{"nonNilSlice", []interface{}{}, "[]", "[]"},
{"sliceWithValues", []interface{}{1, 2, 3}, "[1,2,3]", "[1,2,3]"},
{"pNilSlice", pNilSlice, "null", "null"},
{"nilMap", nilMap, "null", "{}"},
{"nonNilMap", map[string]interface{}{}, "{}", "{}"},
{"mapWithValues", map[string]interface{}{"1": 1, "2": 2, "3": 3}, "{\"1\":1,\"2\":2,\"3\":3}", "{\"1\":1,\"2\":2,\"3\":3}"},
{"pNilMap", pNilMap, "null", "null"},
} {
var buf bytes.Buffer
enc := NewEncoder(&buf)
if err := enc.Encode(tt.v); err != nil {
t.Fatalf("Encode(%s): %s", tt.name, err)
}
if got := strings.TrimSpace(buf.String()); got != tt.want {
t.Errorf("Encode(%s) = %#q, want %#q", tt.name, got, tt.want)
}
buf.Reset()
enc.SetNilSafeCollection(true, true)
if err := enc.Encode(tt.v); err != nil {
t.Fatalf("SetNilSafeCollection(true) Encode(%s): %s", tt.name, err)
}
if got := strings.TrimSpace(buf.String()); got != tt.rescuedWant {
t.Errorf("SetNilSafeCollection(true) Encode(%s) = %#q, want %#q",
tt.name, got, tt.want)
}
}
}

218
gojson/tables.go Normal file
View File

@@ -0,0 +1,218 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import "unicode/utf8"
// safeSet holds the value true if the ASCII character with the given array
// position can be represented inside a JSON string without any further
// escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), and the backslash character ("\").
var safeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': true,
'=': true,
'>': true,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}
// htmlSafeSet holds the value true if the ASCII character with the given
// array position can be safely represented inside a JSON string, embedded
// inside of HTML <script> tags, without any additional escaping.
//
// All values are true except for the ASCII control characters (0-31), the
// double quote ("), the backslash character ("\"), HTML opening and closing
// tags ("<" and ">"), and the ampersand ("&").
var htmlSafeSet = [utf8.RuneSelf]bool{
' ': true,
'!': true,
'"': false,
'#': true,
'$': true,
'%': true,
'&': false,
'\'': true,
'(': true,
')': true,
'*': true,
'+': true,
',': true,
'-': true,
'.': true,
'/': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
':': true,
';': true,
'<': false,
'=': true,
'>': false,
'?': true,
'@': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'V': true,
'W': true,
'X': true,
'Y': true,
'Z': true,
'[': true,
'\\': false,
']': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'{': true,
'|': true,
'}': true,
'~': true,
'\u007f': true,
}

120
gojson/tagkey_test.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"testing"
)
type basicLatin2xTag struct {
V string `json:"$%-/"`
}
type basicLatin3xTag struct {
V string `json:"0123456789"`
}
type basicLatin4xTag struct {
V string `json:"ABCDEFGHIJKLMO"`
}
type basicLatin5xTag struct {
V string `json:"PQRSTUVWXYZ_"`
}
type basicLatin6xTag struct {
V string `json:"abcdefghijklmno"`
}
type basicLatin7xTag struct {
V string `json:"pqrstuvwxyz"`
}
type miscPlaneTag struct {
V string `json:"色は匂へど"`
}
type percentSlashTag struct {
V string `json:"text/html%"` // https://golang.org/issue/2718
}
type punctuationTag struct {
V string `json:"!#$%&()*+-./:;<=>?@[]^_{|}~ "` // https://golang.org/issue/3546
}
type dashTag struct {
V string `json:"-,"`
}
type emptyTag struct {
W string
}
type misnamedTag struct {
X string `jsom:"Misnamed"`
}
type badFormatTag struct {
Y string `:"BadFormat"`
}
type badCodeTag struct {
Z string `json:" !\"#&'()*+,."`
}
type spaceTag struct {
Q string `json:"With space"`
}
type unicodeTag struct {
W string `json:"Ελλάδα"`
}
var structTagObjectKeyTests = []struct {
raw any
value string
key string
}{
{basicLatin2xTag{"2x"}, "2x", "$%-/"},
{basicLatin3xTag{"3x"}, "3x", "0123456789"},
{basicLatin4xTag{"4x"}, "4x", "ABCDEFGHIJKLMO"},
{basicLatin5xTag{"5x"}, "5x", "PQRSTUVWXYZ_"},
{basicLatin6xTag{"6x"}, "6x", "abcdefghijklmno"},
{basicLatin7xTag{"7x"}, "7x", "pqrstuvwxyz"},
{miscPlaneTag{"いろはにほへと"}, "いろはにほへと", "色は匂へど"},
{dashTag{"foo"}, "foo", "-"},
{emptyTag{"Pour Moi"}, "Pour Moi", "W"},
{misnamedTag{"Animal Kingdom"}, "Animal Kingdom", "X"},
{badFormatTag{"Orfevre"}, "Orfevre", "Y"},
{badCodeTag{"Reliable Man"}, "Reliable Man", "Z"},
{percentSlashTag{"brut"}, "brut", "text/html%"},
{punctuationTag{"Union Rags"}, "Union Rags", "!#$%&()*+-./:;<=>?@[]^_{|}~ "},
{spaceTag{"Perreddu"}, "Perreddu", "With space"},
{unicodeTag{"Loukanikos"}, "Loukanikos", "Ελλάδα"},
}
func TestStructTagObjectKey(t *testing.T) {
for _, tt := range structTagObjectKeyTests {
b, err := Marshal(tt.raw)
if err != nil {
t.Fatalf("Marshal(%#q) failed: %v", tt.raw, err)
}
var f any
err = Unmarshal(b, &f)
if err != nil {
t.Fatalf("Unmarshal(%#q) failed: %v", b, err)
}
for i, v := range f.(map[string]any) {
switch i {
case tt.key:
if s, ok := v.(string); !ok || s != tt.value {
t.Fatalf("Unexpected value: %#q, want %v", s, tt.value)
}
default:
t.Fatalf("Unexpected key: %#q, from %#q", i, b)
}
}
}
}

38
gojson/tags.go Normal file
View File

@@ -0,0 +1,38 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"strings"
)
// tagOptions is the string following a comma in a struct field's "json"
// tag, or the empty string. It does not include the leading comma.
type tagOptions string
// parseTag splits a struct field's json tag into its name and
// comma-separated options.
func parseTag(tag string) (string, tagOptions) {
tag, opt, _ := strings.Cut(tag, ",")
return tag, tagOptions(opt)
}
// Contains reports whether a comma-separated list of options
// contains a particular substr flag. substr must be surrounded by a
// string boundary or commas.
func (o tagOptions) Contains(optionName string) bool {
if len(o) == 0 {
return false
}
s := string(o)
for s != "" {
var name string
name, s, _ = strings.Cut(s, ",")
if name == optionName {
return true
}
}
return false
}

28
gojson/tags_test.go Normal file
View File

@@ -0,0 +1,28 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package json
import (
"testing"
)
func TestTagParsing(t *testing.T) {
name, opts := parseTag("field,foobar,foo")
if name != "field" {
t.Fatalf("name = %q, want field", name)
}
for _, tt := range []struct {
opt string
want bool
}{
{"foobar", true},
{"foo", true},
{"bar", false},
} {
if opts.Contains(tt.opt) != tt.want {
t.Errorf("Contains(%q) = %v", tt.opt, !tt.want)
}
}
}

BIN
gojson/testdata/code.json.gz vendored Normal file

Binary file not shown.

View File

@@ -1,6 +1,8 @@
package langext
import (
"errors"
"fmt"
"reflect"
)
@@ -70,7 +72,73 @@ func ArrEqualsExact[T comparable](arr1 []T, arr2 []T) bool {
return true
}
func ArrAll(arr interface{}, fn func(int) bool) bool {
func ArrAll[T any](arr []T, fn func(T) bool) bool {
for _, av := range arr {
if !fn(av) {
return false
}
}
return true
}
func ArrAllErr[T any](arr []T, fn func(T) (bool, error)) (bool, error) {
for _, av := range arr {
v, err := fn(av)
if err != nil {
return false, err
}
if !v {
return false, nil
}
}
return true, nil
}
func ArrNone[T any](arr []T, fn func(T) bool) bool {
for _, av := range arr {
if fn(av) {
return false
}
}
return true
}
func ArrNoneErr[T any](arr []T, fn func(T) (bool, error)) (bool, error) {
for _, av := range arr {
v, err := fn(av)
if err != nil {
return false, err
}
if v {
return false, nil
}
}
return true, nil
}
func ArrAny[T any](arr []T, fn func(T) bool) bool {
for _, av := range arr {
if fn(av) {
return true
}
}
return false
}
func ArrAnyErr[T any](arr []T, fn func(T) (bool, error)) (bool, error) {
for _, av := range arr {
v, err := fn(av)
if err != nil {
return false, err
}
if v {
return true, nil
}
}
return false, nil
}
func ArrIdxAll(arr any, fn func(int) bool) bool {
av := reflect.ValueOf(arr)
for i := 0; i < av.Len(); i++ {
if !fn(i) {
@@ -80,7 +148,7 @@ func ArrAll(arr interface{}, fn func(int) bool) bool {
return true
}
func ArrAllErr(arr interface{}, fn func(int) (bool, error)) (bool, error) {
func ArrIdxAllErr(arr any, fn func(int) (bool, error)) (bool, error) {
av := reflect.ValueOf(arr)
for i := 0; i < av.Len(); i++ {
v, err := fn(i)
@@ -94,7 +162,7 @@ func ArrAllErr(arr interface{}, fn func(int) (bool, error)) (bool, error) {
return true, nil
}
func ArrNone(arr interface{}, fn func(int) bool) bool {
func ArrIdxNone(arr any, fn func(int) bool) bool {
av := reflect.ValueOf(arr)
for i := 0; i < av.Len(); i++ {
if fn(i) {
@@ -104,7 +172,7 @@ func ArrNone(arr interface{}, fn func(int) bool) bool {
return true
}
func ArrNoneErr(arr interface{}, fn func(int) (bool, error)) (bool, error) {
func ArrIdxNoneErr(arr any, fn func(int) (bool, error)) (bool, error) {
av := reflect.ValueOf(arr)
for i := 0; i < av.Len(); i++ {
v, err := fn(i)
@@ -118,7 +186,7 @@ func ArrNoneErr(arr interface{}, fn func(int) (bool, error)) (bool, error) {
return true, nil
}
func ArrAny(arr interface{}, fn func(int) bool) bool {
func ArrIdxAny(arr any, fn func(int) bool) bool {
av := reflect.ValueOf(arr)
for i := 0; i < av.Len(); i++ {
if fn(i) {
@@ -128,7 +196,7 @@ func ArrAny(arr interface{}, fn func(int) bool) bool {
return false
}
func ArrAnyErr(arr interface{}, fn func(int) (bool, error)) (bool, error) {
func ArrIdxAnyErr(arr any, fn func(int) (bool, error)) (bool, error) {
av := reflect.ValueOf(arr)
for i := 0; i < av.Len(); i++ {
v, err := fn(i)
@@ -142,7 +210,7 @@ func ArrAnyErr(arr interface{}, fn func(int) (bool, error)) (bool, error) {
return false, nil
}
func ArrFirst[T comparable](arr []T, comp func(v T) bool) (T, bool) {
func ArrFirst[T any](arr []T, comp func(v T) bool) (T, bool) {
for _, v := range arr {
if comp(v) {
return v, true
@@ -151,7 +219,16 @@ func ArrFirst[T comparable](arr []T, comp func(v T) bool) (T, bool) {
return *new(T), false
}
func ArrLast[T comparable](arr []T, comp func(v T) bool) (T, bool) {
func ArrFirstOrNil[T any](arr []T, comp func(v T) bool) *T {
for _, v := range arr {
if comp(v) {
return Ptr(v)
}
}
return nil
}
func ArrLast[T any](arr []T, comp func(v T) bool) (T, bool) {
found := false
result := *new(T)
for _, v := range arr {
@@ -163,6 +240,22 @@ func ArrLast[T comparable](arr []T, comp func(v T) bool) (T, bool) {
return result, found
}
func ArrLastOrNil[T any](arr []T, comp func(v T) bool) *T {
found := false
result := *new(T)
for _, v := range arr {
if comp(v) {
found = true
result = v
}
}
if found {
return Ptr(result)
} else {
return nil
}
}
func ArrFirstIndex[T comparable](arr []T, needle T) int {
for i, v := range arr {
if v == needle {
@@ -199,6 +292,66 @@ func ArrMap[T1 any, T2 any](arr []T1, conv func(v T1) T2) []T2 {
return r
}
func MapMap[TK comparable, TV any, TR any](inmap map[TK]TV, conv func(k TK, v TV) TR) []TR {
r := make([]TR, 0, len(inmap))
for k, v := range inmap {
r = append(r, conv(k, v))
}
return r
}
func MapMapErr[TK comparable, TV any, TR any](inmap map[TK]TV, conv func(k TK, v TV) (TR, error)) ([]TR, error) {
r := make([]TR, 0, len(inmap))
for k, v := range inmap {
elem, err := conv(k, v)
if err != nil {
return nil, err
}
r = append(r, elem)
}
return r, nil
}
func ArrMapExt[T1 any, T2 any](arr []T1, conv func(idx int, v T1) T2) []T2 {
r := make([]T2, len(arr))
for i, v := range arr {
r[i] = conv(i, v)
}
return r
}
func ArrMapErr[T1 any, T2 any](arr []T1, conv func(v T1) (T2, error)) ([]T2, error) {
var err error
r := make([]T2, len(arr))
for i, v := range arr {
r[i], err = conv(v)
if err != nil {
return nil, err
}
}
return r, nil
}
func ArrFilterMap[T1 any, T2 any](arr []T1, filter func(v T1) bool, conv func(v T1) T2) []T2 {
r := make([]T2, 0, len(arr))
for _, v := range arr {
if filter(v) {
r = append(r, conv(v))
}
}
return r
}
func ArrFilter[T any](arr []T, filter func(v T) bool) []T {
r := make([]T, 0, len(arr))
for _, v := range arr {
if filter(v) {
r = append(r, v)
}
}
return r
}
func ArrSum[T NumberConstraint](arr []T) T {
var r T = 0
for _, v := range arr {
@@ -206,3 +359,84 @@ func ArrSum[T NumberConstraint](arr []T) T {
}
return r
}
func ArrFlatten[T1 any, T2 any](arr []T1, conv func(v T1) []T2) []T2 {
r := make([]T2, 0, len(arr))
for _, v1 := range arr {
r = append(r, conv(v1)...)
}
return r
}
func ArrFlattenDirect[T1 any](arr [][]T1) []T1 {
r := make([]T1, 0, len(arr))
for _, v1 := range arr {
r = append(r, v1...)
}
return r
}
func ArrCastToAny[T1 any](arr []T1) []any {
r := make([]any, len(arr))
for i, v := range arr {
r[i] = any(v)
}
return r
}
func ArrCastSafe[T1 any, T2 any](arr []T1) []T2 {
r := make([]T2, 0, len(arr))
for _, v := range arr {
if vcast, ok := any(v).(T2); ok {
r = append(r, vcast)
}
}
return r
}
func ArrCastErr[T1 any, T2 any](arr []T1) ([]T2, error) {
r := make([]T2, len(arr))
for i, v := range arr {
if vcast, ok := any(v).(T2); ok {
r[i] = vcast
} else {
return nil, errors.New(fmt.Sprintf("Cannot cast element %d of type %T to type %s", i, v, *new(T2)))
}
}
return r, nil
}
func ArrCastPanic[T1 any, T2 any](arr []T1) []T2 {
r := make([]T2, len(arr))
for i, v := range arr {
if vcast, ok := any(v).(T2); ok {
r[i] = vcast
} else {
panic(fmt.Sprintf("Cannot cast element %d of type %T to type %s", i, v, *new(T2)))
}
}
return r
}
func ArrConcat[T any](arr ...[]T) []T {
c := 0
for _, v := range arr {
c += len(v)
}
r := make([]T, c)
i := 0
for _, av := range arr {
for _, v := range av {
r[i] = v
i++
}
}
return r
}
// ArrCopy does a shallow copy of the 'in' array
func ArrCopy[T any](in []T) []T {
out := make([]T, len(in))
copy(out, in)
return out
}

178
langext/base58.go Normal file
View File

@@ -0,0 +1,178 @@
package langext
import (
"bytes"
"errors"
"math/big"
)
// shamelessly stolen from https://github.com/btcsuite/
type B58Encoding struct {
bigRadix [11]*big.Int
bigRadix10 *big.Int
alphabet string
alphabetIdx0 byte
b58 [256]byte
}
var Base58DefaultEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
var Base58FlickrEncoding = newBase58Encoding("123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ")
var Base58RippleEncoding = newBase58Encoding("rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz")
var Base58BitcoinEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
func newBase58Encoding(alphabet string) *B58Encoding {
bigRadix10 := big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58)
enc := &B58Encoding{
alphabet: alphabet,
alphabetIdx0: '1',
bigRadix: [...]*big.Int{
big.NewInt(0),
big.NewInt(58),
big.NewInt(58 * 58),
big.NewInt(58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
bigRadix10,
},
bigRadix10: bigRadix10,
}
b58 := make([]byte, 0, 256)
for i := byte(0); i < 32; i++ {
for j := byte(0); j < 8; j++ {
b := i*8 + j
idx := bytes.IndexByte([]byte(alphabet), b)
if idx == -1 {
b58 = append(b58, 255)
} else {
b58 = append(b58, byte(idx))
}
}
}
enc.b58 = *((*[256]byte)(b58))
return enc
}
func (enc *B58Encoding) EncodeString(src string) (string, error) {
v, err := enc.Encode([]byte(src))
if err != nil {
return "", err
}
return string(v), nil
}
func (enc *B58Encoding) Encode(src []byte) ([]byte, error) {
x := new(big.Int)
x.SetBytes(src)
// maximum length of output is log58(2^(8*len(b))) == len(b) * 8 / log(58)
maxlen := int(float64(len(src))*1.365658237309761) + 1
answer := make([]byte, 0, maxlen)
mod := new(big.Int)
for x.Sign() > 0 {
// Calculating with big.Int is slow for each iteration.
// x, mod = x / 58, x % 58
//
// Instead we can try to do as much calculations on int64.
// x, mod = x / 58^10, x % 58^10
//
// Which will give us mod, which is 10 digit base58 number.
// We'll loop that 10 times to convert to the answer.
x.DivMod(x, enc.bigRadix10, mod)
if x.Sign() == 0 {
// When x = 0, we need to ensure we don't add any extra zeros.
m := mod.Int64()
for m > 0 {
answer = append(answer, enc.alphabet[m%58])
m /= 58
}
} else {
m := mod.Int64()
for i := 0; i < 10; i++ {
answer = append(answer, enc.alphabet[m%58])
m /= 58
}
}
}
// leading zero bytes
for _, i := range src {
if i != 0 {
break
}
answer = append(answer, enc.alphabetIdx0)
}
// reverse
alen := len(answer)
for i := 0; i < alen/2; i++ {
answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
}
return answer, nil
}
func (enc *B58Encoding) DecodeString(src string) (string, error) {
v, err := enc.Decode([]byte(src))
if err != nil {
return "", err
}
return string(v), nil
}
func (enc *B58Encoding) Decode(src []byte) ([]byte, error) {
answer := big.NewInt(0)
scratch := new(big.Int)
for t := src; len(t) > 0; {
n := len(t)
if n > 10 {
n = 10
}
total := uint64(0)
for _, v := range t[:n] {
if v > 255 {
return []byte{}, errors.New("invalid char in input")
}
tmp := enc.b58[v]
if tmp == 255 {
return []byte{}, errors.New("invalid char in input")
}
total = total*58 + uint64(tmp)
}
answer.Mul(answer, enc.bigRadix[n])
scratch.SetUint64(total)
answer.Add(answer, scratch)
t = t[n:]
}
tmpval := answer.Bytes()
var numZeros int
for numZeros = 0; numZeros < len(src); numZeros++ {
if src[numZeros] != enc.alphabetIdx0 {
break
}
}
flen := numZeros + len(tmpval)
val := make([]byte, flen)
copy(val[numZeros:], tmpval)
return val, nil
}

67
langext/base58_test.go Normal file
View File

@@ -0,0 +1,67 @@
package langext
import (
"testing"
)
func _encStr(t *testing.T, enc *B58Encoding, v string) string {
v, err := enc.EncodeString(v)
if err != nil {
t.Error(err)
}
return v
}
func _decStr(t *testing.T, enc *B58Encoding, v string) string {
v, err := enc.DecodeString(v)
if err != nil {
t.Error(err)
}
return v
}
func TestBase58DefaultEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58DefaultEncoding, "Hello"), "9Ajdvzr")
tst.AssertEqual(t, _encStr(t, Base58DefaultEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX")
}
func TestBase58DefaultDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58DefaultEncoding, "9Ajdvzr"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58DefaultEncoding, "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58RippleEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58RippleEncoding, "Hello"), "9wjdvzi")
tst.AssertEqual(t, _encStr(t, Base58RippleEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "h3as3SMcJu26nokqPxhkUVCnd3Qwdgbd4Cp3gfReYrsFi7N44bMy11jqnGj1iJHEnzeZCq1huJM7JHifVbi7hXB7ZpEA9DVtqt894reXucNWSNZ26XVaAhy1GSWqGdFeYTJCrMdDzTg3vCcQV55CJjZX")
}
func TestBase58RippleDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58RippleEncoding, "9wjdvzi"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58RippleEncoding, "h3as3SMcJu26nokqPxhkUVCnd3Qwdgbd4Cp3gfReYrsFi7N44bMy11jqnGj1iJHEnzeZCq1huJM7JHifVbi7hXB7ZpEA9DVtqt894reXucNWSNZ26XVaAhy1GSWqGdFeYTJCrMdDzTg3vCcQV55CJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58BitcoinEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58BitcoinEncoding, "Hello"), "9Ajdvzr")
tst.AssertEqual(t, _encStr(t, Base58BitcoinEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX")
}
func TestBase58BitcoinDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58BitcoinEncoding, "9Ajdvzr"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58BitcoinEncoding, "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58FlickrEncoding(t *testing.T) {
tst.AssertEqual(t, _encStr(t, Base58FlickrEncoding, "Hello"), "9aJCVZR")
tst.AssertEqual(t, _encStr(t, Base58FlickrEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw")
}
func TestBase58FlickrDecoding(t *testing.T) {
tst.AssertEqual(t, _decStr(t, Base58FlickrEncoding, "9aJCVZR"), "Hello")
tst.AssertEqual(t, _decStr(t, Base58FlickrEncoding, "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func tst.AssertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -15,3 +15,35 @@ func Conditional[T any](v bool, resTrue T, resFalse T) T {
return resFalse
}
}
func ConditionalFn00[T any](v bool, resTrue T, resFalse T) T {
if v {
return resTrue
} else {
return resFalse
}
}
func ConditionalFn10[T any](v bool, resTrue func() T, resFalse T) T {
if v {
return resTrue()
} else {
return resFalse
}
}
func ConditionalFn01[T any](v bool, resTrue T, resFalse func() T) T {
if v {
return resTrue
} else {
return resFalse()
}
}
func ConditionalFn11[T any](v bool, resTrue func() T, resFalse func() T) T {
if v {
return resTrue()
} else {
return resFalse()
}
}

View File

@@ -60,3 +60,12 @@ func CoalesceStringer(s fmt.Stringer, def string) string {
return s.String()
}
}
func SafeCast[T any](v any, def T) T {
switch r := v.(type) {
case T:
return r
default:
return def
}
}

View File

@@ -31,16 +31,16 @@ func CompareIntArr(arr1 []int, arr2 []int) bool {
return false
}
func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool {
func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) int {
for i := 0; i < len(arr1) || i < len(arr2); i++ {
if i < len(arr1) && i < len(arr2) {
if arr1[i] < arr2[i] {
return true
return -1
} else if arr1[i] > arr2[i] {
return false
return +2
} else {
continue
}
@@ -49,15 +49,55 @@ func CompareArr[T OrderedConstraint](arr1 []T, arr2 []T) bool {
if i < len(arr1) {
return true
return +1
} else { // if i < len(arr2)
return false
return -1
}
}
return false
return 0
}
func CompareString(a, b string) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
func CompareInt(a, b int) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
func CompareInt64(a, b int64) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}
func Compare[T OrderedConstraint](a, b T) int {
if a == b {
return 0
}
if a < b {
return -1
}
return +1
}

View File

@@ -7,3 +7,19 @@ func MapKeyArr[T comparable, V any](v map[T]V) []T {
}
return result
}
func ArrToMap[T comparable, V any](a []V, keyfunc func(V) T) map[T]V {
result := make(map[T]V, len(a))
for _, v := range a {
result[keyfunc(v)] = v
}
return result
}
func CopyMap[K comparable, V any](a map[K]V) map[K]V {
result := make(map[K]V, len(a))
for k, v := range a {
result[k] = v
}
return result
}

71
langext/panic.go Normal file
View File

@@ -0,0 +1,71 @@
package langext
type PanicWrappedErr struct {
panic any
}
func (p PanicWrappedErr) Error() string {
return "A panic occured"
}
func (p PanicWrappedErr) ReoveredObj() any {
return p.panic
}
func RunPanicSafe(fn func()) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = PanicWrappedErr{panic: rec}
}
}()
fn()
return nil
}
func RunPanicSafeR1(fn func() error) (err error) {
defer func() {
if rec := recover(); rec != nil {
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}
func RunPanicSafeR2[T1 any](fn func() (T1, error)) (r1 T1, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}
func RunPanicSafeR3[T1 any, T2 any](fn func() (T1, T2, error)) (r1 T1, r2 T2, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
r2 = *new(T2)
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}
func RunPanicSafeR4[T1 any, T2 any, T3 any](fn func() (T1, T2, T3, error)) (r1 T1, r2 T2, r3 T3, err error) {
defer func() {
if rec := recover(); rec != nil {
r1 = *new(T1)
r2 = *new(T2)
r3 = *new(T3)
err = PanicWrappedErr{panic: rec}
}
}()
return fn()
}

View File

@@ -34,3 +34,7 @@ func IsNil(i interface{}) bool {
}
return false
}
func PtrEquals[T comparable](v1 *T, v2 *T) bool {
return (v1 == nil && v2 == nil) || (v1 != nil && v2 != nil && *v1 == *v2)
}

111
langext/reflection.go Normal file
View File

@@ -0,0 +1,111 @@
package langext
import (
"reflect"
)
var reflectBasicTypes = []reflect.Type{
reflect.Bool: reflect.TypeOf(false),
reflect.Int: reflect.TypeOf(int(0)),
reflect.Int8: reflect.TypeOf(int8(0)),
reflect.Int16: reflect.TypeOf(int16(0)),
reflect.Int32: reflect.TypeOf(int32(0)),
reflect.Int64: reflect.TypeOf(int64(0)),
reflect.Uint: reflect.TypeOf(uint(0)),
reflect.Uint8: reflect.TypeOf(uint8(0)),
reflect.Uint16: reflect.TypeOf(uint16(0)),
reflect.Uint32: reflect.TypeOf(uint32(0)),
reflect.Uint64: reflect.TypeOf(uint64(0)),
reflect.Uintptr: reflect.TypeOf(uintptr(0)),
reflect.Float32: reflect.TypeOf(float32(0)),
reflect.Float64: reflect.TypeOf(float64(0)),
reflect.Complex64: reflect.TypeOf(complex64(0)),
reflect.Complex128: reflect.TypeOf(complex128(0)),
reflect.String: reflect.TypeOf(""),
}
// Underlying returns the underlying type of t (without type alias)
//
// https://github.com/golang/go/issues/39574#issuecomment-655664772
func Underlying(t reflect.Type) (ret reflect.Type) {
if t.Name() == "" {
// t is an unnamed type. the underlying type is t itself
return t
}
kind := t.Kind()
if ret = reflectBasicTypes[kind]; ret != nil {
return ret
}
switch kind {
case reflect.Array:
ret = reflect.ArrayOf(t.Len(), t.Elem())
case reflect.Chan:
ret = reflect.ChanOf(t.ChanDir(), t.Elem())
case reflect.Map:
ret = reflect.MapOf(t.Key(), t.Elem())
case reflect.Func:
nIn := t.NumIn()
nOut := t.NumOut()
in := make([]reflect.Type, nIn)
out := make([]reflect.Type, nOut)
for i := 0; i < nIn; i++ {
in[i] = t.In(i)
}
for i := 0; i < nOut; i++ {
out[i] = t.Out(i)
}
ret = reflect.FuncOf(in, out, t.IsVariadic())
case reflect.Interface:
// not supported
case reflect.Ptr:
ret = reflect.PtrTo(t.Elem())
case reflect.Slice:
ret = reflect.SliceOf(t.Elem())
case reflect.Struct:
// only partially supported: embedded fields
// and unexported fields may cause panic in reflect.StructOf()
defer func() {
// if a panic happens, return t unmodified
if recover() != nil && ret == nil {
ret = t
}
}()
n := t.NumField()
fields := make([]reflect.StructField, n)
for i := 0; i < n; i++ {
fields[i] = t.Field(i)
}
ret = reflect.StructOf(fields)
}
return ret
}
// TryCast works similar to `v2, ok := v.(T)`
// Except it works through type alias'
func TryCast[T any](v any) (T, bool) {
underlying := Underlying(reflect.TypeOf(v))
def := *new(T)
if underlying != Underlying(reflect.TypeOf(def)) {
return def, false
}
r1 := reflect.ValueOf(v)
if !r1.CanConvert(underlying) {
return def, false
}
r2 := r1.Convert(underlying)
r3 := r2.Interface()
r4, ok := r3.(T)
if !ok {
return def, false
}
return r4, true
}

57
langext/sort.go Normal file
View File

@@ -0,0 +1,57 @@
package langext
import "sort"
func Sort[T OrderedConstraint](arr []T) {
sort.Slice(arr, func(i1, i2 int) bool {
return arr[i1] < arr[i2]
})
}
func SortStable[T OrderedConstraint](arr []T) {
sort.SliceStable(arr, func(i1, i2 int) bool {
return arr[i1] < arr[i2]
})
}
func IsSorted[T OrderedConstraint](arr []T) bool {
return sort.SliceIsSorted(arr, func(i1, i2 int) bool {
return arr[i1] < arr[i2]
})
}
func SortSlice[T any](arr []T, less func(v1, v2 T) bool) {
sort.Slice(arr, func(i1, i2 int) bool {
return less(arr[i1], arr[i2])
})
}
func SortSliceStable[T any](arr []T, less func(v1, v2 T) bool) {
sort.SliceStable(arr, func(i1, i2 int) bool {
return less(arr[i1], arr[i2])
})
}
func IsSliceSorted[T any](arr []T, less func(v1, v2 T) bool) bool {
return sort.SliceIsSorted(arr, func(i1, i2 int) bool {
return less(arr[i1], arr[i2])
})
}
func SortBy[TElem any, TSel OrderedConstraint](arr []TElem, selector func(v TElem) TSel) {
sort.Slice(arr, func(i1, i2 int) bool {
return selector(arr[i1]) < selector(arr[i2])
})
}
func SortByStable[TElem any, TSel OrderedConstraint](arr []TElem, selector func(v TElem) TSel) {
sort.SliceStable(arr, func(i1, i2 int) bool {
return selector(arr[i1]) < selector(arr[i2])
})
}
func IsSortedBy[TElem any, TSel OrderedConstraint](arr []TElem, selector func(v TElem) TSel) {
sort.SliceStable(arr, func(i1, i2 int) bool {
return selector(arr[i1]) < selector(arr[i2])
})
}

View File

@@ -107,3 +107,11 @@ func NumToStringOpt[V IntConstraint](v *V, fallback string) string {
return fmt.Sprintf("%d", v)
}
}
func StrRepeat(val string, count int) string {
r := ""
for i := 0; i < count; i++ {
r += val
}
return r
}

View File

@@ -22,6 +22,31 @@ func Max[T langext.OrderedConstraint](v1 T, v2 T) T {
}
}
func Max3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T {
result := v1
if v2 > result {
result = v2
}
if v3 > result {
result = v3
}
return result
}
func Max4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T {
result := v1
if v2 > result {
result = v2
}
if v3 > result {
result = v3
}
if v4 > result {
result = v4
}
return result
}
func Min[T langext.OrderedConstraint](v1 T, v2 T) T {
if v1 < v2 {
return v1
@@ -30,6 +55,31 @@ func Min[T langext.OrderedConstraint](v1 T, v2 T) T {
}
}
func Min3[T langext.OrderedConstraint](v1 T, v2 T, v3 T) T {
result := v1
if v2 < result {
result = v2
}
if v3 < result {
result = v3
}
return result
}
func Min4[T langext.OrderedConstraint](v1 T, v2 T, v3 T, v4 T) T {
result := v1
if v2 < result {
result = v2
}
if v3 < result {
result = v3
}
if v4 < result {
result = v4
}
return result
}
func Abs[T langext.NumberConstraint](v T) T {
if v < 0 {
return -v

49
mongoext/pipeline.go Normal file
View File

@@ -0,0 +1,49 @@
package mongoext
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
// FixTextSearchPipeline moves {$match:{$text:{$search}}} entries to the front of the pipeline (otherwise its an mongo error)
func FixTextSearchPipeline(pipeline mongo.Pipeline) mongo.Pipeline {
dget := func(v bson.D, k string) (bson.M, bool) {
for _, e := range v {
if e.Key == k {
if mv, ok := e.Value.(bson.M); ok {
return mv, true
}
}
}
return nil, false
}
mget := func(v bson.M, k string) (bson.M, bool) {
for ekey, eval := range v {
if ekey == k {
if mv, ok := eval.(bson.M); ok {
return mv, true
}
}
}
return nil, false
}
result := make([]bson.D, 0, len(pipeline))
for _, entry := range pipeline {
if v0, ok := dget(entry, "$match"); ok {
if v1, ok := mget(v0, "$text"); ok {
if _, ok := v1["$search"]; ok {
result = append([]bson.D{entry}, result...)
continue
}
}
}
result = append(result, entry)
}
return result
}

30
mongoext/projections.go Normal file
View File

@@ -0,0 +1,30 @@
package mongoext
import (
"go.mongodb.org/mongo-driver/bson"
"reflect"
"strings"
)
// ProjectionFromStruct automatically generated a mongodb projection for a struct
// This way you can pretty much always write
// `options.FindOne().SetProjection(mongoutils.ProjectionFromStruct(...your_model...))`
// to only get the data from mongodb that you will actually use in the later decode step
func ProjectionFromStruct(obj interface{}) bson.M {
v := reflect.ValueOf(obj)
t := v.Type()
result := bson.M{}
for i := 0; i < v.NumField(); i++ {
tag := t.Field(i).Tag.Get("bson")
if tag == "" {
continue
}
tag = strings.Split(tag, ",")[0]
result[tag] = 1
}
return result
}

25
mongoext/registry.go Normal file
View File

@@ -0,0 +1,25 @@
package mongoext
import (
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"gogs.mikescher.com/BlackForestBytes/goext/rfctime"
"reflect"
)
func CreateGoExtBsonRegistry() *bsoncodec.Registry {
rb := bsoncodec.NewRegistryBuilder()
rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339Time{}), rfctime.RFC3339Time{})
rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339Time{}), rfctime.RFC3339Time{})
rb.RegisterTypeDecoder(reflect.TypeOf(rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{})
rb.RegisterTypeDecoder(reflect.TypeOf(&rfctime.RFC3339NanoTime{}), rfctime.RFC3339NanoTime{})
bsoncodec.DefaultValueEncoders{}.RegisterDefaultEncoders(rb)
bsoncodec.DefaultValueDecoders{}.RegisterDefaultDecoders(rb)
bson.PrimitiveCodecs{}.RegisterPrimitiveCodecs(rb)
return rb.Build()
}

191
rext/wrapper.go Normal file
View File

@@ -0,0 +1,191 @@
package rext
import (
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"regexp"
)
type Regex interface {
IsMatch(haystack string) bool
MatchFirst(haystack string) (RegexMatch, bool)
MatchAll(haystack string) []RegexMatch
ReplaceAll(haystack string, repl string, literal bool) string
ReplaceAllFunc(haystack string, repl func(string) string) string
RemoveAll(haystack string) string
GroupCount() int
}
type regexWrapper struct {
rex *regexp.Regexp
subnames []string
}
type RegexMatch struct {
haystack string
submatchesIndex []int
subnames []string
}
type RegexMatchGroup struct {
haystack string
start int
end int
}
type OptRegexMatchGroup struct {
v *RegexMatchGroup
}
func W(rex *regexp.Regexp) Regex {
return &regexWrapper{rex: rex, subnames: rex.SubexpNames()}
}
// ---------------------------------------------------------------------------------------------------------------------
func (w *regexWrapper) IsMatch(haystack string) bool {
return w.rex.MatchString(haystack)
}
func (w *regexWrapper) MatchFirst(haystack string) (RegexMatch, bool) {
res := w.rex.FindStringSubmatchIndex(haystack)
if res == nil {
return RegexMatch{}, false
}
return RegexMatch{haystack: haystack, submatchesIndex: res, subnames: w.subnames}, true
}
func (w *regexWrapper) MatchAll(haystack string) []RegexMatch {
resarr := w.rex.FindAllStringSubmatchIndex(haystack, -1)
matches := make([]RegexMatch, 0, len(resarr))
for _, res := range resarr {
matches = append(matches, RegexMatch{haystack: haystack, submatchesIndex: res, subnames: w.subnames})
}
return matches
}
func (w *regexWrapper) ReplaceAll(haystack string, repl string, literal bool) string {
if literal {
// do not expand placeholder aka $1, $2, ...
return w.rex.ReplaceAllLiteralString(haystack, repl)
} else {
return w.rex.ReplaceAllString(haystack, repl)
}
}
func (w *regexWrapper) ReplaceAllFunc(haystack string, repl func(string) string) string {
return w.rex.ReplaceAllStringFunc(haystack, repl)
}
func (w *regexWrapper) RemoveAll(haystack string) string {
return w.rex.ReplaceAllLiteralString(haystack, "")
}
// GroupCount returns the amount of groups in this match, does not count group-0 (whole match)
func (w *regexWrapper) GroupCount() int {
return len(w.subnames) - 1
}
// ---------------------------------------------------------------------------------------------------------------------
func (m RegexMatch) FullMatch() RegexMatchGroup {
return RegexMatchGroup{haystack: m.haystack, start: m.submatchesIndex[0], end: m.submatchesIndex[1]}
}
// GroupCount returns the amount of groups in this match, does not count group-0 (whole match)
func (m RegexMatch) GroupCount() int {
return len(m.subnames) - 1
}
// GroupByIndex returns the value of a matched group (group 0 == whole match)
func (m RegexMatch) GroupByIndex(idx int) RegexMatchGroup {
return RegexMatchGroup{haystack: m.haystack, start: m.submatchesIndex[idx*2], end: m.submatchesIndex[idx*2+1]}
}
// GroupByName returns the value of a matched group (panics if not found!)
func (m RegexMatch) GroupByName(name string) RegexMatchGroup {
for idx, subname := range m.subnames {
if subname == name {
return RegexMatchGroup{haystack: m.haystack, start: m.submatchesIndex[idx*2], end: m.submatchesIndex[idx*2+1]}
}
}
panic("failed to find regex-group by name")
}
// GroupByName returns the value of a matched group (returns empty OptRegexMatchGroup if not found)
func (m RegexMatch) GroupByNameOrEmpty(name string) OptRegexMatchGroup {
for idx, subname := range m.subnames {
if subname == name && (m.submatchesIndex[idx*2] != -1 || m.submatchesIndex[idx*2+1] != -1) {
return OptRegexMatchGroup{&RegexMatchGroup{haystack: m.haystack, start: m.submatchesIndex[idx*2], end: m.submatchesIndex[idx*2+1]}}
}
}
return OptRegexMatchGroup{}
}
// ---------------------------------------------------------------------------------------------------------------------
func (g RegexMatchGroup) Value() string {
return g.haystack[g.start:g.end]
}
func (g RegexMatchGroup) Start() int {
return g.start
}
func (g RegexMatchGroup) End() int {
return g.end
}
func (g RegexMatchGroup) Range() (int, int) {
return g.start, g.end
}
func (g RegexMatchGroup) Length() int {
return g.end - g.start
}
// ---------------------------------------------------------------------------------------------------------------------
func (g OptRegexMatchGroup) Value() string {
return g.v.Value()
}
func (g OptRegexMatchGroup) ValueOrEmpty() string {
if g.v == nil {
return ""
}
return g.v.Value()
}
func (g OptRegexMatchGroup) ValueOrNil() *string {
if g.v == nil {
return nil
}
return langext.Ptr(g.v.Value())
}
func (g OptRegexMatchGroup) IsEmpty() bool {
return g.v == nil
}
func (g OptRegexMatchGroup) Exists() bool {
return g.v != nil
}
func (g OptRegexMatchGroup) Start() int {
return g.v.Start()
}
func (g OptRegexMatchGroup) End() int {
return g.v.End()
}
func (g OptRegexMatchGroup) Range() (int, int) {
return g.v.Range()
}
func (g OptRegexMatchGroup) Length() int {
return g.v.Length()
}

47
rext/wrapper_test.go Normal file
View File

@@ -0,0 +1,47 @@
package rext
import (
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"regexp"
"testing"
)
func TestGroupByNameOrEmpty1(t *testing.T) {
regex1 := W(regexp.MustCompile("0(?P<group1>A+)B(?P<group2>C+)0"))
match1, ok1 := regex1.MatchFirst("Hello 0AAAABCCC0 Bye.")
tst.AssertTrue(t, ok1)
tst.AssertFalse(t, match1.GroupByNameOrEmpty("group1").IsEmpty())
tst.AssertEqual(t, match1.GroupByNameOrEmpty("group1").ValueOrEmpty(), "AAAA")
tst.AssertEqual(t, *match1.GroupByNameOrEmpty("group1").ValueOrNil(), "AAAA")
tst.AssertFalse(t, match1.GroupByNameOrEmpty("group2").IsEmpty())
tst.AssertEqual(t, match1.GroupByNameOrEmpty("group2").ValueOrEmpty(), "CCC")
tst.AssertEqual(t, *match1.GroupByNameOrEmpty("group2").ValueOrNil(), "CCC")
}
func TestGroupByNameOrEmpty2(t *testing.T) {
regex1 := W(regexp.MustCompile("0(?P<group1>A+)B(?P<group2>C+)(?P<group3>C+)?0"))
match1, ok1 := regex1.MatchFirst("Hello 0AAAABCCC0 Bye.")
tst.AssertTrue(t, ok1)
tst.AssertFalse(t, match1.GroupByNameOrEmpty("group1").IsEmpty())
tst.AssertEqual(t, match1.GroupByNameOrEmpty("group1").ValueOrEmpty(), "AAAA")
tst.AssertEqual(t, *match1.GroupByNameOrEmpty("group1").ValueOrNil(), "AAAA")
tst.AssertFalse(t, match1.GroupByNameOrEmpty("group2").IsEmpty())
tst.AssertEqual(t, match1.GroupByNameOrEmpty("group2").ValueOrEmpty(), "CCC")
tst.AssertEqual(t, *match1.GroupByNameOrEmpty("group2").ValueOrNil(), "CCC")
tst.AssertTrue(t, match1.GroupByNameOrEmpty("group3").IsEmpty())
tst.AssertEqual(t, match1.GroupByNameOrEmpty("group3").ValueOrEmpty(), "")
tst.AssertPtrEqual(t, match1.GroupByNameOrEmpty("group3").ValueOrNil(), nil)
}

101
rfctime/interface.go Normal file
View File

@@ -0,0 +1,101 @@
package rfctime
import "time"
type RFCTime interface {
AnyTime
Time() time.Time
Serialize() string
After(u AnyTime) bool
Before(u AnyTime) bool
Equal(u AnyTime) bool
Sub(u AnyTime) time.Duration
}
type AnyTime interface {
MarshalJSON() ([]byte, error)
MarshalBinary() ([]byte, error)
GobEncode() ([]byte, error)
MarshalText() ([]byte, error)
IsZero() bool
Date() (year int, month time.Month, day int)
Year() int
Month() time.Month
Day() int
Weekday() time.Weekday
ISOWeek() (year, week int)
Clock() (hour, min, sec int)
Hour() int
Minute() int
Second() int
Nanosecond() int
YearDay() int
Unix() int64
UnixMilli() int64
UnixMicro() int64
UnixNano() int64
Format(layout string) string
GoString() string
String() string
Location() *time.Location
}
type RFCDuration interface {
Time() time.Time
Serialize() string
UnmarshalJSON(bytes []byte) error
MarshalJSON() ([]byte, error)
MarshalBinary() ([]byte, error)
UnmarshalBinary(data []byte) error
GobEncode() ([]byte, error)
GobDecode(data []byte) error
MarshalText() ([]byte, error)
UnmarshalText(data []byte) error
After(u AnyTime) bool
Before(u AnyTime) bool
Equal(u AnyTime) bool
IsZero() bool
Date() (year int, month time.Month, day int)
Year() int
Month() time.Month
Day() int
Weekday() time.Weekday
ISOWeek() (year, week int)
Clock() (hour, min, sec int)
Hour() int
Minute() int
Second() int
Nanosecond() int
YearDay() int
Sub(u AnyTime) time.Duration
Unix() int64
UnixMilli() int64
UnixMicro() int64
UnixNano() int64
Format(layout string) string
GoString() string
String() string
}
func tt(v AnyTime) time.Time {
if r, ok := v.(time.Time); ok {
return r
}
if r, ok := v.(RFCTime); ok {
return r.Time()
}
return time.Unix(0, v.UnixNano()).In(v.Location())
}

50
rfctime/interface_test.go Normal file
View File

@@ -0,0 +1,50 @@
package rfctime
import (
"testing"
"time"
)
func TestAnyTimeInterface(t *testing.T) {
var v AnyTime
v = NowRFC3339Nano()
tst.AssertEqual(t, v.String(), v.String())
v = NowRFC3339()
tst.AssertEqual(t, v.String(), v.String())
v = NowUnix()
tst.AssertEqual(t, v.String(), v.String())
v = NowUnixMilli()
tst.AssertEqual(t, v.String(), v.String())
v = NowUnixNano()
tst.AssertEqual(t, v.String(), v.String())
v = time.Now()
tst.AssertEqual(t, v.String(), v.String())
}
func TestRFCTimeInterface(t *testing.T) {
var v RFCTime
v = NowRFC3339Nano()
tst.AssertEqual(t, v.String(), v.String())
v = NowRFC3339()
tst.AssertEqual(t, v.String(), v.String())
v = NowUnix()
tst.AssertEqual(t, v.String(), v.String())
v = NowUnixMilli()
tst.AssertEqual(t, v.String(), v.String())
v = NowUnixNano()
tst.AssertEqual(t, v.String(), v.String())
}

244
rfctime/rfc3339.go Normal file
View File

@@ -0,0 +1,244 @@
package rfctime
import (
"encoding/json"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"reflect"
"time"
)
type RFC3339Time time.Time
func (t RFC3339Time) Time() time.Time {
return time.Time(t)
}
func (t RFC3339Time) MarshalBinary() ([]byte, error) {
return (time.Time)(t).MarshalBinary()
}
func (t *RFC3339Time) UnmarshalBinary(data []byte) error {
return (*time.Time)(t).UnmarshalBinary(data)
}
func (t RFC3339Time) GobEncode() ([]byte, error) {
return (time.Time)(t).GobEncode()
}
func (t *RFC3339Time) GobDecode(data []byte) error {
return (*time.Time)(t).GobDecode(data)
}
func (t *RFC3339Time) UnmarshalJSON(data []byte) error {
str := ""
if err := json.Unmarshal(data, &str); err != nil {
return err
}
t0, err := time.Parse(t.FormatStr(), str)
if err != nil {
return err
}
*t = RFC3339Time(t0)
return nil
}
func (t RFC3339Time) MarshalJSON() ([]byte, error) {
str := t.Time().Format(t.FormatStr())
return json.Marshal(str)
}
func (t RFC3339Time) MarshalText() ([]byte, error) {
b := make([]byte, 0, len(t.FormatStr()))
return t.Time().AppendFormat(b, t.FormatStr()), nil
}
func (t *RFC3339Time) UnmarshalText(data []byte) error {
var err error
v, err := time.Parse(t.FormatStr(), string(data))
if err != nil {
return err
}
tt := RFC3339Time(v)
*t = tt
return nil
}
func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
if bt == bsontype.Null {
// we can't set nil in UnmarshalBSONValue (so we use default(struct))
// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values
// https://stackoverflow.com/questions/75167597
// https://jira.mongodb.org/browse/GODRIVER-2252
*t = RFC3339Time{}
return nil
}
if bt != bsontype.DateTime {
return errors.New(fmt.Sprintf("cannot unmarshal %v into RFC3339Time", bt))
}
var tt time.Time
err := bson.Unmarshal(data, &tt)
if err != nil {
return err
}
*t = RFC3339Time(tt)
return nil
}
func (t RFC3339Time) MarshalBSONValue() (bsontype.Type, []byte, error) {
return bson.MarshalValue(time.Time(t))
}
func (t RFC3339Time) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return errors.New("ValueUnmarshalerDecodeValue")
}
val.Set(reflect.New(val.Type().Elem()))
}
tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
if err != nil {
return err
}
if val.Kind() == reflect.Ptr && len(src) == 0 {
val.Set(reflect.Zero(val.Type()))
return nil
}
err = t.UnmarshalBSONValue(tp, src)
if err != nil {
return err
}
return nil
}
func (t RFC3339Time) Serialize() string {
return t.Time().Format(t.FormatStr())
}
func (t RFC3339Time) FormatStr() string {
return time.RFC3339
}
func (t RFC3339Time) After(u AnyTime) bool {
return t.Time().After(tt(u))
}
func (t RFC3339Time) Before(u AnyTime) bool {
return t.Time().Before(tt(u))
}
func (t RFC3339Time) Equal(u AnyTime) bool {
return t.Time().Equal(tt(u))
}
func (t RFC3339Time) IsZero() bool {
return t.Time().IsZero()
}
func (t RFC3339Time) Date() (year int, month time.Month, day int) {
return t.Time().Date()
}
func (t RFC3339Time) Year() int {
return t.Time().Year()
}
func (t RFC3339Time) Month() time.Month {
return t.Time().Month()
}
func (t RFC3339Time) Day() int {
return t.Time().Day()
}
func (t RFC3339Time) Weekday() time.Weekday {
return t.Time().Weekday()
}
func (t RFC3339Time) ISOWeek() (year, week int) {
return t.Time().ISOWeek()
}
func (t RFC3339Time) Clock() (hour, min, sec int) {
return t.Time().Clock()
}
func (t RFC3339Time) Hour() int {
return t.Time().Hour()
}
func (t RFC3339Time) Minute() int {
return t.Time().Minute()
}
func (t RFC3339Time) Second() int {
return t.Time().Second()
}
func (t RFC3339Time) Nanosecond() int {
return t.Time().Nanosecond()
}
func (t RFC3339Time) YearDay() int {
return t.Time().YearDay()
}
func (t RFC3339Time) Add(d time.Duration) RFC3339Time {
return RFC3339Time(t.Time().Add(d))
}
func (t RFC3339Time) Sub(u AnyTime) time.Duration {
return t.Time().Sub(tt(u))
}
func (t RFC3339Time) AddDate(years int, months int, days int) RFC3339Time {
return RFC3339Time(t.Time().AddDate(years, months, days))
}
func (t RFC3339Time) Unix() int64 {
return t.Time().Unix()
}
func (t RFC3339Time) UnixMilli() int64 {
return t.Time().UnixMilli()
}
func (t RFC3339Time) UnixMicro() int64 {
return t.Time().UnixMicro()
}
func (t RFC3339Time) UnixNano() int64 {
return t.Time().UnixNano()
}
func (t RFC3339Time) Format(layout string) string {
return t.Time().Format(layout)
}
func (t RFC3339Time) GoString() string {
return t.Time().GoString()
}
func (t RFC3339Time) String() string {
return t.Time().String()
}
func (t RFC3339Time) Location() *time.Location {
return t.Time().Location()
}
func NewRFC3339(t time.Time) RFC3339Time {
return RFC3339Time(t)
}
func NowRFC3339() RFC3339Time {
return RFC3339Time(time.Now())
}

250
rfctime/rfc3339Nano.go Normal file
View File

@@ -0,0 +1,250 @@
package rfctime
import (
"encoding/json"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsoncodec"
"go.mongodb.org/mongo-driver/bson/bsonrw"
"go.mongodb.org/mongo-driver/bson/bsontype"
"reflect"
"time"
)
type RFC3339NanoTime time.Time
func (t RFC3339NanoTime) Time() time.Time {
return time.Time(t)
}
func (t RFC3339NanoTime) MarshalBinary() ([]byte, error) {
return (time.Time)(t).MarshalBinary()
}
func (t *RFC3339NanoTime) UnmarshalBinary(data []byte) error {
return (*time.Time)(t).UnmarshalBinary(data)
}
func (t RFC3339NanoTime) GobEncode() ([]byte, error) {
return (time.Time)(t).GobEncode()
}
func (t *RFC3339NanoTime) GobDecode(data []byte) error {
return (*time.Time)(t).GobDecode(data)
}
func (t *RFC3339NanoTime) UnmarshalJSON(data []byte) error {
str := ""
if err := json.Unmarshal(data, &str); err != nil {
return err
}
t0, err := time.Parse(t.FormatStr(), str)
if err != nil {
return err
}
*t = RFC3339NanoTime(t0)
return nil
}
func (t RFC3339NanoTime) MarshalJSON() ([]byte, error) {
str := t.Time().Format(t.FormatStr())
return json.Marshal(str)
}
func (t RFC3339NanoTime) MarshalText() ([]byte, error) {
b := make([]byte, 0, len(t.FormatStr()))
return t.Time().AppendFormat(b, t.FormatStr()), nil
}
func (t *RFC3339NanoTime) UnmarshalText(data []byte) error {
var err error
v, err := time.Parse(t.FormatStr(), string(data))
if err != nil {
return err
}
tt := RFC3339NanoTime(v)
*t = tt
return nil
}
func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
if bt == bsontype.Null {
// we can't set nil in UnmarshalBSONValue (so we use default(struct))
// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values
// https://stackoverflow.com/questions/75167597
// https://jira.mongodb.org/browse/GODRIVER-2252
*t = RFC3339NanoTime{}
return nil
}
if bt != bsontype.DateTime {
return errors.New(fmt.Sprintf("cannot unmarshal %v into RFC3339NanoTime", bt))
}
var tt time.Time
err := bson.RawValue{Type: bt, Value: data}.Unmarshal(&tt)
if err != nil {
return err
}
*t = RFC3339NanoTime(tt)
return nil
}
func (t RFC3339NanoTime) MarshalBSONValue() (bsontype.Type, []byte, error) {
return bson.MarshalValue(time.Time(t))
}
func (t RFC3339NanoTime) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
if val.Kind() == reflect.Ptr && val.IsNil() {
if !val.CanSet() {
return errors.New("ValueUnmarshalerDecodeValue")
}
val.Set(reflect.New(val.Type().Elem()))
}
tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
if err != nil {
return err
}
if val.Kind() == reflect.Ptr && len(src) == 0 {
val.Set(reflect.Zero(val.Type()))
return nil
}
err = t.UnmarshalBSONValue(tp, src)
if err != nil {
return err
}
if val.Kind() == reflect.Ptr {
val.Set(reflect.ValueOf(&t))
} else {
val.Set(reflect.ValueOf(t))
}
return nil
}
func (t RFC3339NanoTime) Serialize() string {
return t.Time().Format(t.FormatStr())
}
func (t RFC3339NanoTime) FormatStr() string {
return time.RFC3339Nano
}
func (t RFC3339NanoTime) After(u AnyTime) bool {
return t.Time().After(tt(u))
}
func (t RFC3339NanoTime) Before(u AnyTime) bool {
return t.Time().Before(tt(u))
}
func (t RFC3339NanoTime) Equal(u AnyTime) bool {
return t.Time().Equal(tt(u))
}
func (t RFC3339NanoTime) IsZero() bool {
return t.Time().IsZero()
}
func (t RFC3339NanoTime) Date() (year int, month time.Month, day int) {
return t.Time().Date()
}
func (t RFC3339NanoTime) Year() int {
return t.Time().Year()
}
func (t RFC3339NanoTime) Month() time.Month {
return t.Time().Month()
}
func (t RFC3339NanoTime) Day() int {
return t.Time().Day()
}
func (t RFC3339NanoTime) Weekday() time.Weekday {
return t.Time().Weekday()
}
func (t RFC3339NanoTime) ISOWeek() (year, week int) {
return t.Time().ISOWeek()
}
func (t RFC3339NanoTime) Clock() (hour, min, sec int) {
return t.Time().Clock()
}
func (t RFC3339NanoTime) Hour() int {
return t.Time().Hour()
}
func (t RFC3339NanoTime) Minute() int {
return t.Time().Minute()
}
func (t RFC3339NanoTime) Second() int {
return t.Time().Second()
}
func (t RFC3339NanoTime) Nanosecond() int {
return t.Time().Nanosecond()
}
func (t RFC3339NanoTime) YearDay() int {
return t.Time().YearDay()
}
func (t RFC3339NanoTime) Add(d time.Duration) RFC3339NanoTime {
return RFC3339NanoTime(t.Time().Add(d))
}
func (t RFC3339NanoTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(tt(u))
}
func (t RFC3339NanoTime) AddDate(years int, months int, days int) RFC3339NanoTime {
return RFC3339NanoTime(t.Time().AddDate(years, months, days))
}
func (t RFC3339NanoTime) Unix() int64 {
return t.Time().Unix()
}
func (t RFC3339NanoTime) UnixMilli() int64 {
return t.Time().UnixMilli()
}
func (t RFC3339NanoTime) UnixMicro() int64 {
return t.Time().UnixMicro()
}
func (t RFC3339NanoTime) UnixNano() int64 {
return t.Time().UnixNano()
}
func (t RFC3339NanoTime) Format(layout string) string {
return t.Time().Format(layout)
}
func (t RFC3339NanoTime) GoString() string {
return t.Time().GoString()
}
func (t RFC3339NanoTime) String() string {
return t.Time().String()
}
func (t RFC3339NanoTime) Location() *time.Location {
return t.Time().Location()
}
func NewRFC3339Nano(t time.Time) RFC3339NanoTime {
return RFC3339NanoTime(t)
}
func NowRFC3339Nano() RFC3339NanoTime {
return RFC3339NanoTime(time.Now())
}

View File

@@ -0,0 +1,47 @@
package rfctime
import (
"encoding/json"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"testing"
"time"
)
func TestRoundtrip(t *testing.T) {
type Wrap struct {
Value RFC3339NanoTime `json:"v"`
}
val1 := NewRFC3339Nano(time.Unix(0, 1675951556820915171))
w1 := Wrap{val1}
jstr1, err := json.Marshal(w1)
if err != nil {
panic(err)
}
if string(jstr1) != "{\"v\":\"2023-02-09T15:05:56.820915171+01:00\"}" {
t.Errorf(string(jstr1))
t.Errorf("repr differs")
}
w2 := Wrap{}
err = json.Unmarshal(jstr1, &w2)
if err != nil {
panic(err)
}
jstr2, err := json.Marshal(w2)
if err != nil {
panic(err)
}
tst.AssertEqual(t, string(jstr1), string(jstr2))
if !w1.Value.Equal(&w2.Value) {
t.Errorf("time differs")
}
}

59
rfctime/seconds.go Normal file
View File

@@ -0,0 +1,59 @@
package rfctime
import (
"encoding/json"
"gogs.mikescher.com/BlackForestBytes/goext/timeext"
"time"
)
type SecondsF64 time.Duration
func (d SecondsF64) Duration() time.Duration {
return time.Duration(d)
}
func (d SecondsF64) String() string {
return d.Duration().String()
}
func (d SecondsF64) Nanoseconds() int64 {
return d.Duration().Nanoseconds()
}
func (d SecondsF64) Microseconds() int64 {
return d.Duration().Microseconds()
}
func (d SecondsF64) Milliseconds() int64 {
return d.Duration().Milliseconds()
}
func (d SecondsF64) Seconds() float64 {
return d.Duration().Seconds()
}
func (d SecondsF64) Minutes() float64 {
return d.Duration().Minutes()
}
func (d SecondsF64) Hours() float64 {
return d.Duration().Hours()
}
func (d *SecondsF64) UnmarshalJSON(data []byte) error {
var secs float64 = 0
if err := json.Unmarshal(data, &secs); err != nil {
return err
}
*d = SecondsF64(timeext.FromSeconds(secs))
return nil
}
func (d SecondsF64) MarshalJSON() ([]byte, error) {
secs := d.Seconds()
return json.Marshal(secs)
}
func NewSecondsF64(t time.Duration) SecondsF64 {
return SecondsF64(t)
}

180
rfctime/unix.go Normal file
View File

@@ -0,0 +1,180 @@
package rfctime
import (
"encoding/json"
"strconv"
"time"
)
type UnixTime time.Time
func (t UnixTime) Time() time.Time {
return time.Time(t)
}
func (t UnixTime) MarshalBinary() ([]byte, error) {
return (time.Time)(t).MarshalBinary()
}
func (t *UnixTime) UnmarshalBinary(data []byte) error {
return (*time.Time)(t).UnmarshalBinary(data)
}
func (t UnixTime) GobEncode() ([]byte, error) {
return (time.Time)(t).GobEncode()
}
func (t *UnixTime) GobDecode(data []byte) error {
return (*time.Time)(t).GobDecode(data)
}
func (t *UnixTime) UnmarshalJSON(data []byte) error {
str := ""
if err := json.Unmarshal(data, &str); err != nil {
return err
}
t0, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return err
}
*t = UnixTime(time.Unix(t0, 0))
return nil
}
func (t UnixTime) MarshalJSON() ([]byte, error) {
str := strconv.FormatInt(t.Time().Unix(), 10)
return json.Marshal(str)
}
func (t UnixTime) MarshalText() ([]byte, error) {
return []byte(strconv.FormatInt(t.Time().Unix(), 10)), nil
}
func (t *UnixTime) UnmarshalText(data []byte) error {
t0, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return err
}
*t = UnixTime(time.Unix(t0, 0))
return nil
}
func (t UnixTime) Serialize() string {
return strconv.FormatInt(t.Time().Unix(), 10)
}
func (t UnixTime) After(u AnyTime) bool {
return t.Time().After(tt(u))
}
func (t UnixTime) Before(u AnyTime) bool {
return t.Time().Before(tt(u))
}
func (t UnixTime) Equal(u AnyTime) bool {
return t.Time().Equal(tt(u))
}
func (t UnixTime) IsZero() bool {
return t.Time().IsZero()
}
func (t UnixTime) Date() (year int, month time.Month, day int) {
return t.Time().Date()
}
func (t UnixTime) Year() int {
return t.Time().Year()
}
func (t UnixTime) Month() time.Month {
return t.Time().Month()
}
func (t UnixTime) Day() int {
return t.Time().Day()
}
func (t UnixTime) Weekday() time.Weekday {
return t.Time().Weekday()
}
func (t UnixTime) ISOWeek() (year, week int) {
return t.Time().ISOWeek()
}
func (t UnixTime) Clock() (hour, min, sec int) {
return t.Time().Clock()
}
func (t UnixTime) Hour() int {
return t.Time().Hour()
}
func (t UnixTime) Minute() int {
return t.Time().Minute()
}
func (t UnixTime) Second() int {
return t.Time().Second()
}
func (t UnixTime) Nanosecond() int {
return t.Time().Nanosecond()
}
func (t UnixTime) YearDay() int {
return t.Time().YearDay()
}
func (t UnixTime) Add(d time.Duration) UnixTime {
return UnixTime(t.Time().Add(d))
}
func (t UnixTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(tt(u))
}
func (t UnixTime) AddDate(years int, months int, days int) UnixTime {
return UnixTime(t.Time().AddDate(years, months, days))
}
func (t UnixTime) Unix() int64 {
return t.Time().Unix()
}
func (t UnixTime) UnixMilli() int64 {
return t.Time().UnixMilli()
}
func (t UnixTime) UnixMicro() int64 {
return t.Time().UnixMicro()
}
func (t UnixTime) UnixNano() int64 {
return t.Time().UnixNano()
}
func (t UnixTime) Format(layout string) string {
return t.Time().Format(layout)
}
func (t UnixTime) GoString() string {
return t.Time().GoString()
}
func (t UnixTime) String() string {
return t.Time().String()
}
func (t UnixTime) Location() *time.Location {
return t.Time().Location()
}
func NewUnix(t time.Time) UnixTime {
return UnixTime(t)
}
func NowUnix() UnixTime {
return UnixTime(time.Now())
}

180
rfctime/unixMilli.go Normal file
View File

@@ -0,0 +1,180 @@
package rfctime
import (
"encoding/json"
"strconv"
"time"
)
type UnixMilliTime time.Time
func (t UnixMilliTime) Time() time.Time {
return time.Time(t)
}
func (t UnixMilliTime) MarshalBinary() ([]byte, error) {
return (time.Time)(t).MarshalBinary()
}
func (t *UnixMilliTime) UnmarshalBinary(data []byte) error {
return (*time.Time)(t).UnmarshalBinary(data)
}
func (t UnixMilliTime) GobEncode() ([]byte, error) {
return (time.Time)(t).GobEncode()
}
func (t *UnixMilliTime) GobDecode(data []byte) error {
return (*time.Time)(t).GobDecode(data)
}
func (t *UnixMilliTime) UnmarshalJSON(data []byte) error {
str := ""
if err := json.Unmarshal(data, &str); err != nil {
return err
}
t0, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return err
}
*t = UnixMilliTime(time.UnixMilli(t0))
return nil
}
func (t UnixMilliTime) MarshalJSON() ([]byte, error) {
str := strconv.FormatInt(t.Time().UnixMilli(), 10)
return json.Marshal(str)
}
func (t UnixMilliTime) MarshalText() ([]byte, error) {
return []byte(strconv.FormatInt(t.Time().UnixMilli(), 10)), nil
}
func (t *UnixMilliTime) UnmarshalText(data []byte) error {
t0, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return err
}
*t = UnixMilliTime(time.UnixMilli(t0))
return nil
}
func (t UnixMilliTime) Serialize() string {
return strconv.FormatInt(t.Time().UnixMilli(), 10)
}
func (t UnixMilliTime) After(u AnyTime) bool {
return t.Time().After(tt(u))
}
func (t UnixMilliTime) Before(u AnyTime) bool {
return t.Time().Before(tt(u))
}
func (t UnixMilliTime) Equal(u AnyTime) bool {
return t.Time().Equal(tt(u))
}
func (t UnixMilliTime) IsZero() bool {
return t.Time().IsZero()
}
func (t UnixMilliTime) Date() (year int, month time.Month, day int) {
return t.Time().Date()
}
func (t UnixMilliTime) Year() int {
return t.Time().Year()
}
func (t UnixMilliTime) Month() time.Month {
return t.Time().Month()
}
func (t UnixMilliTime) Day() int {
return t.Time().Day()
}
func (t UnixMilliTime) Weekday() time.Weekday {
return t.Time().Weekday()
}
func (t UnixMilliTime) ISOWeek() (year, week int) {
return t.Time().ISOWeek()
}
func (t UnixMilliTime) Clock() (hour, min, sec int) {
return t.Time().Clock()
}
func (t UnixMilliTime) Hour() int {
return t.Time().Hour()
}
func (t UnixMilliTime) Minute() int {
return t.Time().Minute()
}
func (t UnixMilliTime) Second() int {
return t.Time().Second()
}
func (t UnixMilliTime) Nanosecond() int {
return t.Time().Nanosecond()
}
func (t UnixMilliTime) YearDay() int {
return t.Time().YearDay()
}
func (t UnixMilliTime) Add(d time.Duration) UnixMilliTime {
return UnixMilliTime(t.Time().Add(d))
}
func (t UnixMilliTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(tt(u))
}
func (t UnixMilliTime) AddDate(years int, months int, days int) UnixMilliTime {
return UnixMilliTime(t.Time().AddDate(years, months, days))
}
func (t UnixMilliTime) Unix() int64 {
return t.Time().Unix()
}
func (t UnixMilliTime) UnixMilli() int64 {
return t.Time().UnixMilli()
}
func (t UnixMilliTime) UnixMicro() int64 {
return t.Time().UnixMicro()
}
func (t UnixMilliTime) UnixNano() int64 {
return t.Time().UnixNano()
}
func (t UnixMilliTime) Format(layout string) string {
return t.Time().Format(layout)
}
func (t UnixMilliTime) GoString() string {
return t.Time().GoString()
}
func (t UnixMilliTime) String() string {
return t.Time().String()
}
func (t UnixMilliTime) Location() *time.Location {
return t.Time().Location()
}
func NewUnixMilli(t time.Time) UnixMilliTime {
return UnixMilliTime(t)
}
func NowUnixMilli() UnixMilliTime {
return UnixMilliTime(time.Now())
}

180
rfctime/unixNano.go Normal file
View File

@@ -0,0 +1,180 @@
package rfctime
import (
"encoding/json"
"strconv"
"time"
)
type UnixNanoTime time.Time
func (t UnixNanoTime) Time() time.Time {
return time.Time(t)
}
func (t UnixNanoTime) MarshalBinary() ([]byte, error) {
return (time.Time)(t).MarshalBinary()
}
func (t *UnixNanoTime) UnmarshalBinary(data []byte) error {
return (*time.Time)(t).UnmarshalBinary(data)
}
func (t UnixNanoTime) GobEncode() ([]byte, error) {
return (time.Time)(t).GobEncode()
}
func (t *UnixNanoTime) GobDecode(data []byte) error {
return (*time.Time)(t).GobDecode(data)
}
func (t *UnixNanoTime) UnmarshalJSON(data []byte) error {
str := ""
if err := json.Unmarshal(data, &str); err != nil {
return err
}
t0, err := strconv.ParseInt(str, 10, 64)
if err != nil {
return err
}
*t = UnixNanoTime(time.Unix(0, t0))
return nil
}
func (t UnixNanoTime) MarshalJSON() ([]byte, error) {
str := strconv.FormatInt(t.Time().UnixNano(), 10)
return json.Marshal(str)
}
func (t UnixNanoTime) MarshalText() ([]byte, error) {
return []byte(strconv.FormatInt(t.Time().UnixNano(), 10)), nil
}
func (t *UnixNanoTime) UnmarshalText(data []byte) error {
t0, err := strconv.ParseInt(string(data), 10, 64)
if err != nil {
return err
}
*t = UnixNanoTime(time.Unix(0, t0))
return nil
}
func (t UnixNanoTime) Serialize() string {
return strconv.FormatInt(t.Time().UnixNano(), 10)
}
func (t UnixNanoTime) After(u AnyTime) bool {
return t.Time().After(tt(u))
}
func (t UnixNanoTime) Before(u AnyTime) bool {
return t.Time().Before(tt(u))
}
func (t UnixNanoTime) Equal(u AnyTime) bool {
return t.Time().Equal(tt(u))
}
func (t UnixNanoTime) IsZero() bool {
return t.Time().IsZero()
}
func (t UnixNanoTime) Date() (year int, month time.Month, day int) {
return t.Time().Date()
}
func (t UnixNanoTime) Year() int {
return t.Time().Year()
}
func (t UnixNanoTime) Month() time.Month {
return t.Time().Month()
}
func (t UnixNanoTime) Day() int {
return t.Time().Day()
}
func (t UnixNanoTime) Weekday() time.Weekday {
return t.Time().Weekday()
}
func (t UnixNanoTime) ISOWeek() (year, week int) {
return t.Time().ISOWeek()
}
func (t UnixNanoTime) Clock() (hour, min, sec int) {
return t.Time().Clock()
}
func (t UnixNanoTime) Hour() int {
return t.Time().Hour()
}
func (t UnixNanoTime) Minute() int {
return t.Time().Minute()
}
func (t UnixNanoTime) Second() int {
return t.Time().Second()
}
func (t UnixNanoTime) Nanosecond() int {
return t.Time().Nanosecond()
}
func (t UnixNanoTime) YearDay() int {
return t.Time().YearDay()
}
func (t UnixNanoTime) Add(d time.Duration) UnixNanoTime {
return UnixNanoTime(t.Time().Add(d))
}
func (t UnixNanoTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(tt(u))
}
func (t UnixNanoTime) AddDate(years int, months int, days int) UnixNanoTime {
return UnixNanoTime(t.Time().AddDate(years, months, days))
}
func (t UnixNanoTime) Unix() int64 {
return t.Time().Unix()
}
func (t UnixNanoTime) UnixMilli() int64 {
return t.Time().UnixNano()
}
func (t UnixNanoTime) UnixMicro() int64 {
return t.Time().UnixMicro()
}
func (t UnixNanoTime) UnixNano() int64 {
return t.Time().UnixNano()
}
func (t UnixNanoTime) Format(layout string) string {
return t.Time().Format(layout)
}
func (t UnixNanoTime) GoString() string {
return t.Time().GoString()
}
func (t UnixNanoTime) String() string {
return t.Time().String()
}
func (t UnixNanoTime) Location() *time.Location {
return t.Time().Location()
}
func NewUnixNano(t time.Time) UnixNanoTime {
return UnixNanoTime(t)
}
func NowUnixNano() UnixNanoTime {
return UnixNanoTime(time.Now())
}

128
sq/database.go Normal file
View File

@@ -0,0 +1,128 @@
package sq
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"sync"
)
type DB interface {
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
Ping(ctx context.Context) error
BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error)
AddListener(listener Listener)
Exit() error
}
type database struct {
db *sqlx.DB
txctr uint16
lock sync.Mutex
lstr []Listener
}
func NewDB(db *sqlx.DB) DB {
return &database{
db: db,
txctr: 0,
lock: sync.Mutex{},
lstr: make([]Listener, 0),
}
}
func (db *database) AddListener(listener Listener) {
db.lstr = append(db.lstr, listener)
}
func (db *database) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
origsql := sqlstr
for _, v := range db.lstr {
err := v.PreExec(ctx, nil, &sqlstr, &prep)
if err != nil {
return nil, err
}
}
res, err := db.db.NamedExecContext(ctx, sqlstr, prep)
for _, v := range db.lstr {
v.PostExec(nil, origsql, sqlstr, prep)
}
if err != nil {
return nil, err
}
return res, nil
}
func (db *database) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
origsql := sqlstr
for _, v := range db.lstr {
err := v.PreQuery(ctx, nil, &sqlstr, &prep)
if err != nil {
return nil, err
}
}
rows, err := sqlx.NamedQueryContext(ctx, db.db, sqlstr, prep)
for _, v := range db.lstr {
v.PostQuery(nil, origsql, sqlstr, prep)
}
if err != nil {
return nil, err
}
return rows, nil
}
func (db *database) Ping(ctx context.Context) error {
for _, v := range db.lstr {
err := v.PrePing(ctx)
if err != nil {
return err
}
}
err := db.db.PingContext(ctx)
for _, v := range db.lstr {
v.PostPing(err)
}
if err != nil {
return err
}
return nil
}
func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error) {
db.lock.Lock()
txid := db.txctr
db.txctr += 1 // with overflow !
db.lock.Unlock()
for _, v := range db.lstr {
err := v.PreTxBegin(ctx, txid)
if err != nil {
return nil, err
}
}
xtx, err := db.db.BeginTxx(ctx, &sql.TxOptions{Isolation: iso})
if err != nil {
return nil, err
}
for _, v := range db.lstr {
v.PostTxBegin(txid, err)
}
return NewTransaction(xtx, txid, db.lstr), nil
}
func (db *database) Exit() error {
return db.db.Close()
}

19
sq/listener.go Normal file
View File

@@ -0,0 +1,19 @@
package sq
import "context"
type Listener interface {
PrePing(ctx context.Context) error
PreTxBegin(ctx context.Context, txid uint16) error
PreTxCommit(txid uint16) error
PreTxRollback(txid uint16) error
PreQuery(ctx context.Context, txID *uint16, sql *string, params *PP) error
PreExec(ctx context.Context, txID *uint16, sql *string, params *PP) error
PostPing(result error)
PostTxBegin(txid uint16, result error)
PostTxCommit(txid uint16, result error)
PostTxRollback(txid uint16, result error)
PostQuery(txID *uint16, sqlOriginal string, sqlReal string, params PP)
PostExec(txID *uint16, sqlOriginal string, sqlReal string, params PP)
}

13
sq/params.go Normal file
View File

@@ -0,0 +1,13 @@
package sq
type PP map[string]any
func Join(pps ...PP) PP {
r := PP{}
for _, add := range pps {
for k, v := range add {
r[k] = v
}
}
return r
}

12
sq/queryable.go Normal file
View File

@@ -0,0 +1,12 @@
package sq
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
)
type Queryable interface {
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
}

140
sq/scanner.go Normal file
View File

@@ -0,0 +1,140 @@
package sq
import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
)
type StructScanMode string
const (
SModeFast StructScanMode = "FAST"
SModeExtended StructScanMode = "EXTENDED"
)
type StructScanSafety string
const (
Safe StructScanSafety = "SAFE"
Unsafe StructScanSafety = "UNSAFE"
)
func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
if rows.Next() {
var strscan *StructScanner
if sec == Safe {
strscan = NewStructScanner(rows, false)
var data TData
err := strscan.Start(&data)
if err != nil {
return *new(TData), err
}
} else if sec == Unsafe {
strscan = NewStructScanner(rows, true)
var data TData
err := strscan.Start(&data)
if err != nil {
return *new(TData), err
}
} else {
return *new(TData), errors.New("unknown value for <sec>")
}
var data TData
if mode == SModeFast {
err := strscan.StructScanBase(&data)
if err != nil {
return *new(TData), err
}
} else if mode == SModeExtended {
err := strscan.StructScanExt(&data)
if err != nil {
return *new(TData), err
}
} else {
return *new(TData), errors.New("unknown value for <mode>")
}
if rows.Next() {
if close {
_ = rows.Close()
}
return *new(TData), errors.New("sql returned more than one row")
}
if close {
err := rows.Close()
if err != nil {
return *new(TData), err
}
}
if err := rows.Err(); err != nil {
return *new(TData), err
}
return data, nil
} else {
if close {
_ = rows.Close()
}
return *new(TData), sql.ErrNoRows
}
}
func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
var strscan *StructScanner
if sec == Safe {
strscan = NewStructScanner(rows, false)
var data TData
err := strscan.Start(&data)
if err != nil {
return nil, err
}
} else if sec == Unsafe {
strscan = NewStructScanner(rows, true)
var data TData
err := strscan.Start(&data)
if err != nil {
return nil, err
}
} else {
return nil, errors.New("unknown value for <sec>")
}
res := make([]TData, 0)
for rows.Next() {
if mode == SModeFast {
var data TData
err := strscan.StructScanBase(&data)
if err != nil {
return nil, err
}
res = append(res, data)
} else if mode == SModeExtended {
var data TData
err := strscan.StructScanExt(&data)
if err != nil {
return nil, err
}
res = append(res, data)
} else {
return nil, errors.New("unknown value for <mode>")
}
}
if close {
err := strscan.rows.Close()
if err != nil {
return nil, err
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return res, nil
}

223
sq/structscanner.go Normal file
View File

@@ -0,0 +1,223 @@
package sq
import (
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
"reflect"
)
// forked from sqlx, but added ability to unmarshal optional-nested structs
type StructScanner struct {
rows *sqlx.Rows
Mapper *reflectx.Mapper
unsafe bool
fields [][]int
values []any
columns []string
}
func NewStructScanner(rows *sqlx.Rows, unsafe bool) *StructScanner {
return &StructScanner{
rows: rows,
Mapper: reflectx.NewMapper("db"),
unsafe: unsafe,
}
}
func (r *StructScanner) Start(dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
columns, err := r.rows.Columns()
if err != nil {
return err
}
r.columns = columns
r.fields = r.Mapper.TraversalsByName(v.Type(), columns)
// if we are not unsafe and are missing fields, return an error
if f, err := missingFields(r.fields); err != nil && !r.unsafe {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
r.values = make([]interface{}, len(columns))
return nil
}
// StructScanExt forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
// does also wok with nullabel structs (from LEFT JOIN's)
func (r *StructScanner) StructScanExt(dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
// ========= STEP 1 :: =========
v = v.Elem()
err := fieldsByTraversalExtended(v, r.fields, r.values)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r.rows.Scan(r.values...)
if err != nil {
return err
}
nullStructs := make(map[string]bool)
for i, traversal := range r.fields {
if len(traversal) == 0 {
continue
}
isnsil := reflect.ValueOf(r.values[i]).Elem().IsNil()
for i := 1; i < len(traversal); i++ {
canParentNil := reflectx.FieldByIndexes(v, traversal[0:i]).Kind() == reflect.Pointer
k := fmt.Sprintf("%v", traversal[0:i])
if v, ok := nullStructs[k]; ok {
nullStructs[k] = canParentNil && v && isnsil
} else {
nullStructs[k] = canParentNil && isnsil
}
}
}
forcenulled := make(map[string]bool)
for i, traversal := range r.fields {
if len(traversal) == 0 {
continue
}
anyparentnull := false
for i := 1; i < len(traversal); i++ {
k := fmt.Sprintf("%v", traversal[0:i])
if nv, ok := nullStructs[k]; ok && nv {
if _, ok := forcenulled[k]; !ok {
f := reflectx.FieldByIndexes(v, traversal[0:i])
f.Set(reflect.Zero(f.Type())) // set to nil
forcenulled[k] = true
}
anyparentnull = true
break
}
}
if anyparentnull {
continue
}
f := reflectx.FieldByIndexes(v, traversal)
val1 := reflect.ValueOf(r.values[i])
val2 := val1.Elem()
val3 := val2.Elem()
if val2.IsNil() {
if f.Kind() != reflect.Pointer {
return errors.New(fmt.Sprintf("Cannot set field %v to NULL value from column '%s' (type: %s)", traversal, r.columns[i], f.Type().String()))
}
f.Set(reflect.Zero(f.Type())) // set to nil
} else {
f.Set(val3)
}
}
return r.rows.Err()
}
// StructScanBase forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
// without (relevant) changes
func (r *StructScanner) StructScanBase(dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return errors.New("must pass a pointer, not a value, to StructScan destination")
}
v = v.Elem()
err := fieldsByTraversalBase(v, r.fields, r.values, true)
if err != nil {
return err
}
// scan into the struct field pointers and append to our results
err = r.rows.Scan(r.values...)
if err != nil {
return err
}
return r.rows.Err()
}
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func fieldsByTraversalExtended(v reflect.Value, traversals [][]int, values []interface{}) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface()
}
return nil
}
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func fieldsByTraversalBase(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
}
for i, traversal := range traversals {
if len(traversal) == 0 {
values[i] = new(interface{})
continue
}
f := reflectx.FieldByIndexes(v, traversal)
if ptrs {
values[i] = f.Addr().Interface()
} else {
values[i] = f.Interface()
}
}
return nil
}
// missingFields forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}

105
sq/transaction.go Normal file
View File

@@ -0,0 +1,105 @@
package sq
import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
)
type Tx interface {
Rollback() error
Commit() error
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
}
type transaction struct {
tx *sqlx.Tx
id uint16
lstr []Listener
}
func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr []Listener) Tx {
return &transaction{
tx: xtx,
id: txid,
lstr: lstr,
}
}
func (tx *transaction) Rollback() error {
for _, v := range tx.lstr {
err := v.PreTxRollback(tx.id)
if err != nil {
return err
}
}
result := tx.tx.Rollback()
for _, v := range tx.lstr {
v.PostTxRollback(tx.id, result)
}
return result
}
func (tx *transaction) Commit() error {
for _, v := range tx.lstr {
err := v.PreTxCommit(tx.id)
if err != nil {
return err
}
}
result := tx.tx.Commit()
for _, v := range tx.lstr {
v.PostTxRollback(tx.id, result)
}
return result
}
func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
origsql := sqlstr
for _, v := range tx.lstr {
err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
if err != nil {
return nil, err
}
}
res, err := tx.tx.NamedExecContext(ctx, sqlstr, prep)
for _, v := range tx.lstr {
v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep)
}
if err != nil {
return nil, err
}
return res, nil
}
func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
origsql := sqlstr
for _, v := range tx.lstr {
err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
if err != nil {
return nil, err
}
}
rows, err := sqlx.NamedQueryContext(ctx, tx.tx, sqlstr, prep)
for _, v := range tx.lstr {
v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep)
}
if err != nil {
return nil, err
}
return rows, nil
}

109
syncext/atomic.go Normal file
View File

@@ -0,0 +1,109 @@
package syncext
import (
"context"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"sync"
"time"
)
type AtomicBool struct {
v bool
listener map[string]chan bool
lock sync.Mutex
}
func NewAtomicBool(value bool) *AtomicBool {
return &AtomicBool{
v: value,
listener: make(map[string]chan bool),
lock: sync.Mutex{},
}
}
func (a *AtomicBool) Get() bool {
a.lock.Lock()
defer a.lock.Unlock()
return a.v
}
func (a *AtomicBool) Set(value bool) {
a.lock.Lock()
defer a.lock.Unlock()
a.v = value
for k, v := range a.listener {
select {
case v <- value:
// message sent
default:
// no receiver on channel
delete(a.listener, k)
}
}
}
func (a *AtomicBool) Wait(waitFor bool) {
_ = a.WaitWithContext(context.Background(), waitFor)
}
func (a *AtomicBool) WaitWithTimeout(timeout time.Duration, waitFor bool) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return a.WaitWithContext(ctx, waitFor)
}
func (a *AtomicBool) WaitWithContext(ctx context.Context, waitFor bool) error {
if err := ctx.Err(); err != nil {
return err
}
if a.Get() == waitFor {
return nil
}
uuid, _ := langext.NewHexUUID()
waitchan := make(chan bool)
a.lock.Lock()
a.listener[uuid] = waitchan
a.lock.Unlock()
defer func() {
a.lock.Lock()
delete(a.listener, uuid)
a.lock.Unlock()
}()
for {
if err := ctx.Err(); err != nil {
return err
}
timeOut := 1024 * time.Millisecond
if dl, ok := ctx.Deadline(); ok {
timeOutMax := dl.Sub(time.Now())
if timeOutMax <= 0 {
timeOut = 0
} else if 0 < timeOutMax && timeOutMax < timeOut {
timeOut = timeOutMax
}
}
if v, ok := ReadChannelWithTimeout(waitchan, timeOut); ok {
if v == waitFor {
return nil
}
} else {
if err := ctx.Err(); err != nil {
return err
}
if a.Get() == waitFor {
return nil
}
}
}
}

45
syncext/channel.go Normal file
View File

@@ -0,0 +1,45 @@
package syncext
import (
"time"
)
// https://gobyexample.com/non-blocking-channel-operations
// https://gobyexample.com/timeouts
// https://groups.google.com/g/golang-nuts/c/Oth9CmJPoqo
func ReadChannelWithTimeout[T any](c chan T, timeout time.Duration) (T, bool) {
select {
case msg := <-c:
return msg, true
case <-time.After(timeout):
return *new(T), false
}
}
func WriteChannelWithTimeout[T any](c chan T, msg T, timeout time.Duration) bool {
select {
case c <- msg:
return true
case <-time.After(timeout):
return false
}
}
func ReadNonBlocking[T any](c chan T) (T, bool) {
select {
case msg := <-c:
return msg, true
default:
return *new(T), false
}
}
func WriteNonBlocking[T any](c chan T, msg T) bool {
select {
case c <- msg:
return true
default:
return false
}
}

121
syncext/channel_test.go Normal file
View File

@@ -0,0 +1,121 @@
package syncext
import (
"testing"
"time"
)
func TestTimeoutReadBuffered(t *testing.T) {
c := make(chan int, 1)
go func() {
time.Sleep(200 * time.Millisecond)
c <- 112
}()
_, ok := ReadChannelWithTimeout(c, 100*time.Millisecond)
if ok {
t.Error("Read success, but should timeout")
}
}
func TestTimeoutReadBigBuffered(t *testing.T) {
c := make(chan int, 128)
go func() {
time.Sleep(200 * time.Millisecond)
c <- 112
}()
_, ok := ReadChannelWithTimeout(c, 100*time.Millisecond)
if ok {
t.Error("Read success, but should timeout")
}
}
func TestTimeoutReadUnbuffered(t *testing.T) {
c := make(chan int)
go func() {
time.Sleep(200 * time.Millisecond)
c <- 112
}()
_, ok := ReadChannelWithTimeout(c, 100*time.Millisecond)
if ok {
t.Error("Read success, but should timeout")
}
}
func TestNoTimeoutAfterStartReadBuffered(t *testing.T) {
c := make(chan int, 1)
go func() {
time.Sleep(10 * time.Millisecond)
c <- 112
}()
_, ok := ReadChannelWithTimeout(c, 100*time.Millisecond)
if !ok {
t.Error("Read timeout, but should have succeeded")
}
}
func TestNoTimeoutAfterStartReadBigBuffered(t *testing.T) {
c := make(chan int, 128)
go func() {
time.Sleep(10 * time.Millisecond)
c <- 112
}()
_, ok := ReadChannelWithTimeout(c, 100*time.Millisecond)
if !ok {
t.Error("Read timeout, but should have succeeded")
}
}
func TestNoTimeoutAfterStartReadUnbuffered(t *testing.T) {
c := make(chan int)
go func() {
time.Sleep(10 * time.Millisecond)
c <- 112
}()
_, ok := ReadChannelWithTimeout(c, 100*time.Millisecond)
if !ok {
t.Error("Read timeout, but should have succeeded")
}
}
func TestNoTimeoutBeforeStartReadBuffered(t *testing.T) {
c := make(chan int, 1)
c <- 112
_, ok := ReadChannelWithTimeout(c, 10*time.Millisecond)
if !ok {
t.Error("Read timeout, but should have succeeded")
}
}
func TestNoTimeoutBeforeStartReadBigBuffered(t *testing.T) {
c := make(chan int, 128)
c <- 112
_, ok := ReadChannelWithTimeout(c, 10*time.Millisecond)
if !ok {
t.Error("Read timeout, but should have succeeded")
}
}

View File

@@ -14,26 +14,26 @@ func TestSupportsColors(t *testing.T) {
}
func TestColor(t *testing.T) {
assertEqual(t, Red("test"), "\033[31mtest\u001B[0m")
assertEqual(t, Green("test"), "\033[32mtest\u001B[0m")
assertEqual(t, Yellow("test"), "\033[33mtest\u001B[0m")
assertEqual(t, Blue("test"), "\033[34mtest\u001B[0m")
assertEqual(t, Purple("test"), "\033[35mtest\u001B[0m")
assertEqual(t, Cyan("test"), "\033[36mtest\u001B[0m")
assertEqual(t, Gray("test"), "\033[37mtest\u001B[0m")
assertEqual(t, White("test"), "\033[97mtest\u001B[0m")
tst.AssertEqual(t, Red("test"), "\033[31mtest\u001B[0m")
tst.AssertEqual(t, Green("test"), "\033[32mtest\u001B[0m")
tst.AssertEqual(t, Yellow("test"), "\033[33mtest\u001B[0m")
tst.AssertEqual(t, Blue("test"), "\033[34mtest\u001B[0m")
tst.AssertEqual(t, Purple("test"), "\033[35mtest\u001B[0m")
tst.AssertEqual(t, Cyan("test"), "\033[36mtest\u001B[0m")
tst.AssertEqual(t, Gray("test"), "\033[37mtest\u001B[0m")
tst.AssertEqual(t, White("test"), "\033[97mtest\u001B[0m")
assertEqual(t, CleanString(Red("test")), "test")
assertEqual(t, CleanString(Green("test")), "test")
assertEqual(t, CleanString(Yellow("test")), "test")
assertEqual(t, CleanString(Blue("test")), "test")
assertEqual(t, CleanString(Purple("test")), "test")
assertEqual(t, CleanString(Cyan("test")), "test")
assertEqual(t, CleanString(Gray("test")), "test")
assertEqual(t, CleanString(White("test")), "test")
tst.AssertEqual(t, CleanString(Red("test")), "test")
tst.AssertEqual(t, CleanString(Green("test")), "test")
tst.AssertEqual(t, CleanString(Yellow("test")), "test")
tst.AssertEqual(t, CleanString(Blue("test")), "test")
tst.AssertEqual(t, CleanString(Purple("test")), "test")
tst.AssertEqual(t, CleanString(Cyan("test")), "test")
tst.AssertEqual(t, CleanString(Gray("test")), "test")
tst.AssertEqual(t, CleanString(White("test")), "test")
}
func assertEqual(t *testing.T, actual string, expected string) {
func tst.AssertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}

152
timeext/parser.go Normal file
View File

@@ -0,0 +1,152 @@
package timeext
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
)
var durationShortStringMap = map[string]time.Duration{
"ns": time.Nanosecond,
"nanosecond": time.Nanosecond,
"nanoseconds": time.Nanosecond,
"us": time.Microsecond,
"microsecond": time.Microsecond,
"microseconds": time.Microsecond,
"ms": time.Millisecond,
"millisecond": time.Millisecond,
"milliseconds": time.Millisecond,
"s": time.Second,
"sec": time.Second,
"second": time.Second,
"seconds": time.Second,
"m": time.Minute,
"min": time.Minute,
"minute": time.Minute,
"minutes": time.Minute,
"h": time.Hour,
"hour": time.Hour,
"hours": time.Hour,
"d": 24 * time.Hour,
"day": 24 * time.Hour,
"days": 24 * time.Hour,
"w": 7 * 24 * time.Hour,
"wk": 7 * 24 * time.Hour,
"week": 7 * 24 * time.Hour,
"weeks": 7 * 24 * time.Hour,
}
// ParseDurationShortString parses a duration in string format to a time.Duration
// Examples for allowed formats:
// - '10m'
// - '10min'
// - '1minute'
// - '10minutes'
// - '10.5minutes'
// - '50s'
// - '50sec'
// - '1second'
// - '50seconds'
// - '100ms'
// - '100millisseconds'
// - '1h'
// - '1hour'
// - '2hours'
// - '1d'
// - '1day'
// - '10days'
// - '1d10m'
// - '1d10m200sec'
// - '1d:10m'
// - '1d 10m'
// - '1d,10m'
func ParseDurationShortString(s string) (time.Duration, error) {
s = strings.ToLower(s)
segments := make([]string, 0)
collector := ""
prevWasNum := true
for _, chr := range s {
if chr >= '0' && chr <= '9' || chr == '.' {
if prevWasNum {
collector += string(chr)
} else {
segments = append(segments, collector)
prevWasNum = true
collector = string(chr)
}
} else if chr == ' ' || chr == ':' || chr == ',' {
continue
} else if chr >= 'a' && chr <= 'z' {
prevWasNum = false
collector += string(chr)
} else {
return 0, errors.New("unexpected character: " + string(chr))
}
}
if !prevWasNum {
segments = append(segments, collector)
}
result := time.Duration(0)
for _, seg := range segments {
segDur, err := parseDurationShortStringSegment(seg)
if err != nil {
return 0, err
}
result += segDur
}
return result, nil
}
func parseDurationShortStringSegment(segment string) (time.Duration, error) {
num := ""
unit := ""
part0 := true
for _, chr := range segment {
if part0 {
if chr >= 'a' && chr <= 'z' {
part0 = false
unit += string(chr)
} else if chr >= '0' && chr <= '9' || chr == '.' {
num += string(chr)
} else {
return 0, errors.New(fmt.Sprintf("Unexpected character '%s' in segment [%s]", string(chr), segment))
}
} else {
if chr >= 'a' && chr <= 'z' {
unit += string(chr)
} else if chr >= '0' && chr <= '9' || chr == '.' {
return 0, errors.New(fmt.Sprintf("Unexpected number '%s' in segment [%s]", string(chr), segment))
} else {
return 0, errors.New(fmt.Sprintf("Unexpected character '%s' in segment [%s]", string(chr), segment))
}
}
}
fpnum, err := strconv.ParseFloat(num, 64)
if err != nil {
return 0, errors.New(fmt.Sprintf("Failed to parse floating-point number '%s' in segment [%s]", num, segment))
}
if mult, ok := durationShortStringMap[unit]; ok {
return time.Duration(int64(fpnum * float64(mult))), nil
} else {
return 0, errors.New(fmt.Sprintf("Unknown unit '%s' in segment [%s]", unit, segment))
}
}

70
timeext/parser_test.go Normal file
View File

@@ -0,0 +1,70 @@
package timeext
import (
"testing"
"time"
)
func TestParseDurationShortString(t *testing.T) {
tst.AssertPDSSEqual(t, FromSeconds(1), "1s")
tst.AssertPDSSEqual(t, FromSeconds(1), "1sec")
tst.AssertPDSSEqual(t, FromSeconds(1), "1second")
tst.AssertPDSSEqual(t, FromSeconds(1), "1seconds")
tst.AssertPDSSEqual(t, FromSeconds(100), "100second")
tst.AssertPDSSEqual(t, FromSeconds(100), "100seconds")
tst.AssertPDSSEqual(t, FromSeconds(1883639.77), "1883639.77second")
tst.AssertPDSSEqual(t, FromSeconds(1883639.77), "1883639.77seconds")
tst.AssertPDSSEqual(t, FromSeconds(50), "50s")
tst.AssertPDSSEqual(t, FromSeconds(50), "50sec")
tst.AssertPDSSEqual(t, FromSeconds(1), "1second")
tst.AssertPDSSEqual(t, FromSeconds(50), "50seconds")
tst.AssertPDSSEqual(t, FromMinutes(10), "10m")
tst.AssertPDSSEqual(t, FromMinutes(10), "10min")
tst.AssertPDSSEqual(t, FromMinutes(1), "1minute")
tst.AssertPDSSEqual(t, FromMinutes(10), "10minutes")
tst.AssertPDSSEqual(t, FromMinutes(10.5), "10.5minutes")
tst.AssertPDSSEqual(t, FromMilliseconds(100), "100ms")
tst.AssertPDSSEqual(t, FromMilliseconds(100), "100milliseconds")
tst.AssertPDSSEqual(t, FromMilliseconds(100), "100millisecond")
tst.AssertPDSSEqual(t, FromNanoseconds(99235), "99235ns")
tst.AssertPDSSEqual(t, FromNanoseconds(99235), "99235nanoseconds")
tst.AssertPDSSEqual(t, FromNanoseconds(99235), "99235nanosecond")
tst.AssertPDSSEqual(t, FromMicroseconds(99235), "99235us")
tst.AssertPDSSEqual(t, FromMicroseconds(99235), "99235microseconds")
tst.AssertPDSSEqual(t, FromMicroseconds(99235), "99235microsecond")
tst.AssertPDSSEqual(t, FromHours(1), "1h")
tst.AssertPDSSEqual(t, FromHours(1), "1hour")
tst.AssertPDSSEqual(t, FromHours(2), "2hours")
tst.AssertPDSSEqual(t, FromDays(1), "1d")
tst.AssertPDSSEqual(t, FromDays(1), "1day")
tst.AssertPDSSEqual(t, FromDays(10), "10days")
tst.AssertPDSSEqual(t, FromDays(1), "1days")
tst.AssertPDSSEqual(t, FromDays(10), "10day")
tst.AssertPDSSEqual(t, FromDays(1)+FromMinutes(10), "1d10m")
tst.AssertPDSSEqual(t, FromDays(1)+FromMinutes(10)+FromSeconds(200), "1d10m200sec")
tst.AssertPDSSEqual(t, FromDays(1)+FromMinutes(10), "1d:10m")
tst.AssertPDSSEqual(t, FromDays(1)+FromMinutes(10), "1d 10m")
tst.AssertPDSSEqual(t, FromDays(1)+FromMinutes(10), "1d,10m")
tst.AssertPDSSEqual(t, FromDays(1)+FromMinutes(10), "1d, 10m")
tst.AssertPDSSEqual(t, FromDays(1)+FromSeconds(1000), "1d 1000seconds")
tst.AssertPDSSEqual(t, FromDays(1), "86400s")
}
func assertPDSSEqual(t *testing.T, expected time.Duration, fmt string) {
actual, err := ParseDurationShortString(fmt)
if err != nil {
t.Errorf("ParseDurationShortString('%s') failed with %v", fmt, err)
}
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual.String(), expected.String())
}
}

65
tst/assertions.go Normal file
View File

@@ -0,0 +1,65 @@
package tst
import (
"encoding/hex"
"testing"
)
func AssertEqual[T comparable](t *testing.T, actual T, expected T) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}
func AssertNotEqual[T comparable](t *testing.T, actual T, expected T) {
if actual == expected {
t.Errorf("values do not differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}
func AssertDeRefEqual[T comparable](t *testing.T, actual *T, expected T) {
if actual == nil {
t.Errorf("values differ: Actual: NIL, Expected: '%v'", expected)
}
if *actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}
func AssertPtrEqual[T comparable](t *testing.T, actual *T, expected *T) {
if actual == nil && expected == nil {
return
}
if actual != nil && expected != nil {
if *actual != *expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", *actual, *expected)
} else {
return
}
}
if actual == nil && expected != nil {
t.Errorf("values differ: Actual: nil, Expected: not-nil")
}
if actual != nil && expected == nil {
t.Errorf("values differ: Actual: not-nil, Expected: nil")
}
}
func AssertHexEqual(t *testing.T, expected string, actual []byte) {
actualStr := hex.EncodeToString(actual)
if actualStr != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actualStr, expected)
}
}
func AssertTrue(t *testing.T, value bool) {
if !value {
t.Error("value should be true")
}
}
func AssertFalse(t *testing.T, value bool) {
if value {
t.Error("value should be false")
}
}

48
tst/identAssertions.go Normal file
View File

@@ -0,0 +1,48 @@
package tst
import (
"testing"
)
func AssertIdentEqual[T comparable](t *testing.T, ident string, actual T, expected T) {
if actual != expected {
t.Errorf("[%s] values differ: Actual: '%v', Expected: '%v'", ident, actual, expected)
}
}
func AssertIdentNotEqual[T comparable](t *testing.T, ident string, actual T, expected T) {
if actual == expected {
t.Errorf("[%s] values do not differ: Actual: '%v', Expected: '%v'", ident, actual, expected)
}
}
func AssertIdentPtrEqual[T comparable](t *testing.T, ident string, actual *T, expected *T) {
if actual == nil && expected == nil {
return
}
if actual != nil && expected != nil {
if *actual != *expected {
t.Errorf("[%s] values differ: Actual: '%v', Expected: '%v'", ident, *actual, *expected)
} else {
return
}
}
if actual == nil && expected != nil {
t.Errorf("[%s] values differ: Actual: nil, Expected: not-nil", ident)
}
if actual != nil && expected == nil {
t.Errorf("[%s] values differ: Actual: not-nil, Expected: nil", ident)
}
}
func AssertIdentTrue(t *testing.T, ident string, value bool) {
if !value {
t.Errorf("[%s] value should be true", ident)
}
}
func AssertIdentFalse(t *testing.T, ident string, value bool) {
if !value {
t.Errorf("[%s] value should be false", ident)
}
}