Compare commits

..

25 Commits

Author SHA1 Message Date
289b9f47a2 v0.0.95 2023-03-28 16:29:16 +02:00
007c44df85 v0.0.94 2023-03-21 16:00:15 +01:00
a6252f0743 v0.0.93 2023-03-15 15:41:55 +01:00
86c01659d7 base58 2023-03-15 14:00:48 +01:00
62acddda5e v0.0.91 2023-03-11 14:38:19 +01:00
ee325f67fd v0.0.90 2023-03-09 14:51:53 +01:00
dba0cd229e v0.0.89 2023-03-07 10:43:30 +01:00
ec4dba173f v0.0.88 2023-02-16 13:27:34 +01:00
22ce2d26f3 v0.0.87 2023-02-16 13:22:15 +01:00
4fd768e573 v0.0.86 2023-02-14 17:18:58 +01:00
bf16a8165f v0.0.85 2023-02-14 16:25:45 +01:00
9f5612248a fix fd0 read error on long stdout output (scanner buffer was too small) 2023-02-13 01:41:33 +01:00
4a2b830252 added more tests to cmdrunner (reproduce another ?? cmdrunner bug...) 2023-02-09 16:49:33 +01:00
c492c80881 v0.0.83 2023-02-09 15:06:37 +01:00
26dd16d021 v0.0.82 2023-02-09 15:01:54 +01:00
b0b43de8ca v0.0.81 2023-02-09 11:27:49 +01:00
94f72e4ddf v0.0.80 2023-02-09 11:16:23 +01:00
df4388e6dc v0.0.79 2023-02-08 18:55:51 +01:00
fd33b43f31 v0.0.78 2023-02-03 01:05:36 +01:00
be4de07eb8 v0.0.77 2023-02-03 00:59:54 +01:00
36ed474bfe v0.0.76 2023-01-31 23:46:35 +01:00
fdc590c8c3 v0.0.75 2023-01-31 22:41:12 +01:00
1990e5d32d v0.0.74 2023-01-31 11:01:45 +01:00
72883cf6bd v0.0.73 2023-01-31 10:56:30 +01:00
ff08d5f180 v0.0.72 2023-01-30 19:55:55 +01:00
28 changed files with 1629 additions and 244 deletions

View File

@@ -7,6 +7,20 @@ set -o pipefail # Return value of a pipeline is the value of the last (rightmos
IFS=$'\n\t' # Set $IFS to only newline and tab. IFS=$'\n\t' # Set $IFS to only newline and tab.
function black() { echo -e "\x1B[30m $1 \x1B[0m"; }
function red() { echo -e "\x1B[31m $1 \x1B[0m"; }
function green() { echo -e "\x1B[32m $1 \x1B[0m"; }
function yellow(){ echo -e "\x1B[33m $1 \x1B[0m"; }
function blue() { echo -e "\x1B[34m $1 \x1B[0m"; }
function purple(){ echo -e "\x1B[35m $1 \x1B[0m"; }
function cyan() { echo -e "\x1B[36m $1 \x1B[0m"; }
function white() { echo -e "\x1B[37m $1 \x1B[0m"; }
if [ "$( git rev-parse --abbrev-ref HEAD )" != "master" ]; then
>&2 red "[ERROR] Can only create versions of <master>"
exit 1
fi
curr_vers=$(git describe --tags --abbrev=0 | sed 's/v//g') curr_vers=$(git describe --tags --abbrev=0 | sed 's/v//g')
next_ver=$(echo "$curr_vers" | awk -F. -v OFS=. 'NF==1{print ++$NF}; NF>1{if(length($NF+1)>length($NF))$(NF-1)++; $NF=sprintf("%0*d", length($NF), ($NF+1)%(10^length($NF))); print}') next_ver=$(echo "$curr_vers" | awk -F. -v OFS=. 'NF==1{print ++$NF}; NF>1{if(length($NF+1)>length($NF))$(NF-1)++; $NF=sprintf("%0*d", length($NF), ($NF+1)%(10^length($NF))); print}')
@@ -18,7 +32,13 @@ echo ""
git add --verbose . git add --verbose .
git commit -a -m "v${next_ver}" msg="v${next_ver}"
if [ $# -gt 0 ]; then
msg="$1"
fi
git commit -a -m "${msg}"
git tag "v${next_ver}" git tag "v${next_ver}"

View File

@@ -2,6 +2,7 @@ package cmdext
import ( import (
"fmt" "fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"time" "time"
) )
@@ -10,6 +11,9 @@ type CommandRunner struct {
args []string args []string
timeout *time.Duration timeout *time.Duration
env []string env []string
listener []CommandListener
enforceExitCodes *[]int
enforceNoTimeout bool
} }
func Runner(program string) *CommandRunner { func Runner(program string) *CommandRunner {
@@ -18,6 +22,9 @@ func Runner(program string) *CommandRunner {
args: make([]string, 0), args: make([]string, 0),
timeout: nil, timeout: nil,
env: make([]string, 0), env: make([]string, 0),
listener: make([]CommandListener, 0),
enforceExitCodes: nil,
enforceNoTimeout: false,
} }
} }
@@ -51,6 +58,36 @@ func (r *CommandRunner) Envs(env []string) *CommandRunner {
return r return r
} }
func (r *CommandRunner) EnsureExitcode(arg ...int) *CommandRunner {
r.enforceExitCodes = langext.Ptr(langext.ForceArray(arg))
return r
}
func (r *CommandRunner) FailOnExitCode() *CommandRunner {
r.enforceExitCodes = langext.Ptr([]int{0})
return r
}
func (r *CommandRunner) FailOnTimeout() *CommandRunner {
r.enforceNoTimeout = true
return r
}
func (r *CommandRunner) Listen(lstr CommandListener) *CommandRunner {
r.listener = append(r.listener, lstr)
return r
}
func (r *CommandRunner) ListenStdout(lstr func(string)) *CommandRunner {
r.listener = append(r.listener, genericCommandListener{_readStdoutLine: &lstr})
return r
}
func (r *CommandRunner) ListenStderr(lstr func(string)) *CommandRunner {
r.listener = append(r.listener, genericCommandListener{_readStderrLine: &lstr})
return r
}
func (r *CommandRunner) Run() (CommandResult, error) { func (r *CommandRunner) Run() (CommandResult, error) {
return run(*r) return run(*r)
} }

View File

@@ -1,11 +1,17 @@
package cmdext package cmdext
import ( import (
"bufio" "errors"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/mathext"
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
"os/exec" "os/exec"
"time" "time"
) )
var ErrExitCode = errors.New("process exited with an unexpected exitcode")
var ErrTimeout = errors.New("process did not exit after the specified timeout")
type CommandResult struct { type CommandResult struct {
StdOut string StdOut string
StdErr string StdErr string
@@ -29,51 +35,43 @@ func run(opt CommandRunner) (CommandResult, error) {
return CommandResult{}, err return CommandResult{}, err
} }
preader := pipeReader{
lineBufferSize: langext.Ptr(128 * 1024 * 1024), // 128MB max size of a single line, is hopefully enough....
stdout: stdoutPipe,
stderr: stderrPipe,
}
err = cmd.Start() err = cmd.Start()
if err != nil { if err != nil {
return CommandResult{}, err return CommandResult{}, err
} }
errch := make(chan error, 1) type resultObj struct {
go func() { errch <- cmd.Wait() }() stdout string
stderr string
combch := make(chan string, 32) stdcombined string
stopCombch := make(chan bool) err error
stdout := ""
go func() {
scanner := bufio.NewScanner(stdoutPipe)
for scanner.Scan() {
txt := scanner.Text()
stdout += txt
combch <- txt
} }
}()
stderr := "" outputChan := make(chan resultObj)
go func() { go func() {
scanner := bufio.NewScanner(stderrPipe) // we need to first fully read the pipes and then call Wait
for scanner.Scan() { // see https://pkg.go.dev/os/exec#Cmd.StdoutPipe
txt := scanner.Text()
stderr += txt
combch <- txt
}
}()
defer func() { stdout, stderr, stdcombined, err := preader.Read(opt.listener)
stopCombch <- true if err != nil {
}() outputChan <- resultObj{stdout, stderr, stdcombined, err}
_ = cmd.Process.Kill()
stdcombined := ""
go func() {
for {
select {
case txt := <-combch:
stdcombined += txt
case <-stopCombch:
return return
} }
err = cmd.Wait()
if err != nil {
outputChan <- resultObj{stdout, stderr, stdcombined, err}
} else {
outputChan <- resultObj{stdout, stderr, stdcombined, nil}
} }
}() }()
var timeoutChan <-chan time.Time = make(chan time.Time, 1) var timeoutChan <-chan time.Time = make(chan time.Time, 1)
@@ -85,33 +83,72 @@ func run(opt CommandRunner) (CommandResult, error) {
case <-timeoutChan: case <-timeoutChan:
_ = cmd.Process.Kill() _ = cmd.Process.Kill()
return CommandResult{ for _, lstr := range opt.listener {
StdOut: stdout, lstr.Timeout()
StdErr: stderr, }
StdCombined: stdcombined,
if fallback, ok := syncext.ReadChannelWithTimeout(outputChan, mathext.Min(32*time.Millisecond, *opt.timeout)); ok {
// most of the time the cmd.Process.Kill() should also ahve finished the pipereader
// and we can at least return the already collected stdout, stderr, etc
res := CommandResult{
StdOut: fallback.stdout,
StdErr: fallback.stderr,
StdCombined: fallback.stdcombined,
ExitCode: -1, ExitCode: -1,
CommandTimedOut: true, CommandTimedOut: true,
}, nil }
if opt.enforceNoTimeout {
return res, ErrTimeout
}
return res, nil
} else {
res := CommandResult{
StdOut: "",
StdErr: "",
StdCombined: "",
ExitCode: -1,
CommandTimedOut: true,
}
if opt.enforceNoTimeout {
return res, ErrTimeout
}
return res, nil
}
case err := <-errch: case outobj := <-outputChan:
if exiterr, ok := err.(*exec.ExitError); ok { if exiterr, ok := outobj.err.(*exec.ExitError); ok {
return CommandResult{ excode := exiterr.ExitCode()
StdOut: stdout, for _, lstr := range opt.listener {
StdErr: stderr, lstr.Finished(excode)
StdCombined: stdcombined, }
ExitCode: exiterr.ExitCode(), res := CommandResult{
StdOut: outobj.stdout,
StdErr: outobj.stderr,
StdCombined: outobj.stdcombined,
ExitCode: excode,
CommandTimedOut: false, CommandTimedOut: false,
}, nil }
if opt.enforceExitCodes != nil && !langext.InArray(excode, *opt.enforceExitCodes) {
return res, ErrExitCode
}
return res, nil
} else if err != nil { } else if err != nil {
return CommandResult{}, err return CommandResult{}, err
} else { } else {
return CommandResult{ for _, lstr := range opt.listener {
StdOut: stdout, lstr.Finished(0)
StdErr: stderr, }
StdCombined: stdcombined, res := CommandResult{
StdOut: outobj.stdout,
StdErr: outobj.stderr,
StdCombined: outobj.stdcombined,
ExitCode: 0, ExitCode: 0,
CommandTimedOut: false, CommandTimedOut: false,
}, nil }
if opt.enforceExitCodes != nil && !langext.InArray(0, *opt.enforceExitCodes) {
return res, ErrExitCode
}
return res, nil
} }
} }
} }

323
cmdext/cmdrunner_test.go Normal file
View File

@@ -0,0 +1,323 @@
package cmdext
import (
"fmt"
"testing"
"time"
)
func TestStdout(t *testing.T) {
res1, err := Runner("printf").Arg("hello").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "hello" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "hello\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestStderr(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"error\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "error" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "error\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestStdcombined(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"1\", file=sys.stderr, flush=True); time.sleep(0.1); print(\"2\", file=sys.stdout, flush=True); time.sleep(0.1); print(\"3\", file=sys.stderr, flush=True)").
Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "1\n3\n" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "2\n" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "1\n2\n3\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialRead(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", flush=True); time.sleep(5); print(\"cant see me\", flush=True);").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "first message\n" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadStderr(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, flush=True); time.sleep(5); print(\"cant see me\", file=sys.stderr, flush=True);").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "first message\n" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestReadUnflushedStdout(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stdout, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "message101" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "message101\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestReadUnflushedStderr(t *testing.T) {
res1, err := Runner("python").Arg("-c").Arg("import sys; print(\"message101\", file=sys.stderr, end='')").Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "message101" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "message101\n" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadUnflushed(t *testing.T) {
t.SkipNow()
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", end=''); time.sleep(5); print(\"cant see me\", end='');").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "first message" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestPartialReadUnflushedStderr(t *testing.T) {
t.SkipNow()
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"first message\", file=sys.stderr, end=''); time.sleep(5); print(\"cant see me\", file=sys.stderr, end='');").
Timeout(100 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if !res1.CommandTimedOut {
t.Errorf("!CommandTimedOut")
}
if res1.StdErr != "first message" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if res1.StdOut != "" {
t.Errorf("res1.StdOut == '%v'", res1.StdOut)
}
if res1.StdCombined != "first message" {
t.Errorf("res1.StdCombined == '%v'", res1.StdCombined)
}
}
func TestListener(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys;" +
"import time;" +
"print(\"message 1\", flush=True);" +
"time.sleep(1);" +
"print(\"message 2\", flush=True);" +
"time.sleep(1);" +
"print(\"message 3\", flush=True);" +
"time.sleep(1);" +
"print(\"message 4\", file=sys.stderr, flush=True);" +
"time.sleep(1);" +
"print(\"message 5\", flush=True);" +
"time.sleep(1);" +
"print(\"final\");").
ListenStdout(func(s string) { fmt.Printf("@@STDOUT <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
ListenStderr(func(s string) { fmt.Printf("@@STDERR <<- %v (%v)\n", s, time.Now().Format(time.RFC3339Nano)) }).
Timeout(10 * time.Second).
Run()
if err != nil {
panic(err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
}
func TestLongStdout(t *testing.T) {
res1, err := Runner("python").
Arg("-c").
Arg("import sys; import time; print(\"X\" * 125001 + \"\\n\"); print(\"Y\" * 125001 + \"\\n\"); print(\"Z\" * 125001 + \"\\n\");").
Timeout(5000 * time.Millisecond).
Run()
if err != nil {
t.Errorf("%v", err)
}
if res1.CommandTimedOut {
t.Errorf("Timeout")
}
if res1.ExitCode != 0 {
t.Errorf("res1.ExitCode == %v", res1.ExitCode)
}
if res1.StdErr != "" {
t.Errorf("res1.StdErr == '%v'", res1.StdErr)
}
if len(res1.StdOut) != 375009 {
t.Errorf("len(res1.StdOut) == '%v'", len(res1.StdOut))
}
}
func TestFailOnTimeout(t *testing.T) {
_, err := Runner("sleep").Arg("2").Timeout(200 * time.Millisecond).FailOnTimeout().Run()
if err != ErrTimeout {
t.Errorf("wrong err := %v", err)
}
}
func TestFailOnExitcode(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).FailOnExitCode().Run()
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}
}
func TestEnsureExitcode1(t *testing.T) {
_, err := Runner("false").Timeout(200 * time.Millisecond).EnsureExitcode(1).Run()
if err != nil {
t.Errorf("wrong err := %v", err)
}
}
func TestEnsureExitcode2(t *testing.T) {
_, err := Runner("false").Timeout(200*time.Millisecond).EnsureExitcode(0, 2, 3).Run()
if err != ErrExitCode {
t.Errorf("wrong err := %v", err)
}
}

57
cmdext/listener.go Normal file
View File

@@ -0,0 +1,57 @@
package cmdext
type CommandListener interface {
ReadRawStdout([]byte)
ReadRawStderr([]byte)
ReadStdoutLine(string)
ReadStderrLine(string)
Finished(int)
Timeout()
}
type genericCommandListener struct {
_readRawStdout *func([]byte)
_readRawStderr *func([]byte)
_readStdoutLine *func(string)
_readStderrLine *func(string)
_finished *func(int)
_timeout *func()
}
func (g genericCommandListener) ReadRawStdout(v []byte) {
if g._readRawStdout != nil {
(*g._readRawStdout)(v)
}
}
func (g genericCommandListener) ReadRawStderr(v []byte) {
if g._readRawStderr != nil {
(*g._readRawStderr)(v)
}
}
func (g genericCommandListener) ReadStdoutLine(v string) {
if g._readStdoutLine != nil {
(*g._readStdoutLine)(v)
}
}
func (g genericCommandListener) ReadStderrLine(v string) {
if g._readStderrLine != nil {
(*g._readStderrLine)(v)
}
}
func (g genericCommandListener) Finished(v int) {
if g._finished != nil {
(*g._finished)(v)
}
}
func (g genericCommandListener) Timeout() {
if g._timeout != nil {
(*g._timeout)()
}
}

158
cmdext/pipereader.go Normal file
View File

@@ -0,0 +1,158 @@
package cmdext
import (
"bufio"
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
"io"
"sync"
)
type pipeReader struct {
lineBufferSize *int
stdout io.ReadCloser
stderr io.ReadCloser
}
// Read ready stdout and stdin until finished
// also splits both pipes into lines and calld the listener
func (pr *pipeReader) Read(listener []CommandListener) (string, string, string, error) {
type combevt struct {
line string
stop bool
}
errch := make(chan error, 8)
wg := sync.WaitGroup{}
// [1] read raw stdout
wg.Add(1)
stdoutBufferReader, stdoutBufferWriter := io.Pipe()
stdout := ""
go func() {
buf := make([]byte, 128)
for true {
n, out := pr.stdout.Read(buf)
if n > 0 {
txt := string(buf[:n])
stdout += txt
_, _ = stdoutBufferWriter.Write(buf[:n])
for _, lstr := range listener {
lstr.ReadRawStdout(buf[:n])
}
}
if out == io.EOF {
break
}
if out != nil {
errch <- out
break
}
}
_ = stdoutBufferWriter.Close()
wg.Done()
}()
// [2] read raw stderr
wg.Add(1)
stderrBufferReader, stderrBufferWriter := io.Pipe()
stderr := ""
go func() {
buf := make([]byte, 128)
for true {
n, err := pr.stderr.Read(buf)
if n > 0 {
txt := string(buf[:n])
stderr += txt
_, _ = stderrBufferWriter.Write(buf[:n])
for _, lstr := range listener {
lstr.ReadRawStderr(buf[:n])
}
}
if err == io.EOF {
break
}
if err != nil {
errch <- err
break
}
}
_ = stderrBufferWriter.Close()
wg.Done()
}()
combch := make(chan combevt, 32)
// [3] collect stdout line-by-line
wg.Add(1)
go func() {
scanner := bufio.NewScanner(stdoutBufferReader)
if pr.lineBufferSize != nil {
scanner.Buffer([]byte{}, *pr.lineBufferSize)
}
for scanner.Scan() {
txt := scanner.Text()
for _, lstr := range listener {
lstr.ReadStdoutLine(txt)
}
combch <- combevt{txt, false}
}
if err := scanner.Err(); err != nil {
errch <- err
}
combch <- combevt{"", true}
wg.Done()
}()
// [4] collect stderr line-by-line
wg.Add(1)
go func() {
scanner := bufio.NewScanner(stderrBufferReader)
if pr.lineBufferSize != nil {
scanner.Buffer([]byte{}, *pr.lineBufferSize)
}
for scanner.Scan() {
txt := scanner.Text()
for _, lstr := range listener {
lstr.ReadStderrLine(txt)
}
combch <- combevt{txt, false}
}
if err := scanner.Err(); err != nil {
errch <- err
}
combch <- combevt{"", true}
wg.Done()
}()
// [5] combine stdcombined
wg.Add(1)
stdcombined := ""
go func() {
stopctr := 0
for stopctr < 2 {
vvv := <-combch
if vvv.stop {
stopctr++
} else {
stdcombined += vvv.line + "\n" // this comes from bufio.Scanner and has no newlines...
}
}
wg.Done()
}()
// wait for all (5) goroutines to finish
wg.Wait()
if err, ok := syncext.ReadNonBlocking(errch); ok {
return "", "", "", err
}
return stdout, stderr, stdcombined, nil
}

View File

@@ -22,10 +22,10 @@ import (
// //
// sub-structs are recursively parsed (if they have an env tag) and the env-variable keys are delimited by the delim parameter // sub-structs are recursively parsed (if they have an env tag) and the env-variable keys are delimited by the delim parameter
// sub-structs with `env:""` are also parsed, but the delimited is skipped (they are handled as if they were one level higher) // sub-structs with `env:""` are also parsed, but the delimited is skipped (they are handled as if they were one level higher)
func ApplyEnvOverrides[T any](c *T, delim string) error { func ApplyEnvOverrides[T any](prefix string, c *T, delim string) error {
rval := reflect.ValueOf(c).Elem() rval := reflect.ValueOf(c).Elem()
return processEnvOverrides(rval, delim, "") return processEnvOverrides(rval, delim, prefix)
} }
func processEnvOverrides(rval reflect.Value, delim string, prefix string) error { func processEnvOverrides(rval reflect.Value, delim string, prefix string) error {
@@ -70,103 +70,114 @@ func processEnvOverrides(rval reflect.Value, delim string, prefix string) error
continue continue
} }
if rvfield.Type() == reflect.TypeOf("") { if rvfield.Type().Kind() == reflect.Pointer {
rvfield.Set(reflect.ValueOf(envval)) newval, err := parseEnvToValue(envval, fullEnvKey, rvfield.Type().Elem())
if err != nil {
return err
}
// converts reflect.Value to pointer
ptrval := reflect.New(rvfield.Type().Elem())
ptrval.Elem().Set(newval)
rvfield.Set(ptrval)
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval) fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int(0)) {
envint, err := strconv.ParseInt(envval, 10, bits.UintSize)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int64(0)) {
envint, err := strconv.ParseInt(envval, 10, 64)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int64 (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int64(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int32(0)) {
envint, err := strconv.ParseInt(envval, 10, 32)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int32(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(int8(0)) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(int8(envint)))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
} else if rvfield.Type() == reflect.TypeOf(time.Duration(0)) {
dur, err := timeext.ParseDurationShortString(envval)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to duration (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(dur))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, dur.String())
} else if rvfield.Type() == reflect.TypeOf(time.UnixMilli(0)) {
tim, err := time.Parse(time.RFC3339Nano, envval)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to time.time (value := '%s')", fullEnvKey, envval))
}
rvfield.Set(reflect.ValueOf(tim))
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, tim.String())
} else if rvfield.Type().ConvertibleTo(reflect.TypeOf(int(0))) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,int> (value := '%s')", rvfield.Type().Name(), fullEnvKey, envval))
}
envcvl := reflect.ValueOf(envint).Convert(rvfield.Type())
rvfield.Set(envcvl)
fmt.Printf("[CONF] Overwrite config '%s' with '%v'\n", fullEnvKey, envcvl.Interface())
} else if rvfield.Type().ConvertibleTo(reflect.TypeOf("")) {
envcvl := reflect.ValueOf(envval).Convert(rvfield.Type())
rvfield.Set(envcvl)
fmt.Printf("[CONF] Overwrite config '%s' with '%v'\n", fullEnvKey, envcvl.Interface())
} else { } else {
return errors.New(fmt.Sprintf("Unknown kind/type in config: [ %s | %s ]", rvfield.Kind().String(), rvfield.Type().String()))
newval, err := parseEnvToValue(envval, fullEnvKey, rvfield.Type())
if err != nil {
return err
} }
rvfield.Set(newval)
fmt.Printf("[CONF] Overwrite config '%s' with '%s'\n", fullEnvKey, envval)
}
} }
return nil return nil
} }
func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (reflect.Value, error) {
if rvtype == reflect.TypeOf("") {
return reflect.ValueOf(envval), nil
} else if rvtype == reflect.TypeOf(int(0)) {
envint, err := strconv.ParseInt(envval, 10, bits.UintSize)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int(envint)), nil
} else if rvtype == reflect.TypeOf(int64(0)) {
envint, err := strconv.ParseInt(envval, 10, 64)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int64 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int64(envint)), nil
} else if rvtype == reflect.TypeOf(int32(0)) {
envint, err := strconv.ParseInt(envval, 10, 32)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int32(envint)), nil
} else if rvtype == reflect.TypeOf(int8(0)) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to int32 (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(int8(envint)), nil
} else if rvtype == reflect.TypeOf(time.Duration(0)) {
dur, err := timeext.ParseDurationShortString(envval)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to duration (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(dur), nil
} else if rvtype == reflect.TypeOf(time.UnixMilli(0)) {
tim, err := time.Parse(time.RFC3339Nano, envval)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to time.time (value := '%s')", fullEnvKey, envval))
}
return reflect.ValueOf(tim), nil
} else if rvtype.ConvertibleTo(reflect.TypeOf(int(0))) {
envint, err := strconv.ParseInt(envval, 10, 8)
if err != nil {
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,int> (value := '%s')", rvtype.Name(), fullEnvKey, envval))
}
envcvl := reflect.ValueOf(envint).Convert(rvtype)
return envcvl, nil
} else if rvtype.ConvertibleTo(reflect.TypeOf("")) {
envcvl := reflect.ValueOf(envval).Convert(rvtype)
return envcvl, nil
} else {
return reflect.Value{}, errors.New(fmt.Sprintf("Unknown kind/type in config: [ %s | %s ]", rvtype.Kind().String(), rvtype.String()))
}
}

View File

@@ -213,8 +213,65 @@ func TestApplyEnvOverridesRecursive(t *testing.T) {
assertEqual(t, data.Sub4.V9, time.Unix(2335219200, 0).UTC()) assertEqual(t, data.Sub4.V9, time.Unix(2335219200, 0).UTC())
} }
func TestApplyEnvOverridesPointer(t *testing.T) {
type aliasint int
type aliasstring string
type testdata struct {
V1 *int `env:"TEST_V1"`
VX *string ``
V2 *string `env:"TEST_V2"`
V3 *int8 `env:"TEST_V3"`
V4 *int32 `env:"TEST_V4"`
V5 *int64 `env:"TEST_V5"`
V6 *aliasint `env:"TEST_V6"`
VY *aliasint ``
V7 *aliasstring `env:"TEST_V7"`
V8 *time.Duration `env:"TEST_V8"`
V9 *time.Time `env:"TEST_V9"`
}
data := testdata{}
t.Setenv("TEST_V1", "846")
t.Setenv("TEST_V2", "hello_world")
t.Setenv("TEST_V3", "6")
t.Setenv("TEST_V4", "333")
t.Setenv("TEST_V5", "-937")
t.Setenv("TEST_V6", "070")
t.Setenv("TEST_V7", "AAAAAA")
t.Setenv("TEST_V8", "1min4s")
t.Setenv("TEST_V9", "2009-11-10T23:00:00Z")
err := ApplyEnvOverrides(&data, ".")
if err != nil {
t.Errorf("%v", err)
t.FailNow()
}
assertPtrEqual(t, data.V1, 846)
assertPtrEqual(t, data.V2, "hello_world")
assertPtrEqual(t, data.V3, 6)
assertPtrEqual(t, data.V4, 333)
assertPtrEqual(t, data.V5, -937)
assertPtrEqual(t, data.V6, 70)
assertPtrEqual(t, data.V7, "AAAAAA")
assertPtrEqual(t, data.V8, time.Second*64)
assertPtrEqual(t, data.V9, time.Unix(1257894000, 0).UTC())
}
func assertEqual[T comparable](t *testing.T, actual T, expected T) { func assertEqual[T comparable](t *testing.T, actual T, expected T) {
if actual != expected { if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected) t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
} }
} }
func assertPtrEqual[T comparable](t *testing.T, actual *T, expected T) {
if actual == nil {
t.Errorf("values differ: Actual: NIL, Expected: '%v'", expected)
}
if *actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -1,10 +1,13 @@
package cryptext package cryptext
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/rand" "crypto/rand"
"encoding/base64" "crypto/sha256"
"encoding/base32"
"encoding/json"
"errors" "errors"
"golang.org/x/crypto/scrypt" "golang.org/x/crypto/scrypt"
"io" "io"
@@ -12,35 +15,90 @@ import (
// https://stackoverflow.com/a/18819040/1761622 // https://stackoverflow.com/a/18819040/1761622
func EncryptAESSimple(password, text []byte) ([]byte, error) { type aesPayload struct {
Salt []byte `json:"s"`
key, err := scrypt.Key(password, nil, 32768, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing IV []byte `json:"i"`
if err != nil { Data []byte `json:"d"`
return nil, err Rounds int `json:"r"`
} Version uint `json:"v"`
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
b := base64.StdEncoding.EncodeToString(text)
ciphertext := make([]byte, aes.BlockSize+len(b))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b))
return ciphertext, nil
} }
func DecryptAESSimple(password, text []byte) ([]byte, error) { func EncryptAESSimple(password []byte, data []byte, rounds int) (string, error) {
key, err := scrypt.Key(password, nil, 32768, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing salt := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, salt)
if err != nil {
return "", err
}
key, err := scrypt.Key(password, salt, rounds, 8, 1, 32)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
h := sha256.New()
h.Write(data)
checksum := h.Sum(nil)
if len(checksum) != 32 {
return "", errors.New("wrong cs size")
}
ciphertext := make([]byte, 32+len(data))
iv := make([]byte, aes.BlockSize)
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return "", err
}
combinedData := make([]byte, 0, 32+len(data))
combinedData = append(combinedData, checksum...)
combinedData = append(combinedData, data...)
cfb := cipher.NewCFBEncrypter(block, iv)
cfb.XORKeyStream(ciphertext, combinedData)
pl := aesPayload{
Salt: salt,
IV: iv,
Data: ciphertext,
Version: 1,
Rounds: rounds,
}
jbin, err := json.Marshal(pl)
if err != nil {
return "", err
}
res := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(jbin)
return res, nil
}
func DecryptAESSimple(password []byte, encText string) ([]byte, error) {
jbin, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encText)
if err != nil {
return nil, err
}
var pl aesPayload
err = json.Unmarshal(jbin, &pl)
if err != nil {
return nil, err
}
if pl.Version != 1 {
return nil, errors.New("unsupported version")
}
key, err := scrypt.Key(password, pl.Salt, pl.Rounds, 8, 1, 32) // this is not 100% correct, rounds too low and salt is missing
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -50,18 +108,24 @@ func DecryptAESSimple(password, text []byte) ([]byte, error) {
return nil, err return nil, err
} }
if len(text) < aes.BlockSize { dest := make([]byte, len(pl.Data))
return nil, errors.New("ciphertext too short")
cfb := cipher.NewCFBDecrypter(block, pl.IV)
cfb.XORKeyStream(dest, pl.Data)
if len(dest) < 32 {
return nil, errors.New("payload too small")
} }
iv := text[:aes.BlockSize] chck := dest[:32]
text = text[aes.BlockSize:] data := dest[32:]
cfb := cipher.NewCFBDecrypter(block, iv)
cfb.XORKeyStream(text, text)
data, err := base64.StdEncoding.DecodeString(string(text)) h := sha256.New()
if err != nil { h.Write(data)
return nil, err chck2 := h.Sum(nil)
if !bytes.Equal(chck, chck2) {
return nil, errors.New("checksum mismatch")
} }
return data, nil return data, nil

View File

@@ -1,6 +1,9 @@
package cryptext package cryptext
import "testing" import (
"fmt"
"testing"
)
func TestEncryptAESSimple(t *testing.T) { func TestEncryptAESSimple(t *testing.T) {
@@ -8,15 +11,25 @@ func TestEncryptAESSimple(t *testing.T) {
str1 := []byte("Hello World") str1 := []byte("Hello World")
str2, err := EncryptAESSimple(pw, str1) str2, err := EncryptAESSimple(pw, str1, 512)
if err != nil { if err != nil {
panic(err) panic(err)
} }
fmt.Printf("%s\n", str2)
str3, err := DecryptAESSimple(pw, str2) str3, err := DecryptAESSimple(pw, str2)
if err != nil { if err != nil {
panic(err) panic(err)
} }
assertEqual(t, string(str1), string(str3)) assertEqual(t, string(str1), string(str3))
str4, err := EncryptAESSimple(pw, str3, 512)
if err != nil {
panic(err)
}
assertNotEqual(t, string(str2), string(str4))
} }

View File

@@ -23,3 +23,9 @@ func assertEqual(t *testing.T, actual string, expected string) {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected) t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
} }
} }
func assertNotEqual(t *testing.T, actual string, expected string) {
if actual == expected {
t.Errorf("values do not differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -12,7 +12,7 @@ func init() {
} }
func TestResultCache1(t *testing.T) { func TestResultCache1(t *testing.T) {
cache := NewLRUMap[string](8) cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t) verifyLRUList(cache, t)
key := randomKey() key := randomKey()
@@ -50,7 +50,7 @@ func TestResultCache1(t *testing.T) {
} }
func TestResultCache2(t *testing.T) { func TestResultCache2(t *testing.T) {
cache := NewLRUMap[string](8) cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t) verifyLRUList(cache, t)
key1 := "key1" key1 := "key1"
@@ -150,7 +150,7 @@ func TestResultCache2(t *testing.T) {
} }
func TestResultCache3(t *testing.T) { func TestResultCache3(t *testing.T) {
cache := NewLRUMap[string](8) cache := NewLRUMap[string, string](8)
verifyLRUList(cache, t) verifyLRUList(cache, t)
key1 := "key1" key1 := "key1"
@@ -173,7 +173,7 @@ func TestResultCache3(t *testing.T) {
} }
// does a basic consistency check over the internal cache representation // does a basic consistency check over the internal cache representation
func verifyLRUList[TData any](cache *LRUMap[TData], t *testing.T) { func verifyLRUList[TKey comparable, TData any](cache *LRUMap[TKey, TData], t *testing.T) {
size := 0 size := 0
tailFound := false tailFound := false

1
go.mod
View File

@@ -9,5 +9,6 @@ require (
require ( require (
github.com/jmoiron/sqlx v1.3.5 // indirect github.com/jmoiron/sqlx v1.3.5 // indirect
go.mongodb.org/mongo-driver v1.11.1 // indirect
golang.org/x/crypto v0.4.0 // indirect golang.org/x/crypto v0.4.0 // indirect
) )

35
go.sum
View File

@@ -1,15 +1,50 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
go.mongodb.org/mongo-driver v1.11.1 h1:QP0znIRTuL0jf1oBQoAoM0C6ZJfBK4kx0Uumtv1A7w8=
go.mongodb.org/mongo-driver v1.11.1/go.mod h1:s7p5vEtfbeR1gYi6pnj3c3/urpbLv2T5Sfd6Rp2HBB8=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8=
golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI= golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -265,6 +265,66 @@ func ArrMap[T1 any, T2 any](arr []T1, conv func(v T1) T2) []T2 {
return r return r
} }
func MapMap[TK comparable, TV any, TR any](inmap map[TK]TV, conv func(k TK, v TV) TR) []TR {
r := make([]TR, 0, len(inmap))
for k, v := range inmap {
r = append(r, conv(k, v))
}
return r
}
func MapMapErr[TK comparable, TV any, TR any](inmap map[TK]TV, conv func(k TK, v TV) (TR, error)) ([]TR, error) {
r := make([]TR, 0, len(inmap))
for k, v := range inmap {
elem, err := conv(k, v)
if err != nil {
return nil, err
}
r = append(r, elem)
}
return r, nil
}
func ArrMapExt[T1 any, T2 any](arr []T1, conv func(idx int, v T1) T2) []T2 {
r := make([]T2, len(arr))
for i, v := range arr {
r[i] = conv(i, v)
}
return r
}
func ArrMapErr[T1 any, T2 any](arr []T1, conv func(v T1) (T2, error)) ([]T2, error) {
var err error
r := make([]T2, len(arr))
for i, v := range arr {
r[i], err = conv(v)
if err != nil {
return nil, err
}
}
return r, nil
}
func ArrFilterMap[T1 any, T2 any](arr []T1, filter func(v T1) bool, conv func(v T1) T2) []T2 {
r := make([]T2, 0, len(arr))
for _, v := range arr {
if filter(v) {
r = append(r, conv(v))
}
}
return r
}
func ArrFilter[T any](arr []T, filter func(v T) bool) []T {
r := make([]T, 0, len(arr))
for _, v := range arr {
if filter(v) {
r = append(r, v)
}
}
return r
}
func ArrSum[T NumberConstraint](arr []T) T { func ArrSum[T NumberConstraint](arr []T) T {
var r T = 0 var r T = 0
for _, v := range arr { for _, v := range arr {
@@ -272,3 +332,19 @@ func ArrSum[T NumberConstraint](arr []T) T {
} }
return r return r
} }
func ArrFlatten[T1 any, T2 any](arr []T1, conv func(v T1) []T2) []T2 {
r := make([]T2, 0, len(arr))
for _, v1 := range arr {
r = append(r, conv(v1)...)
}
return r
}
func ArrFlattenDirect[T1 any](arr [][]T1) []T1 {
r := make([]T1, 0, len(arr))
for _, v1 := range arr {
r = append(r, v1...)
}
return r
}

178
langext/base58.go Normal file
View File

@@ -0,0 +1,178 @@
package langext
import (
"bytes"
"errors"
"math/big"
)
// shamelessly stolen from https://github.com/btcsuite/
type B58Encoding struct {
bigRadix [11]*big.Int
bigRadix10 *big.Int
alphabet string
alphabetIdx0 byte
b58 [256]byte
}
var Base58DefaultEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
var Base58FlickrEncoding = newBase58Encoding("123456789abcdefghijkmnopqrstuvwxyzABCDEFGHJKLMNPQRSTUVWXYZ")
var Base58RippleEncoding = newBase58Encoding("rpshnaf39wBUDNEGHJKLM4PQRST7VWXYZ2bcdeCg65jkm8oFqi1tuvAxyz")
var Base58BitcoinEncoding = newBase58Encoding("123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz")
func newBase58Encoding(alphabet string) *B58Encoding {
bigRadix10 := big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58)
enc := &B58Encoding{
alphabet: alphabet,
alphabetIdx0: '1',
bigRadix: [...]*big.Int{
big.NewInt(0),
big.NewInt(58),
big.NewInt(58 * 58),
big.NewInt(58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
bigRadix10,
},
bigRadix10: bigRadix10,
}
b58 := make([]byte, 0, 256)
for i := byte(0); i < 32; i++ {
for j := byte(0); j < 8; j++ {
b := i*8 + j
idx := bytes.IndexByte([]byte(alphabet), b)
if idx == -1 {
b58 = append(b58, 255)
} else {
b58 = append(b58, byte(idx))
}
}
}
enc.b58 = *((*[256]byte)(b58))
return enc
}
func (enc *B58Encoding) EncodeString(src string) (string, error) {
v, err := enc.Encode([]byte(src))
if err != nil {
return "", err
}
return string(v), nil
}
func (enc *B58Encoding) Encode(src []byte) ([]byte, error) {
x := new(big.Int)
x.SetBytes(src)
// maximum length of output is log58(2^(8*len(b))) == len(b) * 8 / log(58)
maxlen := int(float64(len(src))*1.365658237309761) + 1
answer := make([]byte, 0, maxlen)
mod := new(big.Int)
for x.Sign() > 0 {
// Calculating with big.Int is slow for each iteration.
// x, mod = x / 58, x % 58
//
// Instead we can try to do as much calculations on int64.
// x, mod = x / 58^10, x % 58^10
//
// Which will give us mod, which is 10 digit base58 number.
// We'll loop that 10 times to convert to the answer.
x.DivMod(x, enc.bigRadix10, mod)
if x.Sign() == 0 {
// When x = 0, we need to ensure we don't add any extra zeros.
m := mod.Int64()
for m > 0 {
answer = append(answer, enc.alphabet[m%58])
m /= 58
}
} else {
m := mod.Int64()
for i := 0; i < 10; i++ {
answer = append(answer, enc.alphabet[m%58])
m /= 58
}
}
}
// leading zero bytes
for _, i := range src {
if i != 0 {
break
}
answer = append(answer, enc.alphabetIdx0)
}
// reverse
alen := len(answer)
for i := 0; i < alen/2; i++ {
answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
}
return answer, nil
}
func (enc *B58Encoding) DecodeString(src string) (string, error) {
v, err := enc.Decode([]byte(src))
if err != nil {
return "", err
}
return string(v), nil
}
func (enc *B58Encoding) Decode(src []byte) ([]byte, error) {
answer := big.NewInt(0)
scratch := new(big.Int)
for t := src; len(t) > 0; {
n := len(t)
if n > 10 {
n = 10
}
total := uint64(0)
for _, v := range t[:n] {
if v > 255 {
return []byte{}, errors.New("invalid char in input")
}
tmp := enc.b58[v]
if tmp == 255 {
return []byte{}, errors.New("invalid char in input")
}
total = total*58 + uint64(tmp)
}
answer.Mul(answer, enc.bigRadix[n])
scratch.SetUint64(total)
answer.Add(answer, scratch)
t = t[n:]
}
tmpval := answer.Bytes()
var numZeros int
for numZeros = 0; numZeros < len(src); numZeros++ {
if src[numZeros] != enc.alphabetIdx0 {
break
}
}
flen := numZeros + len(tmpval)
val := make([]byte, flen)
copy(val[numZeros:], tmpval)
return val, nil
}

67
langext/base58_test.go Normal file
View File

@@ -0,0 +1,67 @@
package langext
import (
"testing"
)
func _encStr(t *testing.T, enc *B58Encoding, v string) string {
v, err := enc.EncodeString(v)
if err != nil {
t.Error(err)
}
return v
}
func _decStr(t *testing.T, enc *B58Encoding, v string) string {
v, err := enc.DecodeString(v)
if err != nil {
t.Error(err)
}
return v
}
func TestBase58DefaultEncoding(t *testing.T) {
assertEqual(t, _encStr(t, Base58DefaultEncoding, "Hello"), "9Ajdvzr")
assertEqual(t, _encStr(t, Base58DefaultEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX")
}
func TestBase58DefaultDecoding(t *testing.T) {
assertEqual(t, _decStr(t, Base58DefaultEncoding, "9Ajdvzr"), "Hello")
assertEqual(t, _decStr(t, Base58DefaultEncoding, "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58RippleEncoding(t *testing.T) {
assertEqual(t, _encStr(t, Base58RippleEncoding, "Hello"), "9wjdvzi")
assertEqual(t, _encStr(t, Base58RippleEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "h3as3SMcJu26nokqPxhkUVCnd3Qwdgbd4Cp3gfReYrsFi7N44bMy11jqnGj1iJHEnzeZCq1huJM7JHifVbi7hXB7ZpEA9DVtqt894reXucNWSNZ26XVaAhy1GSWqGdFeYTJCrMdDzTg3vCcQV55CJjZX")
}
func TestBase58RippleDecoding(t *testing.T) {
assertEqual(t, _decStr(t, Base58RippleEncoding, "9wjdvzi"), "Hello")
assertEqual(t, _decStr(t, Base58RippleEncoding, "h3as3SMcJu26nokqPxhkUVCnd3Qwdgbd4Cp3gfReYrsFi7N44bMy11jqnGj1iJHEnzeZCq1huJM7JHifVbi7hXB7ZpEA9DVtqt894reXucNWSNZ26XVaAhy1GSWqGdFeYTJCrMdDzTg3vCcQV55CJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58BitcoinEncoding(t *testing.T) {
assertEqual(t, _encStr(t, Base58BitcoinEncoding, "Hello"), "9Ajdvzr")
assertEqual(t, _encStr(t, Base58BitcoinEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX")
}
func TestBase58BitcoinDecoding(t *testing.T) {
assertEqual(t, _decStr(t, Base58BitcoinEncoding, "9Ajdvzr"), "Hello")
assertEqual(t, _decStr(t, Base58BitcoinEncoding, "48638SMcJuah5okqPx4kCVf5d8QAdgbdNf28g7ReY13prUENNbMyssjq5GjsrJHF5zeZfqs4uJMUJHr7VbrU4XBUZ2Fw9DVtqtn9N1eXucEWSEZahXV6w4ysGSWqGdpeYTJf1MdDzTg8vfcQViifJjZX"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func TestBase58FlickrEncoding(t *testing.T) {
assertEqual(t, _encStr(t, Base58FlickrEncoding, "Hello"), "9aJCVZR")
assertEqual(t, _encStr(t, Base58FlickrEncoding, "If debugging is the process of removing software bugs, then programming must be the process of putting them in."), "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw")
}
func TestBase58FlickrDecoding(t *testing.T) {
assertEqual(t, _decStr(t, Base58FlickrEncoding, "9aJCVZR"), "Hello")
assertEqual(t, _decStr(t, Base58FlickrEncoding, "48638rmBiUzG5NKQoX4KcuE5C8paCFACnE28F7qDx13PRtennAmYSSJQ5gJSRihf5ZDyEQS4UimtihR7uARt4wbty2fW9duTQTM9n1DwUBevreyzGwu6W4YSgrvQgCPDxsiE1mCdZsF8VEBpuHHEiJyw"), "If debugging is the process of removing software bugs, then programming must be the process of putting them in.")
}
func assertEqual(t *testing.T, actual string, expected string) {
if actual != expected {
t.Errorf("values differ: Actual: '%v', Expected: '%v'", actual, expected)
}
}

View File

@@ -60,3 +60,12 @@ func CoalesceStringer(s fmt.Stringer, def string) string {
return s.String() return s.String()
} }
} }
func SafeCast[T any](v any, def T) T {
switch r := v.(type) {
case T:
return r
default:
return def
}
}

View File

@@ -7,3 +7,11 @@ func MapKeyArr[T comparable, V any](v map[T]V) []T {
} }
return result return result
} }
func ArrToMap[T comparable, V any](a []V, keyfunc func(V) T) map[T]V {
result := make(map[T]V, len(a))
for _, v := range a {
result[keyfunc(v)] = v
}
return result
}

View File

@@ -3,6 +3,52 @@ package rfctime
import "time" import "time"
type RFCTime interface { type RFCTime interface {
AnyTime
Time() time.Time
Serialize() string
After(u AnyTime) bool
Before(u AnyTime) bool
Equal(u AnyTime) bool
Sub(u AnyTime) time.Duration
}
type AnyTime interface {
MarshalJSON() ([]byte, error)
MarshalBinary() ([]byte, error)
GobEncode() ([]byte, error)
MarshalText() ([]byte, error)
IsZero() bool
Date() (year int, month time.Month, day int)
Year() int
Month() time.Month
Day() int
Weekday() time.Weekday
ISOWeek() (year, week int)
Clock() (hour, min, sec int)
Hour() int
Minute() int
Second() int
Nanosecond() int
YearDay() int
Unix() int64
UnixMilli() int64
UnixMicro() int64
UnixNano() int64
Format(layout string) string
GoString() string
String() string
Location() *time.Location
}
type RFCDuration interface {
Time() time.Time Time() time.Time
Serialize() string Serialize() string
@@ -18,9 +64,9 @@ type RFCTime interface {
MarshalText() ([]byte, error) MarshalText() ([]byte, error)
UnmarshalText(data []byte) error UnmarshalText(data []byte) error
After(u RFCTime) bool After(u AnyTime) bool
Before(u RFCTime) bool Before(u AnyTime) bool
Equal(u RFCTime) bool Equal(u AnyTime) bool
IsZero() bool IsZero() bool
Date() (year int, month time.Month, day int) Date() (year int, month time.Month, day int)
Year() int Year() int
@@ -34,7 +80,7 @@ type RFCTime interface {
Second() int Second() int
Nanosecond() int Nanosecond() int
YearDay() int YearDay() int
Sub(u RFCTime) time.Duration Sub(u AnyTime) time.Duration
Unix() int64 Unix() int64
UnixMilli() int64 UnixMilli() int64
UnixMicro() int64 UnixMicro() int64
@@ -43,3 +89,13 @@ type RFCTime interface {
GoString() string GoString() string
String() string String() string
} }
func tt(v AnyTime) time.Time {
if r, ok := v.(time.Time); ok {
return r
}
if r, ok := v.(RFCTime); ok {
return r.Time()
}
return time.Unix(0, v.UnixNano()).In(v.Location())
}

50
rfctime/interface_test.go Normal file
View File

@@ -0,0 +1,50 @@
package rfctime
import (
"testing"
"time"
)
func TestAnyTimeInterface(t *testing.T) {
var v AnyTime
v = NowRFC3339Nano()
assertEqual(t, v.String(), v.String())
v = NowRFC3339()
assertEqual(t, v.String(), v.String())
v = NowUnix()
assertEqual(t, v.String(), v.String())
v = NowUnixMilli()
assertEqual(t, v.String(), v.String())
v = NowUnixNano()
assertEqual(t, v.String(), v.String())
v = time.Now()
assertEqual(t, v.String(), v.String())
}
func TestRFCTimeInterface(t *testing.T) {
var v RFCTime
v = NowRFC3339Nano()
assertEqual(t, v.String(), v.String())
v = NowRFC3339()
assertEqual(t, v.String(), v.String())
v = NowUnix()
assertEqual(t, v.String(), v.String())
v = NowUnixMilli()
assertEqual(t, v.String(), v.String())
v = NowUnixNano()
assertEqual(t, v.String(), v.String())
}

View File

@@ -2,6 +2,10 @@ package rfctime
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"time" "time"
) )
@@ -61,6 +65,23 @@ func (t *RFC3339Time) UnmarshalText(data []byte) error {
return nil return nil
} }
func (t *RFC3339Time) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
if bt != bsontype.DateTime {
return errors.New(fmt.Sprintf("cannot unmarshal %v into RFC3339Time", bt))
}
var tt time.Time
err := bson.Unmarshal(data, &tt)
if err != nil {
return err
}
*t = RFC3339Time(tt)
return nil
}
func (t RFC3339Time) MarshalBSONValue() (bsontype.Type, []byte, error) {
return bson.MarshalValue(time.Time(t))
}
func (t RFC3339Time) Serialize() string { func (t RFC3339Time) Serialize() string {
return t.Time().Format(t.FormatStr()) return t.Time().Format(t.FormatStr())
} }
@@ -69,16 +90,16 @@ func (t RFC3339Time) FormatStr() string {
return time.RFC3339 return time.RFC3339
} }
func (t RFC3339Time) After(u RFCTime) bool { func (t RFC3339Time) After(u AnyTime) bool {
return t.Time().After(u.Time()) return t.Time().After(tt(u))
} }
func (t RFC3339Time) Before(u RFCTime) bool { func (t RFC3339Time) Before(u AnyTime) bool {
return t.Time().Before(u.Time()) return t.Time().Before(tt(u))
} }
func (t RFC3339Time) Equal(u RFCTime) bool { func (t RFC3339Time) Equal(u AnyTime) bool {
return t.Time().Equal(u.Time()) return t.Time().Equal(tt(u))
} }
func (t RFC3339Time) IsZero() bool { func (t RFC3339Time) IsZero() bool {
@@ -137,8 +158,8 @@ func (t RFC3339Time) Add(d time.Duration) RFC3339Time {
return RFC3339Time(t.Time().Add(d)) return RFC3339Time(t.Time().Add(d))
} }
func (t RFC3339Time) Sub(u RFCTime) time.Duration { func (t RFC3339Time) Sub(u AnyTime) time.Duration {
return t.Time().Sub(u.Time()) return t.Time().Sub(tt(u))
} }
func (t RFC3339Time) AddDate(years int, months int, days int) RFC3339Time { func (t RFC3339Time) AddDate(years int, months int, days int) RFC3339Time {
@@ -173,6 +194,10 @@ func (t RFC3339Time) String() string {
return t.Time().String() return t.Time().String()
} }
func (t RFC3339Time) Location() *time.Location {
return t.Time().Location()
}
func NewRFC3339(t time.Time) RFC3339Time { func NewRFC3339(t time.Time) RFC3339Time {
return RFC3339Time(t) return RFC3339Time(t)
} }

View File

@@ -2,6 +2,10 @@ package rfctime
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"time" "time"
) )
@@ -61,6 +65,23 @@ func (t *RFC3339NanoTime) UnmarshalText(data []byte) error {
return nil return nil
} }
func (t *RFC3339NanoTime) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
if bt != bsontype.DateTime {
return errors.New(fmt.Sprintf("cannot unmarshal %v into RFC3339NanoTime", bt))
}
var tt time.Time
err := bson.RawValue{Type: bt, Value: data}.Unmarshal(&tt)
if err != nil {
return err
}
*t = RFC3339NanoTime(tt)
return nil
}
func (t RFC3339NanoTime) MarshalBSONValue() (bsontype.Type, []byte, error) {
return bson.MarshalValue(time.Time(t))
}
func (t RFC3339NanoTime) Serialize() string { func (t RFC3339NanoTime) Serialize() string {
return t.Time().Format(t.FormatStr()) return t.Time().Format(t.FormatStr())
} }
@@ -69,16 +90,16 @@ func (t RFC3339NanoTime) FormatStr() string {
return time.RFC3339Nano return time.RFC3339Nano
} }
func (t RFC3339NanoTime) After(u RFCTime) bool { func (t RFC3339NanoTime) After(u AnyTime) bool {
return t.Time().After(u.Time()) return t.Time().After(tt(u))
} }
func (t RFC3339NanoTime) Before(u RFCTime) bool { func (t RFC3339NanoTime) Before(u AnyTime) bool {
return t.Time().Before(u.Time()) return t.Time().Before(tt(u))
} }
func (t RFC3339NanoTime) Equal(u RFCTime) bool { func (t RFC3339NanoTime) Equal(u AnyTime) bool {
return t.Time().Equal(u.Time()) return t.Time().Equal(tt(u))
} }
func (t RFC3339NanoTime) IsZero() bool { func (t RFC3339NanoTime) IsZero() bool {
@@ -137,8 +158,8 @@ func (t RFC3339NanoTime) Add(d time.Duration) RFC3339NanoTime {
return RFC3339NanoTime(t.Time().Add(d)) return RFC3339NanoTime(t.Time().Add(d))
} }
func (t RFC3339NanoTime) Sub(u RFCTime) time.Duration { func (t RFC3339NanoTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(u.Time()) return t.Time().Sub(tt(u))
} }
func (t RFC3339NanoTime) AddDate(years int, months int, days int) RFC3339NanoTime { func (t RFC3339NanoTime) AddDate(years int, months int, days int) RFC3339NanoTime {
@@ -173,6 +194,10 @@ func (t RFC3339NanoTime) String() string {
return t.Time().String() return t.Time().String()
} }
func (t RFC3339NanoTime) Location() *time.Location {
return t.Time().Location()
}
func NewRFC3339Nano(t time.Time) RFC3339NanoTime { func NewRFC3339Nano(t time.Time) RFC3339NanoTime {
return RFC3339NanoTime(t) return RFC3339NanoTime(t)
} }

View File

@@ -12,7 +12,7 @@ func TestRoundtrip(t *testing.T) {
Value RFC3339NanoTime `json:"v"` Value RFC3339NanoTime `json:"v"`
} }
val1 := NewRFC3339Nano(time.Now()) val1 := NewRFC3339Nano(time.Unix(0, 1675951556820915171))
w1 := Wrap{val1} w1 := Wrap{val1}
jstr1, err := json.Marshal(w1) jstr1, err := json.Marshal(w1)
@@ -20,7 +20,8 @@ func TestRoundtrip(t *testing.T) {
panic(err) panic(err)
} }
if string(jstr1) != "{\"v\":\"2023-01-29T20:32:36.149692117+01:00\"}" { if string(jstr1) != "{\"v\":\"2023-02-09T15:05:56.820915171+01:00\"}" {
t.Errorf(string(jstr1))
t.Errorf("repr differs") t.Errorf("repr differs")
} }

59
rfctime/seconds.go Normal file
View File

@@ -0,0 +1,59 @@
package rfctime
import (
"encoding/json"
"gogs.mikescher.com/BlackForestBytes/goext/timeext"
"time"
)
type SecondsF64 time.Duration
func (d SecondsF64) Duration() time.Duration {
return time.Duration(d)
}
func (d SecondsF64) String() string {
return d.Duration().String()
}
func (d SecondsF64) Nanoseconds() int64 {
return d.Duration().Nanoseconds()
}
func (d SecondsF64) Microseconds() int64 {
return d.Duration().Microseconds()
}
func (d SecondsF64) Milliseconds() int64 {
return d.Duration().Milliseconds()
}
func (d SecondsF64) Seconds() float64 {
return d.Duration().Seconds()
}
func (d SecondsF64) Minutes() float64 {
return d.Duration().Minutes()
}
func (d SecondsF64) Hours() float64 {
return d.Duration().Hours()
}
func (d *SecondsF64) UnmarshalJSON(data []byte) error {
var secs float64 = 0
if err := json.Unmarshal(data, &secs); err != nil {
return err
}
*d = SecondsF64(timeext.FromSeconds(secs))
return nil
}
func (d SecondsF64) MarshalJSON() ([]byte, error) {
secs := d.Seconds()
return json.Marshal(secs)
}
func NewSecondsF64(t time.Duration) SecondsF64 {
return SecondsF64(t)
}

View File

@@ -63,16 +63,16 @@ func (t UnixTime) Serialize() string {
return strconv.FormatInt(t.Time().Unix(), 10) return strconv.FormatInt(t.Time().Unix(), 10)
} }
func (t UnixTime) After(u RFCTime) bool { func (t UnixTime) After(u AnyTime) bool {
return t.Time().After(u.Time()) return t.Time().After(tt(u))
} }
func (t UnixTime) Before(u RFCTime) bool { func (t UnixTime) Before(u AnyTime) bool {
return t.Time().Before(u.Time()) return t.Time().Before(tt(u))
} }
func (t UnixTime) Equal(u RFCTime) bool { func (t UnixTime) Equal(u AnyTime) bool {
return t.Time().Equal(u.Time()) return t.Time().Equal(tt(u))
} }
func (t UnixTime) IsZero() bool { func (t UnixTime) IsZero() bool {
@@ -131,8 +131,8 @@ func (t UnixTime) Add(d time.Duration) UnixTime {
return UnixTime(t.Time().Add(d)) return UnixTime(t.Time().Add(d))
} }
func (t UnixTime) Sub(u RFCTime) time.Duration { func (t UnixTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(u.Time()) return t.Time().Sub(tt(u))
} }
func (t UnixTime) AddDate(years int, months int, days int) UnixTime { func (t UnixTime) AddDate(years int, months int, days int) UnixTime {
@@ -167,6 +167,10 @@ func (t UnixTime) String() string {
return t.Time().String() return t.Time().String()
} }
func (t UnixTime) Location() *time.Location {
return t.Time().Location()
}
func NewUnix(t time.Time) UnixTime { func NewUnix(t time.Time) UnixTime {
return UnixTime(t) return UnixTime(t)
} }

View File

@@ -63,16 +63,16 @@ func (t UnixMilliTime) Serialize() string {
return strconv.FormatInt(t.Time().UnixMilli(), 10) return strconv.FormatInt(t.Time().UnixMilli(), 10)
} }
func (t UnixMilliTime) After(u RFCTime) bool { func (t UnixMilliTime) After(u AnyTime) bool {
return t.Time().After(u.Time()) return t.Time().After(tt(u))
} }
func (t UnixMilliTime) Before(u RFCTime) bool { func (t UnixMilliTime) Before(u AnyTime) bool {
return t.Time().Before(u.Time()) return t.Time().Before(tt(u))
} }
func (t UnixMilliTime) Equal(u RFCTime) bool { func (t UnixMilliTime) Equal(u AnyTime) bool {
return t.Time().Equal(u.Time()) return t.Time().Equal(tt(u))
} }
func (t UnixMilliTime) IsZero() bool { func (t UnixMilliTime) IsZero() bool {
@@ -131,8 +131,8 @@ func (t UnixMilliTime) Add(d time.Duration) UnixMilliTime {
return UnixMilliTime(t.Time().Add(d)) return UnixMilliTime(t.Time().Add(d))
} }
func (t UnixMilliTime) Sub(u RFCTime) time.Duration { func (t UnixMilliTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(u.Time()) return t.Time().Sub(tt(u))
} }
func (t UnixMilliTime) AddDate(years int, months int, days int) UnixMilliTime { func (t UnixMilliTime) AddDate(years int, months int, days int) UnixMilliTime {
@@ -167,6 +167,10 @@ func (t UnixMilliTime) String() string {
return t.Time().String() return t.Time().String()
} }
func (t UnixMilliTime) Location() *time.Location {
return t.Time().Location()
}
func NewUnixMilli(t time.Time) UnixMilliTime { func NewUnixMilli(t time.Time) UnixMilliTime {
return UnixMilliTime(t) return UnixMilliTime(t)
} }

View File

@@ -63,16 +63,16 @@ func (t UnixNanoTime) Serialize() string {
return strconv.FormatInt(t.Time().UnixNano(), 10) return strconv.FormatInt(t.Time().UnixNano(), 10)
} }
func (t UnixNanoTime) After(u RFCTime) bool { func (t UnixNanoTime) After(u AnyTime) bool {
return t.Time().After(u.Time()) return t.Time().After(tt(u))
} }
func (t UnixNanoTime) Before(u RFCTime) bool { func (t UnixNanoTime) Before(u AnyTime) bool {
return t.Time().Before(u.Time()) return t.Time().Before(tt(u))
} }
func (t UnixNanoTime) Equal(u RFCTime) bool { func (t UnixNanoTime) Equal(u AnyTime) bool {
return t.Time().Equal(u.Time()) return t.Time().Equal(tt(u))
} }
func (t UnixNanoTime) IsZero() bool { func (t UnixNanoTime) IsZero() bool {
@@ -131,8 +131,8 @@ func (t UnixNanoTime) Add(d time.Duration) UnixNanoTime {
return UnixNanoTime(t.Time().Add(d)) return UnixNanoTime(t.Time().Add(d))
} }
func (t UnixNanoTime) Sub(u RFCTime) time.Duration { func (t UnixNanoTime) Sub(u AnyTime) time.Duration {
return t.Time().Sub(u.Time()) return t.Time().Sub(tt(u))
} }
func (t UnixNanoTime) AddDate(years int, months int, days int) UnixNanoTime { func (t UnixNanoTime) AddDate(years int, months int, days int) UnixNanoTime {
@@ -167,6 +167,10 @@ func (t UnixNanoTime) String() string {
return t.Time().String() return t.Time().String()
} }
func (t UnixNanoTime) Location() *time.Location {
return t.Time().Location()
}
func NewUnixNano(t time.Time) UnixNanoTime { func NewUnixNano(t time.Time) UnixNanoTime {
return UnixNanoTime(t) return UnixNanoTime(t)
} }