Compare commits
172 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
145b1138d7
|
|||
|
fad2e4ff6d
|
|||
|
e12764c0a2
|
|||
|
53aa8c05b0
|
|||
|
4f96907758
|
|||
|
b3131e3ba6
|
|||
|
02d6894ec6
|
|||
|
dad0e3240d
|
|||
|
a5cece2bc7
|
|||
|
cdce955887
|
|||
|
90862ed3f5
|
|||
|
18c172d69a
|
|||
|
80cea13437
|
|||
| f9576a2fec | |||
| 26d542c9a2 | |||
| d30e778bd4 | |||
| 73e867f75a | |||
| 84b87d61f2 | |||
|
f62e7499ec
|
|||
|
1edc2712ed
|
|||
|
63c28b4141
|
|||
|
b01e659bb4
|
|||
|
0923fa7c09
|
|||
|
e19cb30713
|
|||
|
90dc6079d5
|
|||
|
5b5c262994
|
|||
|
44ec5e4804
|
|||
| 2054c04442 | |||
| 373d28d405 | |||
|
8aaf2cd257
|
|||
|
a373876d32
|
|||
|
12324ba60d
|
|||
|
a032d09ea2
|
|||
|
4bebc9ea38
|
|||
|
580906ea08
|
|||
|
0f28e3d11b
|
|||
|
55f80432bb
|
|||
|
d6daf0e285
|
|||
|
1c2bc060da
|
|||
|
34023dca4c
|
|||
|
69f2dd73c5
|
|||
| dce6f634b1 | |||
| 833f74427f | |||
| 535a699584 | |||
|
b439b95f83
|
|||
| 995a82d90a | |||
|
0db10845ed
|
|||
|
128ca25aa2
|
|||
|
f5d13ebe64
|
|||
|
9730a91ad5
|
|||
|
8c16e4d982
|
|||
|
039a53a395
|
|||
|
2cf571579b
|
|||
|
9a537bb8c2
|
|||
|
78ad103151
|
|||
|
c764a946ff
|
|||
|
ef59b1241f
|
|||
|
a70ab33559
|
|||
|
a58bb4b14b
|
|||
|
dc62bbe55f
|
|||
|
b832d77d3e
|
|||
|
38467cb4e7
|
|||
| 68b06158b3 | |||
|
5f51173276
|
|||
|
1586314e3e
|
|||
|
254fe1556a
|
|||
|
52e74b59f5
|
|||
|
64f2cd7219
|
|||
|
a29aec8fb5
|
|||
|
8ea9b3f79f
|
|||
|
a4b2a0589f
|
|||
|
4ef5f6059b
|
|||
|
b23a444aa2
|
|||
|
09932046f8
|
|||
|
37e52595a2
|
|||
|
95d7c90492
|
|||
|
23a3235c7e
|
|||
|
506d276962
|
|||
|
2a0cf84416
|
|||
|
073aa84dd4
|
|||
|
a0dc9e92e4
|
|||
|
98c591b019
|
|||
|
a93b93a3cd
|
|||
|
49bc52d63e
|
|||
|
959020e3c0
|
|||
| 395e83acf6 | |||
|
55ff89f179
|
|||
|
cbaa283f74
|
|||
|
20fb1f5601
|
|||
|
cc58639306
|
|||
|
cea822ffa6
|
|||
|
c20ae20cc1
|
|||
|
f07cd79b96
|
|||
|
164c462b96
|
|||
|
5e6cb63f14
|
|||
|
4832aa9d6c
|
|||
|
4d606d3131
|
|||
|
be9b9e8ccf
|
|||
|
28cdfc5bd2
|
|||
|
10a6627323
|
|||
|
06b3b4116e
|
|||
|
ff821390f7
|
|||
|
c8e9c34706
|
|||
|
b7c48cb467
|
|||
|
a0a80899f5
|
|||
|
3543441b96
|
|||
|
eef12da4e6
|
|||
|
d009aafd4e
|
|||
|
f7b4aa48d7
|
|||
|
36b092774d
|
|||
|
a8c6e39ac5
|
|||
|
62f2ce9268
|
|||
|
49375e90f0
|
|||
|
d8cf255c80
|
|||
|
b520282ba0
|
|||
|
27cc9366b5
|
|||
|
d9517fe73c
|
|||
|
8a92a6cc52
|
|||
|
9b2028ab54
|
|||
|
207fd331d5
|
|||
|
54b0d6701d
|
|||
|
fc2657179b
|
|||
|
d4894e31fe
|
|||
|
0ddfaf666b
|
|||
|
e154137105
|
|||
|
9b9a79b4ad
|
|||
|
5a8d7110e4
|
|||
|
d47c84cd47
|
|||
|
c571f3f888
|
|||
|
e884ba6b89
|
|||
|
1a8e31e5ef
|
|||
|
eccc0fe9e5
|
|||
|
c8dec24a0d
|
|||
|
b8cb989e54
|
|||
|
ec672fbd49
|
|||
|
cfb0b53fc7
|
|||
|
a7389f44fa
|
|||
|
69f0fedd66
|
|||
|
335ef4d8e8
|
|||
|
61801ff20d
|
|||
|
361dca5c85
|
|||
|
9f85a243e8
|
|||
|
dc6cb274ee
|
|||
|
f6b47792a4
|
|||
|
295b3ef793
|
|||
|
721c176337
|
|||
|
ebba6545a3
|
|||
|
19c7e22ced
|
|||
|
9f883b458f
|
|||
|
1f456c5134
|
|||
|
d7fbef37db
|
|||
|
a1668b6e5a
|
|||
|
3a17edfaf0
|
|||
|
3320a9c19d
|
|||
|
8dcd8a270a
|
|||
|
03a9b276d8
|
|||
|
9c8cde384f
|
|||
|
99b000ecf4
|
|||
|
a173e30090
|
|||
|
a3481a7d2d
|
|||
|
a8e6f98a89
|
|||
|
ab805403b9
|
|||
|
1e98d351ce
|
|||
|
c40bdc8e9e
|
|||
|
7204562879
|
|||
|
741611a2e1
|
|||
|
133aeb8374
|
|||
|
b78a468632
|
|||
|
f1b4480e0f
|
|||
|
ffffe4bf24
|
|||
|
413bf3c848
|
|||
|
646990b549
|
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(make:*)",
|
||||
|
||||
"Bash(mkdir:*)",
|
||||
|
||||
"Bash(go build:*)",
|
||||
"Bash(go test:*)",
|
||||
"Bash(go get:*)",
|
||||
"Bash(go mod:*)",
|
||||
"Bash(go clean:*)",
|
||||
"Bash(go doc:*)",
|
||||
|
||||
"Bash(grep:*)",
|
||||
"Bash(find:*)",
|
||||
"Bash(rg:*)",
|
||||
"Bash(base64:*)",
|
||||
"Bash(sed:*)",
|
||||
"Bash(ls:*)",
|
||||
|
||||
"Bash(curl:*)",
|
||||
|
||||
"Bash(timeout 60s go test -v -count=1 ./...)",
|
||||
"Bash(timeout 60s go test -v -count=1 ./tests/integration/...)",
|
||||
"Bash(timeout 60s go test:*)",
|
||||
"Bash(timeout 300 make test)",
|
||||
"Bash(timeout 30s go test ./tests/integration -run:*)",
|
||||
|
||||
"Bash(done)",
|
||||
"Bash(awk:*)",
|
||||
|
||||
"WebFetch(domain:platform.openai.com)"
|
||||
],
|
||||
"deny": [
|
||||
],
|
||||
"defaultMode": "acceptEdits"
|
||||
},
|
||||
"env": {
|
||||
"CLAUDE_CODE_ENABLE_TELEMETRY": "0",
|
||||
"DISABLE_ERROR_REPORTING": "1",
|
||||
"DISABLE_TELEMETRY": "1"
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
# https://docs.github.com/en/actions/learn-github-actions/contexts#github-context
|
||||
|
||||
name: Build Docker and Deploy
|
||||
run-name: Build & Deploy ${{ gitea.ref }} on ${{ gitea.actor }}
|
||||
run-name: "[test]: ${{ github.event.head_commit.message }}"
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -39,17 +39,4 @@ jobs:
|
||||
- name: Run tests
|
||||
run: cd "${{ gitea.workspace }}" && make test
|
||||
|
||||
- name: Send failure mail
|
||||
if: failure()
|
||||
uses: dawidd6/action-send-mail@v3
|
||||
with:
|
||||
server_address: smtp.fastmail.com
|
||||
server_port: 465
|
||||
secure: true
|
||||
username: ${{secrets.MAIL_USERNAME}}
|
||||
password: ${{secrets.MAIL_PASSWORD}}
|
||||
subject: Pipeline on '${{ gitea.repository }}' failed
|
||||
to: ${{ steps.commiter_info.outputs.MAIL }}
|
||||
from: Gitea Actions <gitea_actions@blackforestbytes.de>
|
||||
body: "Go to https://gogs.blackforestbytes.com/${{ gitea.repository }}/actions"
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
|
||||
.claude-queue
|
||||
|
||||
##########################################################################
|
||||
|
||||
.idea/**/workspace.xml
|
||||
|
||||
Generated
+6
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AgentMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+6
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AskMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
+6
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Ask2AgentMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+6
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="EditMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+11
@@ -0,0 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GoImports">
|
||||
<option name="excludedPackages">
|
||||
<array>
|
||||
<option value="github.com/pkg/errors" />
|
||||
<option value="golang.org/x/net/context" />
|
||||
</array>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
@@ -5,10 +5,10 @@ A collection of general & useful library methods
|
||||
|
||||
This should not have any heavy dependencies (gin, mongo, etc) and add missing basic language features...
|
||||
|
||||
Potentially needs `export GOPRIVATE="gogs.mikescher.com"`
|
||||
Potentially needs `export GOPRIVATE="git.blackforestbytes.com"`
|
||||
|
||||
|
||||
### Packages:
|
||||
## Packages:
|
||||
|
||||
| Name | Maintainer | Description |
|
||||
|-------------|------------|---------------------------------------------------------------------------------------------------------------|
|
||||
@@ -20,15 +20,20 @@ Potentially needs `export GOPRIVATE="gogs.mikescher.com"`
|
||||
| zipext | Mike | Utility for zip/gzip/tar etc |
|
||||
| reflectext | Mike | Utility for golang reflection |
|
||||
| fsext | Mike | Utility for filesytem access |
|
||||
| ctxext | Mike | Utility for context.Context |
|
||||
| | | |
|
||||
| mongoext | Mike | Utility/Helper functions for mongodb |
|
||||
| mongoext | Mike | Utility/Helper functions for mongodb (kinda abandoned) |
|
||||
| cursortoken | Mike | MongoDB cursortoken implementation |
|
||||
| pagination | Mike | Pagination implementation |
|
||||
| | | |
|
||||
| ginext | Mike | gin wrapper |
|
||||
| wsw | Mike | websocket wrapper |
|
||||
| | | |
|
||||
| totpext | Mike | Implementation of TOTP (2-Factor-Auth) |
|
||||
| termext | Mike | Utilities for terminals (mostly color output) |
|
||||
| confext | Mike | Parses environment configuration into structs |
|
||||
| cmdext | Mike | Runner for external commands/processes |
|
||||
| excelext | Mike | Build Excel files |
|
||||
| | | |
|
||||
| sq | Mike | Utility functions for sql based databases (primarily sqlite) |
|
||||
| tst | Mike | Utility functions for unit tests |
|
||||
@@ -42,4 +47,69 @@ Potentially needs `export GOPRIVATE="gogs.mikescher.com"`
|
||||
| wmo | Mike | Mongo Wrapper, wraps mongodb with a better interface |
|
||||
| | | |
|
||||
| scn | Mike | SimpleCloudNotifier |
|
||||
| | | |
|
||||
| | | |
|
||||
|
||||
|
||||
|
||||
## Usage:
|
||||
|
||||
### exerr
|
||||
|
||||
- see **mongoext/builder.go** for full info
|
||||
|
||||
Short summary:
|
||||
- An better error package with metadata, listener, api-output and error-traces
|
||||
- Initialize with `exerr.Init()`
|
||||
- *Never* return `err` direct, always use exerr.Wrap(err, "...") - add metadata where applicable
|
||||
- at the end either Print(), Fatal() or Output() your error (print = stdout, fatal = panic, output = json API response)
|
||||
- You can add listeners with exerr.RegisterListener(), and save the full errors to a db or smth
|
||||
|
||||
### wmo
|
||||
|
||||
- A typed wrapper around the official mongo-go-driver
|
||||
- Use `wmo.W[...](...)` to wrap the collections and type-ify them
|
||||
- The new collections have all the usual methods, but types
|
||||
- Also they have List() and Paginate() methods for paginated listings (witehr with a cursortoken or page/limit)
|
||||
- Register additional hooks with `WithDecodeFunc`, `WithUnmarshalHook`, `WithMarshalHook`, `WithModifyingPipeline`, `WithModifyingPipelineFunc`
|
||||
- List(), Paginate(), etc support filter interfaces
|
||||
- Rule(s) of thumb:
|
||||
- filter the results in the filter interface
|
||||
- sort the results in the sort function of the filter interface
|
||||
- add joins ($lookup's) in the `WithModifyingPipelineFunc`/`WithModifyingPipeline`
|
||||
|
||||
#### ginext
|
||||
|
||||
- A wrapper around gin-gonic/gin
|
||||
- create the gin engine with `ginext.NewEngine`
|
||||
- Add routes with `engine.Routes()...`
|
||||
- `.Use(..)` adds a middleware
|
||||
- `.Group(..)` adds a group
|
||||
- `.Get().Handle(..)` adds a handler
|
||||
- Handler return values (in contract to ginext) - values implement the `ginext.HTTPResponse` interface
|
||||
- Every handler starts with something like:
|
||||
```go
|
||||
func (handler Handler) CommunityMetricsValues(pctx ginext.PreContext) ginext.HTTPResponse {
|
||||
type communityURI struct {
|
||||
Version string `uri:"version"`
|
||||
CommunityID models.CommunityID `uri:"cid"`
|
||||
}
|
||||
type body struct {
|
||||
UserID models.UserID `json:"userID"`
|
||||
EventID models.EventID `json:"eventID"`
|
||||
}
|
||||
|
||||
var u uri
|
||||
var b body
|
||||
ctx, gctx, httpErr := pctx.URI(&u).Body(&b).Start() // can have more unmarshaller, like header, form, etc
|
||||
if httpErr != nil {
|
||||
return *httpErr
|
||||
}
|
||||
defer ctx.Cancel()
|
||||
|
||||
// do stuff
|
||||
}
|
||||
```
|
||||
|
||||
#### sq
|
||||
|
||||
- TODO (like mongoext for sqlite/sql databases)
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/cryptext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/rext"
|
||||
"go/format"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/cryptext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/rext"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
@@ -64,16 +64,17 @@ func GenerateCharsetIDSpecs(sourceDir string, destFile string, opt CSIDGenOption
|
||||
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return !strings.HasSuffix(v.Name(), "_gen.go") })
|
||||
langext.SortBy(files, func(v os.DirEntry) string { return v.Name() })
|
||||
|
||||
newChecksumStr := goext.GoextVersion
|
||||
var newChecksumStr strings.Builder
|
||||
newChecksumStr.WriteString(goext.GoextVersion)
|
||||
for _, f := range files {
|
||||
content, err := os.ReadFile(path.Join(sourceDir, f.Name()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newChecksumStr += "\n" + f.Name() + "\t" + cryptext.BytesSha256(content)
|
||||
newChecksumStr.WriteString("\n" + f.Name() + "\t" + cryptext.BytesSha256(content))
|
||||
}
|
||||
|
||||
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr))
|
||||
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr.String()))
|
||||
|
||||
if newChecksum != oldChecksum {
|
||||
fmt.Printf("[CSIDGenerate] Checksum has changed ( %s -> %s ), will generate new file\n\n", oldChecksum, newChecksum)
|
||||
|
||||
@@ -7,9 +7,9 @@ import "crypto/sha256"
|
||||
import "fmt"
|
||||
import "github.com/go-playground/validator/v10"
|
||||
import "github.com/rs/zerolog/log"
|
||||
import "gogs.mikescher.com/BlackForestBytes/goext/exerr"
|
||||
import "gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
import "gogs.mikescher.com/BlackForestBytes/goext/rext"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/exerr"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/rext"
|
||||
import "math/big"
|
||||
import "reflect"
|
||||
import "regexp"
|
||||
@@ -183,6 +183,10 @@ func (id {{.Name}}) CheckString() string {
|
||||
return getCheckString(prefix{{.Name}}, string(id))
|
||||
}
|
||||
|
||||
func (id {{.Name}}) IsZero() bool {
|
||||
return id == ""
|
||||
}
|
||||
|
||||
func (id {{.Name}}) Regex() rext.Regex {
|
||||
return regex{{.Name}}
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package bfcodegen
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/cmdext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/cmdext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
+11
-10
@@ -6,11 +6,11 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/cryptext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/rext"
|
||||
"go/format"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/cryptext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/rext"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
@@ -46,7 +46,7 @@ var rexEnumPackage = rext.W(regexp.MustCompile(`^package\s+(?P<name>[A-Za-z0-9_]
|
||||
|
||||
var rexEnumDef = rext.W(regexp.MustCompile(`^\s*type\s+(?P<name>[A-Za-z0-9_]+)\s+(?P<type>[A-Za-z0-9_]+)\s*//\s*(@enum:type).*$`))
|
||||
|
||||
var rexEnumValueDef = rext.W(regexp.MustCompile(`^\s*(?P<name>[A-Za-z0-9_]+)\s+(?P<type>[A-Za-z0-9_]+)\s*=\s*(?P<value>("[A-Za-z0-9_:\s\-.]+"|[0-9]+))\s*(//(?P<comm>.*))?.*$`))
|
||||
var rexEnumValueDef = rext.W(regexp.MustCompile(`^\s*(?P<name>[A-Za-z0-9_]+)\s+(?P<type>[A-Za-z0-9_]+)\s*=\s*(?P<value>("[/@A-Za-z0-9_:\s\-.]*"|[0-9]+))\s*(//(?P<comm>.*))?.*$`))
|
||||
|
||||
var rexEnumChecksumConst = rext.W(regexp.MustCompile(`const ChecksumEnumGenerator = "(?P<cs>[A-Za-z0-9_]*)"`))
|
||||
|
||||
@@ -95,16 +95,17 @@ func _generateEnumSpecs(sourceDir string, destFile string, oldChecksum string, g
|
||||
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return !strings.HasSuffix(v.Name(), "_gen.go") })
|
||||
langext.SortBy(files, func(v os.DirEntry) string { return v.Name() })
|
||||
|
||||
newChecksumStr := goext.GoextVersion
|
||||
var newChecksumStr strings.Builder
|
||||
newChecksumStr.WriteString(goext.GoextVersion)
|
||||
for _, f := range files {
|
||||
content, err := os.ReadFile(path.Join(sourceDir, f.Name()))
|
||||
if err != nil {
|
||||
return "", "", false, err
|
||||
}
|
||||
newChecksumStr += "\n" + f.Name() + "\t" + cryptext.BytesSha256(content)
|
||||
newChecksumStr.WriteString("\n" + f.Name() + "\t" + cryptext.BytesSha256(content))
|
||||
}
|
||||
|
||||
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr))
|
||||
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr.String()))
|
||||
|
||||
if newChecksum != oldChecksum {
|
||||
fmt.Printf("[EnumGenerate] Checksum has changed ( %s -> %s ), will generate new file\n\n", oldChecksum, newChecksum)
|
||||
@@ -213,7 +214,7 @@ func processEnumFile(basedir string, fn string, debugOutput bool) ([]EnumDef, st
|
||||
var descr *string = nil
|
||||
var data *map[string]any = nil
|
||||
if comment != nil {
|
||||
comment = langext.Ptr(strings.TrimSpace(*comment))
|
||||
comment = new(strings.TrimSpace(*comment))
|
||||
if strings.HasPrefix(*comment, "{") {
|
||||
if v, ok := tryParseDataComment(*comment); ok {
|
||||
data = &v
|
||||
@@ -278,7 +279,7 @@ func tryParseDataComment(s string) (map[string]any, bool) {
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
|
||||
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
if rv.Kind() == reflect.Pointer && rv.IsNil() {
|
||||
continue
|
||||
}
|
||||
if rv.Kind() == reflect.Bool {
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
package {{.PkgName}}
|
||||
|
||||
import "gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
import "gogs.mikescher.com/BlackForestBytes/goext/enums"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/enums"
|
||||
|
||||
const ChecksumEnumGenerator = "{{.Checksum}}" // GoExtVersion: {{.GoextVersion}}
|
||||
|
||||
@@ -108,7 +108,7 @@ func (e {{.EnumTypeName}}) PackageName() string {
|
||||
}
|
||||
|
||||
func (e {{.EnumTypeName}}) Meta() enums.EnumMetaValue {
|
||||
{{if $hasDescr}} return enums.EnumMetaValue{VarName: e.VarName(), Value: e, Description: langext.Ptr(e.Description())} {{else}} return enums.EnumMetaValue{VarName: e.VarName(), Value: e, Description: nil} {{end}}
|
||||
{{if $hasDescr}} return enums.EnumMetaValue{VarName: e.VarName(), Value: e, Description: new(e.Description())} {{else}} return enums.EnumMetaValue{VarName: e.VarName(), Value: e, Description: nil} {{end}}
|
||||
}
|
||||
|
||||
{{if $hasDescr}}
|
||||
@@ -117,6 +117,20 @@ func (e {{.EnumTypeName}}) DescriptionMeta() enums.EnumDescriptionMetaValue {
|
||||
}
|
||||
{{end}}
|
||||
|
||||
{{if $hasData}}
|
||||
func (e {{.EnumTypeName}}) DataMeta() enums.EnumDataMetaValue {
|
||||
return enums.EnumDataMetaValue{
|
||||
VarName: e.VarName(),
|
||||
Value: e,
|
||||
{{if $hasDescr}} Description: new(e.Description()), {{else}} Description: nil, {{end}}
|
||||
Data: map[string]any{
|
||||
{{ range $datakey, $datatype := $enumdef | generalDataKeys }} "{{ $datakey }}": e.Data().{{ $datakey | godatakey }},
|
||||
{{ end }}
|
||||
},
|
||||
}
|
||||
}
|
||||
{{end}}
|
||||
|
||||
func Parse{{.EnumTypeName}}(vv string) ({{.EnumTypeName}}, bool) {
|
||||
for _, ev := range __{{.EnumTypeName}}Values {
|
||||
if string(ev) == vv {
|
||||
@@ -136,6 +150,14 @@ func {{.EnumTypeName}}ValuesMeta() []enums.EnumMetaValue {
|
||||
}
|
||||
}
|
||||
|
||||
{{if $hasData}}
|
||||
func {{.EnumTypeName}}ValuesDataMeta() []enums.EnumDataMetaValue {
|
||||
return []enums.EnumDataMetaValue{ {{range .Values}}
|
||||
{{.VarName}}.DataMeta(), {{end}}
|
||||
}
|
||||
}
|
||||
{{end}}
|
||||
|
||||
{{if $hasDescr}}
|
||||
func {{.EnumTypeName}}ValuesDescriptionMeta() []enums.EnumDescriptionMetaValue {
|
||||
return []enums.EnumDescriptionMetaValue{ {{range .Values}}
|
||||
|
||||
@@ -3,9 +3,9 @@ package bfcodegen
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/cmdext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/cmdext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/cryptext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/rext"
|
||||
"go/format"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/cryptext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/rext"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
@@ -63,16 +63,17 @@ func GenerateIDSpecs(sourceDir string, destFile string, opt IDGenOptions) error
|
||||
files = langext.ArrFilter(files, func(v os.DirEntry) bool { return !strings.HasSuffix(v.Name(), "_gen.go") })
|
||||
langext.SortBy(files, func(v os.DirEntry) string { return v.Name() })
|
||||
|
||||
newChecksumStr := goext.GoextVersion
|
||||
var newChecksumStr strings.Builder
|
||||
newChecksumStr.WriteString(goext.GoextVersion)
|
||||
for _, f := range files {
|
||||
content, err := os.ReadFile(path.Join(sourceDir, f.Name()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newChecksumStr += "\n" + f.Name() + "\t" + cryptext.BytesSha256(content)
|
||||
newChecksumStr.WriteString("\n" + f.Name() + "\t" + cryptext.BytesSha256(content))
|
||||
}
|
||||
|
||||
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr))
|
||||
newChecksum := cryptext.BytesSha256([]byte(newChecksumStr.String()))
|
||||
|
||||
if newChecksum != oldChecksum {
|
||||
fmt.Printf("[IDGenerate] Checksum has changed ( %s -> %s ), will generate new file\n\n", oldChecksum, newChecksum)
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
package {{.PkgName}}
|
||||
|
||||
import "go.mongodb.org/mongo-driver/bson"
|
||||
import "go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
import "go.mongodb.org/mongo-driver/bson/primitive"
|
||||
import "gogs.mikescher.com/BlackForestBytes/goext/exerr"
|
||||
import "go.mongodb.org/mongo-driver/v2/bson"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/exerr"
|
||||
import "git.blackforestbytes.com/BlackForestBytes/goext/wmo"
|
||||
|
||||
const ChecksumIDGenerator = "{{.Checksum}}" // GoExtVersion: {{.GoextVersion}}
|
||||
|
||||
@@ -13,9 +12,10 @@ const ChecksumIDGenerator = "{{.Checksum}}" // GoExtVersion: {{.GoextVersion}}
|
||||
|
||||
// ================================ {{.Name}} ({{.FileRelative}}) ================================
|
||||
|
||||
func (i {{.Name}}) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
if objId, err := primitive.ObjectIDFromHex(string(i)); err == nil {
|
||||
return bson.MarshalValue(objId)
|
||||
func (i {{.Name}}) MarshalBSONValue() (byte, []byte, error) {
|
||||
if objId, err := bson.ObjectIDFromHex(string(i)); err == nil {
|
||||
tp, data, err := bson.MarshalValue(objId)
|
||||
return byte(tp), data, err
|
||||
} else {
|
||||
return 0, nil, exerr.New(exerr.TypeMarshalEntityID, "Failed to marshal {{.Name}}("+i.String()+") to ObjectId").Str("value", string(i)).Type("type", i).Build()
|
||||
}
|
||||
@@ -25,12 +25,12 @@ func (i {{.Name}}) String() string {
|
||||
return string(i)
|
||||
}
|
||||
|
||||
func (i {{.Name}}) ObjID() (primitive.ObjectID, error) {
|
||||
return primitive.ObjectIDFromHex(string(i))
|
||||
func (i {{.Name}}) ObjID() (bson.ObjectID, error) {
|
||||
return bson.ObjectIDFromHex(string(i))
|
||||
}
|
||||
|
||||
func (i {{.Name}}) Valid() bool {
|
||||
_, err := primitive.ObjectIDFromHex(string(i))
|
||||
_, err := bson.ObjectIDFromHex(string(i))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -38,10 +38,21 @@ func (i {{.Name}}) Valid() bool {
|
||||
func (i {{.Name}}) AsAny() {{$.AnyDef.Name}} {
|
||||
return {{$.AnyDef.Name}}(i)
|
||||
}
|
||||
|
||||
func (i {{.Name}}) AsAnyPtr() *{{$.AnyDef.Name}} {
|
||||
v := {{$.AnyDef.Name}}(i)
|
||||
return &v
|
||||
}
|
||||
{{end}}
|
||||
|
||||
func New{{.Name}}() {{.Name}} {
|
||||
return {{.Name}}(primitive.NewObjectID().Hex())
|
||||
func (i {{.Name}}) IsZero() bool {
|
||||
return i == ""
|
||||
}
|
||||
|
||||
func New{{.Name}}() {{.Name}} {
|
||||
return {{.Name}}(bson.NewObjectID().Hex())
|
||||
}
|
||||
|
||||
var _ wmo.MongoEntityID = (*{{.Name}})(nil)
|
||||
|
||||
{{end}}
|
||||
@@ -3,9 +3,9 @@ package bfcodegen
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/cmdext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/cmdext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
package bfcodegen
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessCSIDFileSimple(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package mymodels
|
||||
|
||||
type UserID string // @csid:type [USR]
|
||||
type OrderID string // @csid:type [ORD]
|
||||
`
|
||||
fp := writeTestFile(t, dir, "models.go", src)
|
||||
|
||||
ids, pkg, err := processCSIDFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "mymodels")
|
||||
tst.AssertEqual(t, len(ids), 2)
|
||||
tst.AssertEqual(t, ids[0].Name, "UserID")
|
||||
tst.AssertEqual(t, ids[0].Prefix, "USR")
|
||||
tst.AssertEqual(t, ids[1].Name, "OrderID")
|
||||
tst.AssertEqual(t, ids[1].Prefix, "ORD")
|
||||
tst.AssertEqual(t, ids[0].FileRelative, "models.go")
|
||||
}
|
||||
|
||||
func TestProcessCSIDFilePrefixMustBeUppercase(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// lowercase prefix should not match the regex (only [A-Z0-9]{3})
|
||||
src := `package x
|
||||
|
||||
type FooID string // @csid:type [usr]
|
||||
`
|
||||
fp := writeTestFile(t, dir, "x.go", src)
|
||||
|
||||
ids, pkg, err := processCSIDFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "x")
|
||||
tst.AssertEqual(t, len(ids), 0)
|
||||
}
|
||||
|
||||
func TestProcessCSIDFileGeneratedHeaderSkipped(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `// Code generated by csid-generate.go DO NOT EDIT.
|
||||
|
||||
package x
|
||||
|
||||
type SkipMeID string // @csid:type [SKP]
|
||||
`
|
||||
fp := writeTestFile(t, dir, "skip.go", src)
|
||||
|
||||
ids, pkg, err := processCSIDFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "")
|
||||
tst.AssertEqual(t, len(ids), 0)
|
||||
}
|
||||
|
||||
func TestGenerateCharsetIDSpecsEndToEnd(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src1 := `package models
|
||||
|
||||
type EntityID string // @csid:type [ENT]
|
||||
type UserID string // @csid:type [USR]
|
||||
`
|
||||
writeTestFile(t, dir, "a_models.go", src1)
|
||||
|
||||
src2 := `package models
|
||||
|
||||
type OrderID string // @csid:type [ORD]
|
||||
`
|
||||
writeTestFile(t, dir, "b_models.go", src2)
|
||||
|
||||
dest := filepath.Join(dir, "csid_gen.go")
|
||||
|
||||
err := GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{DebugOutput: langext.PFalse})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
out, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
outStr := string(out)
|
||||
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "package models"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ChecksumCharsetIDGenerator"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "func NewUserID()"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "func NewOrderID()"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "func NewEntityID()"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, `prefixUserID`) && strings.Contains(outStr, `"USR"`))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, `prefixOrderID`) && strings.Contains(outStr, `"ORD"`))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, `prefixEntityID`) && strings.Contains(outStr, `"ENT"`))
|
||||
}
|
||||
|
||||
func TestGenerateCharsetIDSpecsIdempotentWhenUnchanged(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type SomeID string // @csid:type [SOM]
|
||||
`
|
||||
writeTestFile(t, dir, "models.go", src)
|
||||
dest := filepath.Join(dir, "csid_gen.go")
|
||||
|
||||
err := GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
content1, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
err = GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
content2, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, string(content1), string(content2))
|
||||
}
|
||||
|
||||
func TestGenerateCharsetIDSpecsErrorsWithoutPackage(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `// Code generated by csid-generate.go DO NOT EDIT.
|
||||
|
||||
package x
|
||||
|
||||
type SkippedID string // @csid:type [SKP]
|
||||
`
|
||||
writeTestFile(t, dir, "z.go", src)
|
||||
dest := filepath.Join(dir, "csid_gen.go")
|
||||
|
||||
err := GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{})
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestGenerateCharsetIDSpecsMissingDir(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "definitely-missing")
|
||||
err := GenerateCharsetIDSpecs(dir, filepath.Join(dir, "csid_gen.go"), CSIDGenOptions{})
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestFmtCSIDOutputContainsAllNames(t *testing.T) {
|
||||
ids := []CSIDDef{
|
||||
{File: "a.go", FileRelative: "a.go", Name: "AlphaID", Prefix: "ALP"},
|
||||
{File: "b.go", FileRelative: "b.go", Name: "BetaID", Prefix: "BET"},
|
||||
}
|
||||
out := fmtCSIDOutput("CHK_XYZ", ids, "models")
|
||||
tst.AssertTrue(t, strings.Contains(out, "package models"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "CHK_XYZ"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "AlphaID"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "BetaID"))
|
||||
tst.AssertTrue(t, strings.Contains(out, `prefixAlphaID`) && strings.Contains(out, `"ALP"`))
|
||||
tst.AssertTrue(t, strings.Contains(out, `prefixBetaID`) && strings.Contains(out, `"BET"`))
|
||||
}
|
||||
@@ -0,0 +1,369 @@
|
||||
package bfcodegen
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProcessEnumFileBasicStringEnum(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package mymodels
|
||||
|
||||
type Color string // @enum:type
|
||||
|
||||
const (
|
||||
ColorRed Color = "red"
|
||||
ColorBlue Color = "blue"
|
||||
ColorGreen Color = "green"
|
||||
)
|
||||
`
|
||||
fp := writeTestFile(t, dir, "color.go", src)
|
||||
|
||||
enums, pkg, err := processEnumFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "mymodels")
|
||||
tst.AssertEqual(t, len(enums), 1)
|
||||
tst.AssertEqual(t, enums[0].EnumTypeName, "Color")
|
||||
tst.AssertEqual(t, enums[0].Type, "string")
|
||||
tst.AssertEqual(t, len(enums[0].Values), 3)
|
||||
tst.AssertEqual(t, enums[0].Values[0].VarName, "ColorRed")
|
||||
tst.AssertEqual(t, enums[0].Values[0].Value, `"red"`)
|
||||
tst.AssertEqual(t, enums[0].Values[1].VarName, "ColorBlue")
|
||||
tst.AssertEqual(t, enums[0].Values[2].VarName, "ColorGreen")
|
||||
}
|
||||
|
||||
func TestProcessEnumFileIntEnum(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package m
|
||||
|
||||
type Priority int // @enum:type
|
||||
|
||||
const (
|
||||
PriorityLow Priority = 0
|
||||
PriorityMedium Priority = 1
|
||||
PriorityHigh Priority = 2
|
||||
)
|
||||
`
|
||||
fp := writeTestFile(t, dir, "prio.go", src)
|
||||
|
||||
enums, pkg, err := processEnumFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "m")
|
||||
tst.AssertEqual(t, len(enums), 1)
|
||||
tst.AssertEqual(t, enums[0].EnumTypeName, "Priority")
|
||||
tst.AssertEqual(t, enums[0].Type, "int")
|
||||
tst.AssertEqual(t, len(enums[0].Values), 3)
|
||||
tst.AssertEqual(t, enums[0].Values[0].Value, "0")
|
||||
tst.AssertEqual(t, enums[0].Values[2].Value, "2")
|
||||
}
|
||||
|
||||
func TestProcessEnumFileWithDescriptions(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package m
|
||||
|
||||
type Status string // @enum:type
|
||||
|
||||
const (
|
||||
StatusActive Status = "active" // The active status
|
||||
StatusInactive Status = "inactive" // The inactive status
|
||||
)
|
||||
`
|
||||
fp := writeTestFile(t, dir, "s.go", src)
|
||||
|
||||
enums, _, err := processEnumFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, len(enums), 1)
|
||||
tst.AssertEqual(t, len(enums[0].Values), 2)
|
||||
|
||||
v0 := enums[0].Values[0]
|
||||
tst.AssertTrue(t, v0.Description != nil)
|
||||
tst.AssertEqual(t, *v0.Description, "The active status")
|
||||
tst.AssertTrue(t, v0.Data == nil)
|
||||
}
|
||||
|
||||
func TestProcessEnumFileWithDataComment(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package m
|
||||
|
||||
type Severity string // @enum:type
|
||||
|
||||
const (
|
||||
SeverityLow Severity = "low" // {"description": "Low severity", "weight": 1}
|
||||
SeverityHigh Severity = "high" // {"description": "High severity", "weight": 9}
|
||||
)
|
||||
`
|
||||
fp := writeTestFile(t, dir, "sev.go", src)
|
||||
|
||||
enums, _, err := processEnumFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, len(enums), 1)
|
||||
tst.AssertEqual(t, len(enums[0].Values), 2)
|
||||
|
||||
v0 := enums[0].Values[0]
|
||||
tst.AssertTrue(t, v0.Data != nil)
|
||||
tst.AssertTrue(t, v0.Description != nil)
|
||||
tst.AssertEqual(t, *v0.Description, "Low severity")
|
||||
}
|
||||
|
||||
func TestProcessEnumFileNonMatchingValuesNotAttached(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package m
|
||||
|
||||
type Color string // @enum:type
|
||||
|
||||
const (
|
||||
ColorRed Color = "red"
|
||||
)
|
||||
|
||||
const (
|
||||
OtherX OtherType = "x"
|
||||
)
|
||||
`
|
||||
fp := writeTestFile(t, dir, "c.go", src)
|
||||
|
||||
enums, _, err := processEnumFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, len(enums), 1)
|
||||
tst.AssertEqual(t, len(enums[0].Values), 1)
|
||||
tst.AssertEqual(t, enums[0].Values[0].VarName, "ColorRed")
|
||||
}
|
||||
|
||||
func TestProcessEnumFileGeneratedHeaderSkipped(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `// Code generated by enum-generate.go DO NOT EDIT.
|
||||
|
||||
package x
|
||||
|
||||
type Foo string // @enum:type
|
||||
`
|
||||
fp := writeTestFile(t, dir, "skip.go", src)
|
||||
|
||||
enums, pkg, err := processEnumFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "")
|
||||
tst.AssertEqual(t, len(enums), 0)
|
||||
}
|
||||
|
||||
func TestTryParseDataCommentValid(t *testing.T) {
|
||||
m, ok := tryParseDataComment(`{"description": "hello", "weight": 5}`)
|
||||
tst.AssertTrue(t, ok)
|
||||
descr, _ := m["description"].(string)
|
||||
tst.AssertEqual(t, descr, "hello")
|
||||
weight, _ := m["weight"].(float64)
|
||||
tst.AssertEqual(t, weight, float64(5))
|
||||
}
|
||||
|
||||
func TestTryParseDataCommentBool(t *testing.T) {
|
||||
m, ok := tryParseDataComment(`{"a": true, "b": false}`)
|
||||
tst.AssertTrue(t, ok)
|
||||
a, _ := m["a"].(bool)
|
||||
tst.AssertTrue(t, a)
|
||||
b, _ := m["b"].(bool)
|
||||
tst.AssertFalse(t, b)
|
||||
}
|
||||
|
||||
func TestTryParseDataCommentRejectsNull(t *testing.T) {
|
||||
// null becomes a nil interface — its reflect.Kind is Invalid, not Pointer,
|
||||
// so it does not match any of the allowed kinds and is rejected.
|
||||
_, ok := tryParseDataComment(`{"x": null}`)
|
||||
tst.AssertFalse(t, ok)
|
||||
}
|
||||
|
||||
func TestTryParseDataCommentInvalidJSON(t *testing.T) {
|
||||
_, ok := tryParseDataComment(`{not valid json}`)
|
||||
tst.AssertFalse(t, ok)
|
||||
}
|
||||
|
||||
func TestTryParseDataCommentRejectsArrays(t *testing.T) {
|
||||
// arrays as values are not in the supported kinds list
|
||||
_, ok := tryParseDataComment(`{"x": [1, 2, 3]}`)
|
||||
tst.AssertFalse(t, ok)
|
||||
}
|
||||
|
||||
func TestTryParseDataCommentRejectsObjects(t *testing.T) {
|
||||
_, ok := tryParseDataComment(`{"x": {"nested": 1}}`)
|
||||
tst.AssertFalse(t, ok)
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsEndToEnd(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type Color string // @enum:type
|
||||
|
||||
const (
|
||||
ColorRed Color = "red"
|
||||
ColorGreen Color = "green"
|
||||
ColorBlue Color = "blue"
|
||||
)
|
||||
`
|
||||
writeTestFile(t, dir, "color.go", src)
|
||||
dest := filepath.Join(dir, "enum_gen.go")
|
||||
|
||||
err := GenerateEnumSpecs(dir, dest, EnumGenOptions{
|
||||
DebugOutput: langext.PFalse,
|
||||
GoFormat: langext.PTrue,
|
||||
})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
out, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
outStr := string(out)
|
||||
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "package models"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ChecksumEnumGenerator"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ParseColor"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ColorValues"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ColorRed"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ColorBlue"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ColorGreen"))
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsDeterministic(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type Status string // @enum:type
|
||||
|
||||
const (
|
||||
StatusActive Status = "active" // The active one
|
||||
StatusOff Status = "off" // The off one
|
||||
)
|
||||
`
|
||||
writeTestFile(t, dir, "s.go", src)
|
||||
|
||||
s1, cs1, changed1, err := _generateEnumSpecs(dir, "", "N/A", true, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, changed1)
|
||||
|
||||
s2, cs2, changed2, err := _generateEnumSpecs(dir, "", "N/A", true, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, changed2)
|
||||
|
||||
tst.AssertEqual(t, cs1, cs2)
|
||||
tst.AssertEqual(t, s1, s2)
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsNoChangeWhenChecksumMatches(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type Status string // @enum:type
|
||||
|
||||
const (
|
||||
StatusActive Status = "active"
|
||||
)
|
||||
`
|
||||
writeTestFile(t, dir, "s.go", src)
|
||||
|
||||
_, cs, changed, err := _generateEnumSpecs(dir, "", "N/A", true, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, changed)
|
||||
|
||||
s2, cs2, changed2, err := _generateEnumSpecs(dir, "", cs, true, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertFalse(t, changed2)
|
||||
tst.AssertEqual(t, cs2, cs)
|
||||
tst.AssertEqual(t, s2, "")
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsWithoutGoFormat(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type Color string // @enum:type
|
||||
|
||||
const (
|
||||
ColorRed Color = "red"
|
||||
)
|
||||
`
|
||||
writeTestFile(t, dir, "c.go", src)
|
||||
|
||||
out, _, _, err := _generateEnumSpecs(dir, "", "N/A", false, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.Contains(out, "ColorRed"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "package models"))
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsErrorsWithoutPackage(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `// Code generated by enum-generate.go DO NOT EDIT.
|
||||
|
||||
package x
|
||||
|
||||
type Foo string // @enum:type
|
||||
`
|
||||
writeTestFile(t, dir, "z.go", src)
|
||||
|
||||
_, _, _, err := _generateEnumSpecs(dir, "", "N/A", false, false)
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsMissingDir(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "definitely-missing")
|
||||
_, _, _, err := _generateEnumSpecs(dir, "", "N/A", false, false)
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestFmtEnumOutputContainsTypes(t *testing.T) {
|
||||
descr := "the red one"
|
||||
enums := []EnumDef{
|
||||
{
|
||||
File: "color.go",
|
||||
FileRelative: "color.go",
|
||||
EnumTypeName: "Color",
|
||||
Type: "string",
|
||||
Values: []EnumDefVal{
|
||||
{VarName: "ColorRed", Value: `"red"`, Description: &descr},
|
||||
{VarName: "ColorBlue", Value: `"blue"`, Description: &descr},
|
||||
},
|
||||
},
|
||||
}
|
||||
out := fmtEnumOutput("CHK1", enums, "models")
|
||||
tst.AssertTrue(t, strings.Contains(out, "package models"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "CHK1"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "ColorRed"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "ColorBlue"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "ParseColor"))
|
||||
}
|
||||
|
||||
func TestGenerateEnumSpecsSkipsGenFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type Color string // @enum:type
|
||||
|
||||
const (
|
||||
ColorRed Color = "red"
|
||||
)
|
||||
`
|
||||
writeTestFile(t, dir, "c.go", src)
|
||||
|
||||
// generated file in same dir - should be filtered out
|
||||
gensrc := `package models
|
||||
|
||||
type ShouldBeIgnored string // @enum:type
|
||||
`
|
||||
writeTestFile(t, dir, "ignored_gen.go", gensrc)
|
||||
|
||||
out, _, _, err := _generateEnumSpecs(dir, filepath.Join(dir, "enum_gen.go"), "N/A", false, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.Contains(out, "ColorRed"))
|
||||
tst.AssertFalse(t, strings.Contains(out, "ShouldBeIgnored"))
|
||||
}
|
||||
@@ -0,0 +1,209 @@
|
||||
package bfcodegen
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func writeTestFile(t *testing.T, dir string, name string, content string) string {
|
||||
t.Helper()
|
||||
p := filepath.Join(dir, name)
|
||||
err := os.WriteFile(p, []byte(content), 0o644)
|
||||
tst.AssertNoErr(t, err)
|
||||
return p
|
||||
}
|
||||
|
||||
func TestProcessIDFileSimple(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package mymodels
|
||||
|
||||
type UserID string // @id:type
|
||||
type OrderID string // @id:type
|
||||
`
|
||||
fp := writeTestFile(t, dir, "models.go", src)
|
||||
|
||||
ids, pkg, err := processIDFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "mymodels")
|
||||
tst.AssertEqual(t, len(ids), 2)
|
||||
tst.AssertEqual(t, ids[0].Name, "UserID")
|
||||
tst.AssertEqual(t, ids[1].Name, "OrderID")
|
||||
tst.AssertEqual(t, ids[0].FileRelative, "models.go")
|
||||
}
|
||||
|
||||
func TestProcessIDFileNoMatches(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package x
|
||||
|
||||
type Foo string
|
||||
type Bar int
|
||||
type Baz string // not the right marker
|
||||
`
|
||||
fp := writeTestFile(t, dir, "x.go", src)
|
||||
|
||||
ids, pkg, err := processIDFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "x")
|
||||
tst.AssertEqual(t, len(ids), 0)
|
||||
}
|
||||
|
||||
func TestProcessIDFileGeneratedHeaderSkipped(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `// Code generated by id-generate.go DO NOT EDIT.
|
||||
|
||||
package x
|
||||
|
||||
type SkipMeID string // @id:type
|
||||
`
|
||||
fp := writeTestFile(t, dir, "skip.go", src)
|
||||
|
||||
ids, pkg, err := processIDFile(dir, fp, false)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, pkg, "")
|
||||
tst.AssertEqual(t, len(ids), 0)
|
||||
}
|
||||
|
||||
func TestProcessIDFileMissingFile(t *testing.T) {
|
||||
_, _, err := processIDFile(t.TempDir(), filepath.Join(t.TempDir(), "does_not_exist.go"), false)
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestGenerateIDSpecsEndToEnd(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src1 := `package models
|
||||
|
||||
type UserID string // @id:type
|
||||
type AnyID string // @id:type
|
||||
`
|
||||
writeTestFile(t, dir, "a_models.go", src1)
|
||||
|
||||
src2 := `package models
|
||||
|
||||
type OrderID string // @id:type
|
||||
`
|
||||
writeTestFile(t, dir, "b_models.go", src2)
|
||||
|
||||
dest := filepath.Join(dir, "id_gen.go")
|
||||
|
||||
err := GenerateIDSpecs(dir, dest, IDGenOptions{DebugOutput: langext.PFalse})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
out, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
outStr := string(out)
|
||||
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "package models"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "ChecksumIDGenerator"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "func NewUserID()"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "func NewOrderID()"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "func NewAnyID()"))
|
||||
tst.AssertTrue(t, strings.Contains(outStr, "AsAny()"))
|
||||
}
|
||||
|
||||
func TestGenerateIDSpecsIdempotentWhenUnchanged(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type SomeID string // @id:type
|
||||
`
|
||||
writeTestFile(t, dir, "models.go", src)
|
||||
dest := filepath.Join(dir, "id_gen.go")
|
||||
|
||||
err := GenerateIDSpecs(dir, dest, IDGenOptions{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
stat1, err := os.Stat(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
content1, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
err = GenerateIDSpecs(dir, dest, IDGenOptions{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
stat2, err := os.Stat(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
content2, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, stat1.ModTime().Equal(stat2.ModTime()), true)
|
||||
tst.AssertEqual(t, string(content1), string(content2))
|
||||
}
|
||||
|
||||
func TestGenerateIDSpecsRegeneratesAfterChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `package models
|
||||
|
||||
type FirstID string // @id:type
|
||||
`
|
||||
fp := writeTestFile(t, dir, "models.go", src)
|
||||
dest := filepath.Join(dir, "id_gen.go")
|
||||
|
||||
err := GenerateIDSpecs(dir, dest, IDGenOptions{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
content1, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.Contains(string(content1), "FirstID"))
|
||||
tst.AssertFalse(t, strings.Contains(string(content1), "SecondID"))
|
||||
|
||||
src2 := `package models
|
||||
|
||||
type FirstID string // @id:type
|
||||
type SecondID string // @id:type
|
||||
`
|
||||
err = os.WriteFile(fp, []byte(src2), 0o644)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
err = GenerateIDSpecs(dir, dest, IDGenOptions{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
content2, err := os.ReadFile(dest)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.Contains(string(content2), "SecondID"))
|
||||
}
|
||||
|
||||
func TestGenerateIDSpecsErrorsWithoutPackage(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
src := `// Code generated by id-generate.go DO NOT EDIT.
|
||||
|
||||
package x
|
||||
|
||||
type SkippedID string // @id:type
|
||||
`
|
||||
writeTestFile(t, dir, "z.go", src)
|
||||
dest := filepath.Join(dir, "id_gen.go")
|
||||
|
||||
err := GenerateIDSpecs(dir, dest, IDGenOptions{})
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestGenerateIDSpecsMissingDir(t *testing.T) {
|
||||
dir := filepath.Join(t.TempDir(), "definitely-missing")
|
||||
err := GenerateIDSpecs(dir, filepath.Join(dir, "id_gen.go"), IDGenOptions{})
|
||||
tst.AssertTrue(t, err != nil)
|
||||
}
|
||||
|
||||
func TestFmtIDOutputContainsAllNames(t *testing.T) {
|
||||
ids := []IDDef{
|
||||
{File: "a.go", FileRelative: "a.go", Name: "AlphaID"},
|
||||
{File: "b.go", FileRelative: "b.go", Name: "BetaID"},
|
||||
}
|
||||
out := fmtIDOutput("CHK_ABC", ids, "models")
|
||||
tst.AssertTrue(t, strings.Contains(out, "package models"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "CHK_ABC"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "AlphaID"))
|
||||
tst.AssertTrue(t, strings.Contains(out, "BetaID"))
|
||||
}
|
||||
+3
-3
@@ -2,7 +2,7 @@ package cmdext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -61,12 +61,12 @@ func (r *CommandRunner) Envs(env []string) *CommandRunner {
|
||||
}
|
||||
|
||||
func (r *CommandRunner) EnsureExitcode(arg ...int) *CommandRunner {
|
||||
r.enforceExitCodes = langext.Ptr(langext.ForceArray(arg))
|
||||
r.enforceExitCodes = new(langext.ForceArray(arg))
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *CommandRunner) FailOnExitCode() *CommandRunner {
|
||||
r.enforceExitCodes = langext.Ptr([]int{0})
|
||||
r.enforceExitCodes = new([]int{0})
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,324 @@
|
||||
package cmdext
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRunnerInit(t *testing.T) {
|
||||
r := Runner("myprog")
|
||||
|
||||
if r == nil {
|
||||
t.Fatalf("Runner returned nil")
|
||||
}
|
||||
if r.program != "myprog" {
|
||||
t.Errorf("program == %v, want myprog", r.program)
|
||||
}
|
||||
if r.args == nil {
|
||||
t.Errorf("args is nil, want empty slice")
|
||||
}
|
||||
if len(r.args) != 0 {
|
||||
t.Errorf("len(args) == %v, want 0", len(r.args))
|
||||
}
|
||||
if r.env == nil {
|
||||
t.Errorf("env is nil, want empty slice")
|
||||
}
|
||||
if len(r.env) != 0 {
|
||||
t.Errorf("len(env) == %v, want 0", len(r.env))
|
||||
}
|
||||
if r.listener == nil {
|
||||
t.Errorf("listener is nil, want empty slice")
|
||||
}
|
||||
if len(r.listener) != 0 {
|
||||
t.Errorf("len(listener) == %v, want 0", len(r.listener))
|
||||
}
|
||||
if r.timeout != nil {
|
||||
t.Errorf("timeout == %v, want nil", r.timeout)
|
||||
}
|
||||
if r.enforceExitCodes != nil {
|
||||
t.Errorf("enforceExitCodes == %v, want nil", r.enforceExitCodes)
|
||||
}
|
||||
if r.enforceNoTimeout {
|
||||
t.Errorf("enforceNoTimeout == true, want false")
|
||||
}
|
||||
if r.enforceNoStderr {
|
||||
t.Errorf("enforceNoStderr == true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgSingle(t *testing.T) {
|
||||
r := Runner("p").Arg("a")
|
||||
if !reflect.DeepEqual(r.args, []string{"a"}) {
|
||||
t.Errorf("args == %v, want [a]", r.args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgMultiple(t *testing.T) {
|
||||
r := Runner("p").Arg("a").Arg("b").Arg("c")
|
||||
if !reflect.DeepEqual(r.args, []string{"a", "b", "c"}) {
|
||||
t.Errorf("args == %v, want [a b c]", r.args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgsAppendsAll(t *testing.T) {
|
||||
r := Runner("p").Args([]string{"x", "y"}).Args([]string{"z"})
|
||||
if !reflect.DeepEqual(r.args, []string{"x", "y", "z"}) {
|
||||
t.Errorf("args == %v, want [x y z]", r.args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgAndArgsMixed(t *testing.T) {
|
||||
r := Runner("p").Arg("a").Args([]string{"b", "c"}).Arg("d")
|
||||
if !reflect.DeepEqual(r.args, []string{"a", "b", "c", "d"}) {
|
||||
t.Errorf("args == %v, want [a b c d]", r.args)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArgsEmptySlice(t *testing.T) {
|
||||
r := Runner("p").Args([]string{})
|
||||
if len(r.args) != 0 {
|
||||
t.Errorf("len(args) == %v, want 0", len(r.args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutSet(t *testing.T) {
|
||||
d := 500 * time.Millisecond
|
||||
r := Runner("p").Timeout(d)
|
||||
if r.timeout == nil {
|
||||
t.Fatalf("timeout is nil")
|
||||
}
|
||||
if *r.timeout != d {
|
||||
t.Errorf("timeout == %v, want %v", *r.timeout, d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutOverride(t *testing.T) {
|
||||
r := Runner("p").Timeout(1 * time.Second).Timeout(2 * time.Second)
|
||||
if *r.timeout != 2*time.Second {
|
||||
t.Errorf("timeout == %v, want 2s", *r.timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnv(t *testing.T) {
|
||||
r := Runner("p").Env("KEY", "VALUE")
|
||||
if !reflect.DeepEqual(r.env, []string{"KEY=VALUE"}) {
|
||||
t.Errorf("env == %v, want [KEY=VALUE]", r.env)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvMultiple(t *testing.T) {
|
||||
r := Runner("p").Env("A", "1").Env("B", "2")
|
||||
if !reflect.DeepEqual(r.env, []string{"A=1", "B=2"}) {
|
||||
t.Errorf("env == %v, want [A=1 B=2]", r.env)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvWithEmptyValue(t *testing.T) {
|
||||
r := Runner("p").Env("KEY", "")
|
||||
if !reflect.DeepEqual(r.env, []string{"KEY="}) {
|
||||
t.Errorf("env == %v, want [KEY=]", r.env)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRawEnv(t *testing.T) {
|
||||
r := Runner("p").RawEnv("FOO=BAR=BAZ")
|
||||
if !reflect.DeepEqual(r.env, []string{"FOO=BAR=BAZ"}) {
|
||||
t.Errorf("env == %v, want [FOO=BAR=BAZ]", r.env)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvs(t *testing.T) {
|
||||
r := Runner("p").Envs([]string{"A=1", "B=2"})
|
||||
if !reflect.DeepEqual(r.env, []string{"A=1", "B=2"}) {
|
||||
t.Errorf("env == %v, want [A=1 B=2]", r.env)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvMixed(t *testing.T) {
|
||||
r := Runner("p").Env("A", "1").RawEnv("B=2").Envs([]string{"C=3", "D=4"})
|
||||
if !reflect.DeepEqual(r.env, []string{"A=1", "B=2", "C=3", "D=4"}) {
|
||||
t.Errorf("env == %v, want [A=1 B=2 C=3 D=4]", r.env)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureExitcodeSingle(t *testing.T) {
|
||||
r := Runner("p").EnsureExitcode(2)
|
||||
if r.enforceExitCodes == nil {
|
||||
t.Fatalf("enforceExitCodes is nil")
|
||||
}
|
||||
if !reflect.DeepEqual(*r.enforceExitCodes, []int{2}) {
|
||||
t.Errorf("enforceExitCodes == %v, want [2]", *r.enforceExitCodes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureExitcodeMultiple(t *testing.T) {
|
||||
r := Runner("p").EnsureExitcode(0, 1, 2)
|
||||
if r.enforceExitCodes == nil {
|
||||
t.Fatalf("enforceExitCodes is nil")
|
||||
}
|
||||
if !reflect.DeepEqual(*r.enforceExitCodes, []int{0, 1, 2}) {
|
||||
t.Errorf("enforceExitCodes == %v, want [0 1 2]", *r.enforceExitCodes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailOnExitCode(t *testing.T) {
|
||||
r := Runner("p").FailOnExitCode()
|
||||
if r.enforceExitCodes == nil {
|
||||
t.Fatalf("enforceExitCodes is nil")
|
||||
}
|
||||
if !reflect.DeepEqual(*r.enforceExitCodes, []int{0}) {
|
||||
t.Errorf("enforceExitCodes == %v, want [0]", *r.enforceExitCodes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailOnTimeoutFlag(t *testing.T) {
|
||||
r := Runner("p")
|
||||
if r.enforceNoTimeout {
|
||||
t.Errorf("enforceNoTimeout was true before set")
|
||||
}
|
||||
r = r.FailOnTimeout()
|
||||
if !r.enforceNoTimeout {
|
||||
t.Errorf("enforceNoTimeout == false after FailOnTimeout()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailOnStderrFlag(t *testing.T) {
|
||||
r := Runner("p")
|
||||
if r.enforceNoStderr {
|
||||
t.Errorf("enforceNoStderr was true before set")
|
||||
}
|
||||
r = r.FailOnStderr()
|
||||
if !r.enforceNoStderr {
|
||||
t.Errorf("enforceNoStderr == false after FailOnStderr()")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListen(t *testing.T) {
|
||||
r := Runner("p").Listen(genericCommandListener{})
|
||||
if len(r.listener) != 1 {
|
||||
t.Errorf("len(listener) == %v, want 1", len(r.listener))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenMultiple(t *testing.T) {
|
||||
r := Runner("p").
|
||||
Listen(genericCommandListener{}).
|
||||
Listen(genericCommandListener{}).
|
||||
Listen(genericCommandListener{})
|
||||
if len(r.listener) != 3 {
|
||||
t.Errorf("len(listener) == %v, want 3", len(r.listener))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenStdoutAddsListener(t *testing.T) {
|
||||
r := Runner("p").ListenStdout(func(string) {})
|
||||
if len(r.listener) != 1 {
|
||||
t.Errorf("len(listener) == %v, want 1", len(r.listener))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenStdoutForwardsCalls(t *testing.T) {
|
||||
got := ""
|
||||
r := Runner("p").ListenStdout(func(s string) { got = s })
|
||||
if len(r.listener) != 1 {
|
||||
t.Fatalf("len(listener) == %v, want 1", len(r.listener))
|
||||
}
|
||||
r.listener[0].ReadStdoutLine("hello")
|
||||
if got != "hello" {
|
||||
t.Errorf("listener got %q, want hello", got)
|
||||
}
|
||||
// non-stdout methods should not panic and should not affect state
|
||||
r.listener[0].ReadStderrLine("nope")
|
||||
r.listener[0].ReadRawStdout([]byte("raw"))
|
||||
r.listener[0].ReadRawStderr([]byte("raw"))
|
||||
r.listener[0].Finished(0)
|
||||
r.listener[0].Timeout()
|
||||
if got != "hello" {
|
||||
t.Errorf("listener got mutated to %q, want hello", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenStderrAddsListener(t *testing.T) {
|
||||
r := Runner("p").ListenStderr(func(string) {})
|
||||
if len(r.listener) != 1 {
|
||||
t.Errorf("len(listener) == %v, want 1", len(r.listener))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenStderrForwardsCalls(t *testing.T) {
|
||||
got := ""
|
||||
r := Runner("p").ListenStderr(func(s string) { got = s })
|
||||
if len(r.listener) != 1 {
|
||||
t.Fatalf("len(listener) == %v, want 1", len(r.listener))
|
||||
}
|
||||
r.listener[0].ReadStderrLine("oops")
|
||||
if got != "oops" {
|
||||
t.Errorf("listener got %q, want oops", got)
|
||||
}
|
||||
r.listener[0].ReadStdoutLine("nope")
|
||||
if got != "oops" {
|
||||
t.Errorf("listener got mutated to %q, want oops", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChainReturnsSameInstance(t *testing.T) {
|
||||
r := Runner("p")
|
||||
if r.Arg("a") != r {
|
||||
t.Errorf("Arg returned different instance")
|
||||
}
|
||||
if r.Args([]string{"b"}) != r {
|
||||
t.Errorf("Args returned different instance")
|
||||
}
|
||||
if r.Timeout(time.Second) != r {
|
||||
t.Errorf("Timeout returned different instance")
|
||||
}
|
||||
if r.Env("K", "V") != r {
|
||||
t.Errorf("Env returned different instance")
|
||||
}
|
||||
if r.RawEnv("K=V") != r {
|
||||
t.Errorf("RawEnv returned different instance")
|
||||
}
|
||||
if r.Envs([]string{"K=V"}) != r {
|
||||
t.Errorf("Envs returned different instance")
|
||||
}
|
||||
if r.EnsureExitcode(0) != r {
|
||||
t.Errorf("EnsureExitcode returned different instance")
|
||||
}
|
||||
if r.FailOnExitCode() != r {
|
||||
t.Errorf("FailOnExitCode returned different instance")
|
||||
}
|
||||
if r.FailOnTimeout() != r {
|
||||
t.Errorf("FailOnTimeout returned different instance")
|
||||
}
|
||||
if r.FailOnStderr() != r {
|
||||
t.Errorf("FailOnStderr returned different instance")
|
||||
}
|
||||
if r.Listen(genericCommandListener{}) != r {
|
||||
t.Errorf("Listen returned different instance")
|
||||
}
|
||||
if r.ListenStdout(func(string) {}) != r {
|
||||
t.Errorf("ListenStdout returned different instance")
|
||||
}
|
||||
if r.ListenStderr(func(string) {}) != r {
|
||||
t.Errorf("ListenStderr returned different instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeparateInstancesIndependent(t *testing.T) {
|
||||
r1 := Runner("p1").Arg("a")
|
||||
r2 := Runner("p2").Arg("b")
|
||||
|
||||
if r1.program != "p1" {
|
||||
t.Errorf("r1.program == %v, want p1", r1.program)
|
||||
}
|
||||
if r2.program != "p2" {
|
||||
t.Errorf("r2.program == %v, want p2", r2.program)
|
||||
}
|
||||
if !reflect.DeepEqual(r1.args, []string{"a"}) {
|
||||
t.Errorf("r1.args == %v, want [a]", r1.args)
|
||||
}
|
||||
if !reflect.DeepEqual(r2.args, []string{"b"}) {
|
||||
t.Errorf("r2.args == %v, want [b]", r2.args)
|
||||
}
|
||||
}
|
||||
+5
-5
@@ -2,9 +2,9 @@ package cmdext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/mathext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/mathext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/syncext"
|
||||
"os/exec"
|
||||
"time"
|
||||
)
|
||||
@@ -37,7 +37,7 @@ func run(opt CommandRunner) (CommandResult, error) {
|
||||
}
|
||||
|
||||
preader := pipeReader{
|
||||
lineBufferSize: langext.Ptr(128 * 1024 * 1024), // 128MB max size of a single line, is hopefully enough....
|
||||
lineBufferSize: new(128 * 1024 * 1024), // 128MB max size of a single line, is hopefully enough....
|
||||
stdout: stdoutPipe,
|
||||
stderr: stderrPipe,
|
||||
}
|
||||
@@ -66,7 +66,7 @@ func run(opt CommandRunner) (CommandResult, error) {
|
||||
|
||||
if opt.enforceNoStderr {
|
||||
listener = append(listener, genericCommandListener{
|
||||
_readRawStderr: langext.Ptr(func(v []byte) {
|
||||
_readRawStderr: new(func(v []byte) {
|
||||
if len(v) > 0 {
|
||||
stderrFailChan <- true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package cmdext
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenericListenerEmptyDoesNotPanic(t *testing.T) {
|
||||
l := genericCommandListener{}
|
||||
l.ReadRawStdout([]byte("x"))
|
||||
l.ReadRawStderr([]byte("x"))
|
||||
l.ReadStdoutLine("x")
|
||||
l.ReadStderrLine("x")
|
||||
l.Finished(0)
|
||||
l.Timeout()
|
||||
}
|
||||
|
||||
func TestGenericListenerReadRawStdout(t *testing.T) {
|
||||
var got []byte
|
||||
fn := func(b []byte) { got = append(got, b...) }
|
||||
l := genericCommandListener{_readRawStdout: &fn}
|
||||
|
||||
l.ReadRawStdout([]byte("hello"))
|
||||
l.ReadRawStdout([]byte(" world"))
|
||||
|
||||
if string(got) != "hello world" {
|
||||
t.Errorf("got %q, want %q", string(got), "hello world")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericListenerReadRawStderr(t *testing.T) {
|
||||
var got []byte
|
||||
fn := func(b []byte) { got = append(got, b...) }
|
||||
l := genericCommandListener{_readRawStderr: &fn}
|
||||
|
||||
l.ReadRawStderr([]byte("err"))
|
||||
|
||||
if string(got) != "err" {
|
||||
t.Errorf("got %q, want %q", string(got), "err")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericListenerReadStdoutLine(t *testing.T) {
|
||||
var got []string
|
||||
fn := func(s string) { got = append(got, s) }
|
||||
l := genericCommandListener{_readStdoutLine: &fn}
|
||||
|
||||
l.ReadStdoutLine("line1")
|
||||
l.ReadStdoutLine("line2")
|
||||
|
||||
if !reflect.DeepEqual(got, []string{"line1", "line2"}) {
|
||||
t.Errorf("got %v, want [line1 line2]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericListenerReadStderrLine(t *testing.T) {
|
||||
var got []string
|
||||
fn := func(s string) { got = append(got, s) }
|
||||
l := genericCommandListener{_readStderrLine: &fn}
|
||||
|
||||
l.ReadStderrLine("line1")
|
||||
l.ReadStderrLine("line2")
|
||||
|
||||
if !reflect.DeepEqual(got, []string{"line1", "line2"}) {
|
||||
t.Errorf("got %v, want [line1 line2]", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericListenerFinished(t *testing.T) {
|
||||
var got int
|
||||
called := false
|
||||
fn := func(v int) { got = v; called = true }
|
||||
l := genericCommandListener{_finished: &fn}
|
||||
|
||||
l.Finished(42)
|
||||
|
||||
if !called {
|
||||
t.Errorf("Finished callback was not called")
|
||||
}
|
||||
if got != 42 {
|
||||
t.Errorf("got %v, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericListenerTimeout(t *testing.T) {
|
||||
called := false
|
||||
fn := func() { called = true }
|
||||
l := genericCommandListener{_timeout: &fn}
|
||||
|
||||
l.Timeout()
|
||||
|
||||
if !called {
|
||||
t.Errorf("Timeout callback was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenericListenerOnlySpecifiedCalled(t *testing.T) {
|
||||
stdoutCalled := false
|
||||
stderrCalled := false
|
||||
stdoutFn := func(string) { stdoutCalled = true }
|
||||
stderrFn := func(string) { stderrCalled = true }
|
||||
l := genericCommandListener{_readStdoutLine: &stdoutFn, _readStderrLine: &stderrFn}
|
||||
|
||||
l.ReadStdoutLine("x")
|
||||
if !stdoutCalled {
|
||||
t.Errorf("stdout callback not called")
|
||||
}
|
||||
if stderrCalled {
|
||||
t.Errorf("stderr callback called when it shouldn't be")
|
||||
}
|
||||
|
||||
stdoutCalled = false
|
||||
l.ReadStderrLine("x")
|
||||
if stdoutCalled {
|
||||
t.Errorf("stdout callback called when it shouldn't be")
|
||||
}
|
||||
if !stderrCalled {
|
||||
t.Errorf("stderr callback not called")
|
||||
}
|
||||
|
||||
// these have no callbacks set; should be no-ops
|
||||
l.ReadRawStdout([]byte("x"))
|
||||
l.ReadRawStderr([]byte("x"))
|
||||
l.Finished(0)
|
||||
l.Timeout()
|
||||
}
|
||||
|
||||
func TestGenericListenerImplementsCommandListener(t *testing.T) {
|
||||
var _ CommandListener = genericCommandListener{}
|
||||
}
|
||||
|
||||
func TestGenericListenerEmptyByteSlice(t *testing.T) {
|
||||
calls := 0
|
||||
fn := func(b []byte) { calls++ }
|
||||
l := genericCommandListener{_readRawStdout: &fn}
|
||||
|
||||
l.ReadRawStdout([]byte{})
|
||||
l.ReadRawStdout(nil)
|
||||
|
||||
if calls != 2 {
|
||||
t.Errorf("calls == %v, want 2", calls)
|
||||
}
|
||||
}
|
||||
+13
-16
@@ -2,8 +2,9 @@ package cmdext
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/syncext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/syncext"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -29,14 +30,14 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
|
||||
wg.Add(1)
|
||||
stdoutBufferReader, stdoutBufferWriter := io.Pipe()
|
||||
stdout := ""
|
||||
var stdout strings.Builder
|
||||
go func() {
|
||||
buf := make([]byte, 128)
|
||||
for {
|
||||
n, err := pr.stdout.Read(buf)
|
||||
if n > 0 {
|
||||
txt := string(buf[:n])
|
||||
stdout += txt
|
||||
stdout.WriteString(txt)
|
||||
_, _ = stdoutBufferWriter.Write(buf[:n])
|
||||
for _, lstr := range listener {
|
||||
lstr.ReadRawStdout(buf[:n])
|
||||
@@ -58,7 +59,7 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
|
||||
wg.Add(1)
|
||||
stderrBufferReader, stderrBufferWriter := io.Pipe()
|
||||
stderr := ""
|
||||
var stderr strings.Builder
|
||||
go func() {
|
||||
buf := make([]byte, 128)
|
||||
for {
|
||||
@@ -66,7 +67,7 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
|
||||
if n > 0 {
|
||||
txt := string(buf[:n])
|
||||
stderr += txt
|
||||
stderr.WriteString(txt)
|
||||
_, _ = stderrBufferWriter.Write(buf[:n])
|
||||
for _, lstr := range listener {
|
||||
lstr.ReadRawStderr(buf[:n])
|
||||
@@ -88,8 +89,7 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
|
||||
// [3] collect stdout line-by-line
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stdoutBufferReader)
|
||||
if pr.lineBufferSize != nil {
|
||||
scanner.Buffer([]byte{}, *pr.lineBufferSize)
|
||||
@@ -105,13 +105,11 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
errch <- err
|
||||
}
|
||||
combch <- combevt{"", true}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
// [4] collect stderr line-by-line
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
wg.Go(func() {
|
||||
scanner := bufio.NewScanner(stderrBufferReader)
|
||||
if pr.lineBufferSize != nil {
|
||||
scanner.Buffer([]byte{}, *pr.lineBufferSize)
|
||||
@@ -127,13 +125,12 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
errch <- err
|
||||
}
|
||||
combch <- combevt{"", true}
|
||||
wg.Done()
|
||||
}()
|
||||
})
|
||||
|
||||
// [5] combine stdcombined
|
||||
|
||||
wg.Add(1)
|
||||
stdcombined := ""
|
||||
var stdcombined strings.Builder
|
||||
go func() {
|
||||
stopctr := 0
|
||||
for stopctr < 2 {
|
||||
@@ -141,7 +138,7 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
if vvv.stop {
|
||||
stopctr++
|
||||
} else {
|
||||
stdcombined += vvv.line + "\n" // this comes from bufio.Scanner and has no newlines...
|
||||
stdcombined.WriteString(vvv.line + "\n") // this comes from bufio.Scanner and has no newlines...
|
||||
}
|
||||
}
|
||||
wg.Done()
|
||||
@@ -154,5 +151,5 @@ func (pr *pipeReader) Read(listener []CommandListener) (string, string, string,
|
||||
return "", "", "", err
|
||||
}
|
||||
|
||||
return stdout, stderr, stdcombined, nil
|
||||
return stdout.String(), stderr.String(), stdcombined.String(), nil
|
||||
}
|
||||
|
||||
+10
-10
@@ -3,7 +3,7 @@ package confext
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/timeext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/timeext"
|
||||
"math/bits"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -101,11 +101,11 @@ func processEnvOverrides(rval reflect.Value, delim string, prefix string) error
|
||||
}
|
||||
|
||||
func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (reflect.Value, error) {
|
||||
if rvtype == reflect.TypeOf("") {
|
||||
if rvtype == reflect.TypeFor[string]() {
|
||||
|
||||
return reflect.ValueOf(envval), nil
|
||||
|
||||
} else if rvtype == reflect.TypeOf(int(0)) {
|
||||
} else if rvtype == reflect.TypeFor[int]() {
|
||||
|
||||
envint, err := strconv.ParseInt(envval, 10, bits.UintSize)
|
||||
if err != nil {
|
||||
@@ -114,7 +114,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
|
||||
return reflect.ValueOf(int(envint)), nil
|
||||
|
||||
} else if rvtype == reflect.TypeOf(int64(0)) {
|
||||
} else if rvtype == reflect.TypeFor[int64]() {
|
||||
|
||||
envint, err := strconv.ParseInt(envval, 10, 64)
|
||||
if err != nil {
|
||||
@@ -123,7 +123,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
|
||||
return reflect.ValueOf(int64(envint)), nil
|
||||
|
||||
} else if rvtype == reflect.TypeOf(int32(0)) {
|
||||
} else if rvtype == reflect.TypeFor[int32]() {
|
||||
|
||||
envint, err := strconv.ParseInt(envval, 10, 32)
|
||||
if err != nil {
|
||||
@@ -132,7 +132,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
|
||||
return reflect.ValueOf(int32(envint)), nil
|
||||
|
||||
} else if rvtype == reflect.TypeOf(int8(0)) {
|
||||
} else if rvtype == reflect.TypeFor[int8]() {
|
||||
|
||||
envint, err := strconv.ParseInt(envval, 10, 8)
|
||||
if err != nil {
|
||||
@@ -141,7 +141,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
|
||||
return reflect.ValueOf(int8(envint)), nil
|
||||
|
||||
} else if rvtype == reflect.TypeOf(time.Duration(0)) {
|
||||
} else if rvtype == reflect.TypeFor[time.Duration]() {
|
||||
|
||||
dur, err := timeext.ParseDurationShortString(envval)
|
||||
if err != nil {
|
||||
@@ -159,7 +159,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
|
||||
return reflect.ValueOf(tim), nil
|
||||
|
||||
} else if rvtype.ConvertibleTo(reflect.TypeOf(int(0))) {
|
||||
} else if rvtype.ConvertibleTo(reflect.TypeFor[int]()) {
|
||||
|
||||
envint, err := strconv.ParseInt(envval, 10, 8)
|
||||
if err != nil {
|
||||
@@ -170,7 +170,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
|
||||
return envcvl, nil
|
||||
|
||||
} else if rvtype.ConvertibleTo(reflect.TypeOf(false)) {
|
||||
} else if rvtype.ConvertibleTo(reflect.TypeFor[bool]()) {
|
||||
|
||||
if strings.TrimSpace(strings.ToLower(envval)) == "true" {
|
||||
return reflect.ValueOf(true).Convert(rvtype), nil
|
||||
@@ -184,7 +184,7 @@ func parseEnvToValue(envval string, fullEnvKey string, rvtype reflect.Type) (ref
|
||||
return reflect.Value{}, errors.New(fmt.Sprintf("Failed to parse env-config variable '%s' to <%s, ,bool> (value := '%s')", rvtype.Name(), fullEnvKey, envval))
|
||||
}
|
||||
|
||||
} else if rvtype.ConvertibleTo(reflect.TypeOf("")) {
|
||||
} else if rvtype.ConvertibleTo(reflect.TypeFor[string]()) {
|
||||
|
||||
envcvl := reflect.ValueOf(envval).Convert(rvtype)
|
||||
return envcvl, nil
|
||||
|
||||
@@ -0,0 +1,390 @@
|
||||
package confext
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestApplyEnvOverridesPrefix(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int `env:"V1"`
|
||||
V2 string `env:"V2"`
|
||||
}
|
||||
|
||||
data := testdata{V1: 1, V2: "x"}
|
||||
|
||||
t.Setenv("MYAPP_V1", "42")
|
||||
t.Setenv("MYAPP_V2", "hello")
|
||||
t.Setenv("V1", "111")
|
||||
t.Setenv("V2", "noprefix")
|
||||
|
||||
err := ApplyEnvOverrides("MYAPP_", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.V1, 42)
|
||||
tst.AssertEqual(t, data.V2, "hello")
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesUnexportedFieldsIgnored(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int `env:"TEST_V1"`
|
||||
v2 int `env:"TEST_V2"`
|
||||
}
|
||||
|
||||
data := testdata{V1: 1, v2: 2}
|
||||
|
||||
t.Setenv("TEST_V1", "11")
|
||||
t.Setenv("TEST_V2", "22")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.V1, 11)
|
||||
tst.AssertEqual(t, data.v2, 2)
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesNoEnvTagIgnored(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int `env:"TEST_V1"`
|
||||
V2 int ``
|
||||
}
|
||||
|
||||
data := testdata{V1: 1, V2: 2}
|
||||
|
||||
t.Setenv("TEST_V1", "11")
|
||||
t.Setenv("V2", "22")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.V1, 11)
|
||||
tst.AssertEqual(t, data.V2, 2)
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesDashTagIgnored(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int `env:"TEST_V1"`
|
||||
V2 string `env:"-"`
|
||||
}
|
||||
|
||||
data := testdata{V1: 1, V2: "no"}
|
||||
|
||||
t.Setenv("TEST_V1", "11")
|
||||
t.Setenv("-", "yes")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.V1, 11)
|
||||
tst.AssertEqual(t, data.V2, "no")
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesEnvNotSetKeepsValue(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int `env:"NOT_SET_INT_KEY_XYZ"`
|
||||
V2 string `env:"NOT_SET_STR_KEY_XYZ"`
|
||||
}
|
||||
|
||||
data := testdata{V1: 7, V2: "keep"}
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.V1, 7)
|
||||
tst.AssertEqual(t, data.V2, "keep")
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesBoolVariants(t *testing.T) {
|
||||
type testdata struct {
|
||||
B1 bool `env:"B1"`
|
||||
B2 bool `env:"B2"`
|
||||
B3 bool `env:"B3"`
|
||||
B4 bool `env:"B4"`
|
||||
B5 bool `env:"B5"`
|
||||
B6 bool `env:"B6"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("B1", "true")
|
||||
t.Setenv("B2", "false")
|
||||
t.Setenv("B3", "1")
|
||||
t.Setenv("B4", "0")
|
||||
t.Setenv("B5", " TRUE ")
|
||||
t.Setenv("B6", "FaLsE")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.B1, true)
|
||||
tst.AssertEqual(t, data.B2, false)
|
||||
tst.AssertEqual(t, data.B3, true)
|
||||
tst.AssertEqual(t, data.B4, false)
|
||||
tst.AssertEqual(t, data.B5, true)
|
||||
tst.AssertEqual(t, data.B6, false)
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidIntReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int `env:"BAD_INT"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_INT", "not_a_number")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid int, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidInt8ReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int8 `env:"BAD_INT8"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_INT8", "9999")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid int8, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidInt32ReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int32 `env:"BAD_INT32"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_INT32", "not_an_int32")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid int32, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidInt64ReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 int64 `env:"BAD_INT64"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_INT64", "not_an_int64")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid int64, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidDurationReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 time.Duration `env:"BAD_DUR"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_DUR", "not_a_duration")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid duration, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidTimeReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 time.Time `env:"BAD_TIME"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_TIME", "not_a_time")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid time, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesInvalidBoolReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 bool `env:"BAD_BOOL"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("BAD_BOOL", "yesno")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid bool, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesUnsupportedTypeReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 []int `env:"UNSUPPORTED"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("UNSUPPORTED", "1,2,3")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for unsupported type, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesFloatUnsupportedReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 float64 `env:"UNSUPPORTED_FLOAT"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("UNSUPPORTED_FLOAT", "1.5")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for float64, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesPointerInvalidReturnsError(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 *int `env:"PTR_BAD_INT"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("PTR_BAD_INT", "not_a_number")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid pointer int, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesPointerNotSetStaysNil(t *testing.T) {
|
||||
type testdata struct {
|
||||
V1 *int `env:"PTR_NOT_SET_KEY_ABC"`
|
||||
V2 *string `env:"PTR_NOT_SET_KEY_DEF"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
if data.V1 != nil {
|
||||
t.Errorf("expected V1 to remain nil, got %v", *data.V1)
|
||||
}
|
||||
if data.V2 != nil {
|
||||
t.Errorf("expected V2 to remain nil, got %v", *data.V2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesAliasBool(t *testing.T) {
|
||||
type aliasbool bool
|
||||
|
||||
type testdata struct {
|
||||
V1 aliasbool `env:"ALIAS_BOOL"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("ALIAS_BOOL", "true")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.V1, aliasbool(true))
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesNestedRecursiveError(t *testing.T) {
|
||||
type subdata struct {
|
||||
V1 int `env:"V1"`
|
||||
}
|
||||
type testdata struct {
|
||||
Sub subdata `env:"SUB"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("SUB.V1", "not_a_number")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
if err == nil {
|
||||
t.Errorf("expected error from nested struct invalid value, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesTimeFieldInsideStructIsParsed(t *testing.T) {
|
||||
type testdata struct {
|
||||
T time.Time `env:"MYTIME"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("MYTIME", "2023-01-02T03:04:05Z")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.T.Equal(time.Date(2023, 1, 2, 3, 4, 5, 0, time.UTC)), true)
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesPointerStringAlias(t *testing.T) {
|
||||
type aliasstr string
|
||||
|
||||
type testdata struct {
|
||||
V1 *aliasstr `env:"PTR_ALIAS_STR"`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("PTR_ALIAS_STR", "hello")
|
||||
|
||||
err := ApplyEnvOverrides("", &data, ".")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
if data.V1 == nil {
|
||||
t.Fatalf("expected V1 to be set")
|
||||
}
|
||||
tst.AssertEqual(t, *data.V1, aliasstr("hello"))
|
||||
}
|
||||
|
||||
func TestApplyEnvOverridesEmptyEnvTagOnSubstruct(t *testing.T) {
|
||||
type subdata struct {
|
||||
V1 int `env:"INNER"`
|
||||
}
|
||||
type testdata struct {
|
||||
Sub subdata `env:""`
|
||||
}
|
||||
|
||||
data := testdata{}
|
||||
|
||||
t.Setenv("INNER", "55")
|
||||
|
||||
err := ApplyEnvOverrides("PRE_", &data, "_")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.Sub.V1, 0)
|
||||
|
||||
t.Setenv("PRE_INNER", "77")
|
||||
|
||||
err = ApplyEnvOverrides("PRE_", &data, "_")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, data.Sub.V1, 77)
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package confext
|
||||
|
||||
import (
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/timeext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/timeext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
package cryptext
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAESSimpleEmptyData(t *testing.T) {
|
||||
pw := []byte("password")
|
||||
enc, err := EncryptAESSimple(pw, []byte{}, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertNotEqual(t, enc, "")
|
||||
|
||||
dec, err := DecryptAESSimple(pw, enc)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, len(dec), 0)
|
||||
}
|
||||
|
||||
func TestAESSimpleEmptyPassword(t *testing.T) {
|
||||
pw := []byte{}
|
||||
plain := []byte("some content")
|
||||
|
||||
enc, err := EncryptAESSimple(pw, plain, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
dec, err := DecryptAESSimple(pw, enc)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(dec), string(plain))
|
||||
}
|
||||
|
||||
func TestAESSimpleWrongPassword(t *testing.T) {
|
||||
plain := []byte("Hello World")
|
||||
enc, err := EncryptAESSimple([]byte("right"), plain, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
_, err = DecryptAESSimple([]byte("wrong"), enc)
|
||||
if err == nil {
|
||||
t.Errorf("expected error when decrypting with wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESSimpleInvalidBase32(t *testing.T) {
|
||||
_, err := DecryptAESSimple([]byte("pw"), "!!!not-base32!!!")
|
||||
if err == nil {
|
||||
t.Errorf("expected error on invalid base32 input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESSimpleInvalidJSON(t *testing.T) {
|
||||
// "AAAAAAAA" decodes to valid base32 but not valid JSON
|
||||
_, err := DecryptAESSimple([]byte("pw"), "AAAAAAAA")
|
||||
if err == nil {
|
||||
t.Errorf("expected error on invalid JSON payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESSimpleEmptyEncText(t *testing.T) {
|
||||
_, err := DecryptAESSimple([]byte("pw"), "")
|
||||
if err == nil {
|
||||
t.Errorf("expected error on empty text")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESSimpleLargeData(t *testing.T) {
|
||||
pw := []byte("hunter12")
|
||||
plain := []byte(strings.Repeat("ABCDEFGHIJ", 1024))
|
||||
|
||||
enc, err := EncryptAESSimple(pw, plain, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
dec, err := DecryptAESSimple(pw, enc)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(dec), string(plain))
|
||||
}
|
||||
|
||||
func TestAESSimpleBinaryData(t *testing.T) {
|
||||
pw := []byte("hunter12")
|
||||
plain := []byte{0x00, 0x01, 0x02, 0x7F, 0x80, 0xFE, 0xFF, 0x00, 0xAA, 0x55}
|
||||
|
||||
enc, err := EncryptAESSimple(pw, plain, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
dec, err := DecryptAESSimple(pw, enc)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertArrayEqual(t, dec, plain)
|
||||
}
|
||||
|
||||
func TestAESSimpleDifferentRoundsForEachCall(t *testing.T) {
|
||||
pw := []byte("hunter12")
|
||||
plain := []byte("Hello")
|
||||
|
||||
enc1, err := EncryptAESSimple(pw, plain, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
enc2, err := EncryptAESSimple(pw, plain, 256)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
// Two separate encrypt calls on same plaintext should differ (random salt + IV)
|
||||
tst.AssertNotEqual(t, enc1, enc2)
|
||||
|
||||
// Both should decrypt back to the same plaintext
|
||||
d1, err := DecryptAESSimple(pw, enc1)
|
||||
tst.AssertNoErr(t, err)
|
||||
d2, err := DecryptAESSimple(pw, enc2)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(d1), string(plain))
|
||||
tst.AssertEqual(t, string(d2), string(plain))
|
||||
}
|
||||
|
||||
func TestAESSimpleVariableRounds(t *testing.T) {
|
||||
pw := []byte("hunter12")
|
||||
plain := []byte("rounds-test")
|
||||
|
||||
for _, r := range []int{16, 32, 64, 128, 256, 512, 1024} {
|
||||
enc, err := EncryptAESSimple(pw, plain, r)
|
||||
tst.AssertNoErr(t, err)
|
||||
dec, err := DecryptAESSimple(pw, enc)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(dec), string(plain))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAESSimpleResultIsBase32(t *testing.T) {
|
||||
pw := []byte("hunter12")
|
||||
plain := []byte("Hello World")
|
||||
|
||||
enc, err := EncryptAESSimple(pw, plain, 64)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
for _, c := range enc {
|
||||
isUpper := c >= 'A' && c <= 'Z'
|
||||
isDigit := c >= '2' && c <= '7'
|
||||
if !(isUpper || isDigit) {
|
||||
t.Errorf("non-base32 character %q in output", c)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package cryptext
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package cryptext
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrSha256SameAsBytesSha256(t *testing.T) {
|
||||
inputs := []string{"", "a", "Hello World", "lorem ipsum dolor sit amet", "🎉 unicode"}
|
||||
for _, in := range inputs {
|
||||
tst.AssertEqual(t, StrSha256(in), BytesSha256([]byte(in)))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrSha256Length(t *testing.T) {
|
||||
// SHA-256 hex output must always be 64 characters
|
||||
tst.AssertEqual(t, len(StrSha256("")), 64)
|
||||
tst.AssertEqual(t, len(StrSha256("x")), 64)
|
||||
tst.AssertEqual(t, len(StrSha256(strings.Repeat("x", 10000))), 64)
|
||||
}
|
||||
|
||||
func TestStrSha256Deterministic(t *testing.T) {
|
||||
v := "deterministic input"
|
||||
a := StrSha256(v)
|
||||
b := StrSha256(v)
|
||||
tst.AssertEqual(t, a, b)
|
||||
}
|
||||
|
||||
func TestStrSha256DifferentInputs(t *testing.T) {
|
||||
tst.AssertNotEqual(t, StrSha256("a"), StrSha256("b"))
|
||||
tst.AssertNotEqual(t, StrSha256("Hello"), StrSha256("hello"))
|
||||
tst.AssertNotEqual(t, StrSha256("Hello World"), StrSha256("Hello World "))
|
||||
}
|
||||
|
||||
func TestStrSha256IsHex(t *testing.T) {
|
||||
out := StrSha256("anything")
|
||||
for _, c := range out {
|
||||
isLowerHex := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')
|
||||
if !isLowerHex {
|
||||
t.Errorf("non-hex char %q in StrSha256 output", c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBytesSha256NilSameAsEmpty(t *testing.T) {
|
||||
tst.AssertEqual(t, BytesSha256(nil), BytesSha256([]byte{}))
|
||||
}
|
||||
|
||||
func TestBytesSha256KnownVectors(t *testing.T) {
|
||||
// "abc" => sha-256 standard vector
|
||||
tst.AssertEqual(t, StrSha256("abc"), "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package cryptext
|
||||
|
||||
import (
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
||||
+12
-3
@@ -6,13 +6,15 @@ import (
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/totpext"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/totpext"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const LatestPassHashVersion = 5
|
||||
@@ -317,6 +319,13 @@ func (ph PassHash) String() string {
|
||||
return string(ph)
|
||||
}
|
||||
|
||||
func (ph PassHash) MarshalJSON() ([]byte, error) {
|
||||
if ph == "" {
|
||||
return json.Marshal("")
|
||||
}
|
||||
return json.Marshal("*****")
|
||||
}
|
||||
|
||||
func HashPassword(plainpass string, totpSecret []byte) (PassHash, error) {
|
||||
return HashPasswordV5(plainpass, totpSecret)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,379 @@
|
||||
package cryptext
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPassHashInvalidEmpty(t *testing.T) {
|
||||
ph := PassHash("")
|
||||
tst.AssertFalse(t, ph.Valid())
|
||||
tst.AssertFalse(t, ph.HasTOTP())
|
||||
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
|
||||
}
|
||||
|
||||
func TestPassHashInvalidGarbage(t *testing.T) {
|
||||
for _, raw := range []string{
|
||||
"garbage",
|
||||
"99|nope",
|
||||
"abc|payload",
|
||||
"3|onlytwo",
|
||||
"4|onlytwo",
|
||||
"5|onlytwo",
|
||||
"2|notbase64!|notbase64!",
|
||||
"1|!!!notbase64!!!",
|
||||
"3|!!notb64|!!notb64|0",
|
||||
"3|abc|!!notb64|0",
|
||||
} {
|
||||
ph := PassHash(raw)
|
||||
if ph.Valid() {
|
||||
t.Errorf("expected %q to be invalid", raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashVerifyInvalid(t *testing.T) {
|
||||
ph := PassHash("garbage-value")
|
||||
tst.AssertFalse(t, ph.Verify("anything", nil))
|
||||
}
|
||||
|
||||
func TestPassHashUpgradeInvalid(t *testing.T) {
|
||||
ph := PassHash("garbage-value")
|
||||
_, err := ph.Upgrade("anything")
|
||||
if err == nil {
|
||||
t.Errorf("expected error for invalid PassHash upgrade")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashStringRoundtrip(t *testing.T) {
|
||||
ph, err := HashPassword("hunter2", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, ph.String(), string(ph))
|
||||
}
|
||||
|
||||
func TestPassHashMarshalJSONEmpty(t *testing.T) {
|
||||
ph := PassHash("")
|
||||
data, err := json.Marshal(ph)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(data), `""`)
|
||||
}
|
||||
|
||||
func TestPassHashMarshalJSONMasked(t *testing.T) {
|
||||
ph, err := HashPassword("hunter2", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
data, err := json.Marshal(ph)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(data), `"*****"`)
|
||||
}
|
||||
|
||||
func TestPassHashDataV0(t *testing.T) {
|
||||
ph, err := HashPasswordV0("test123")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, seed, payload, hastotp, totpsecret, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, 0)
|
||||
tst.AssertEqual(t, len(seed), 0)
|
||||
tst.AssertEqual(t, string(payload), "test123")
|
||||
tst.AssertFalse(t, hastotp)
|
||||
tst.AssertEqual(t, len(totpsecret), 0)
|
||||
}
|
||||
|
||||
func TestPassHashDataV1(t *testing.T) {
|
||||
ph, err := HashPasswordV1("test123")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, seed, payload, hastotp, _, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, 1)
|
||||
tst.AssertEqual(t, len(seed), 0)
|
||||
tst.AssertEqual(t, len(payload), 32) // sha-256 is 32 bytes
|
||||
tst.AssertFalse(t, hastotp)
|
||||
}
|
||||
|
||||
func TestPassHashDataV2(t *testing.T) {
|
||||
ph, err := HashPasswordV2("test123")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, seed, payload, hastotp, _, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, 2)
|
||||
tst.AssertEqual(t, len(seed), 32)
|
||||
tst.AssertEqual(t, len(payload), 32)
|
||||
tst.AssertFalse(t, hastotp)
|
||||
}
|
||||
|
||||
func TestPassHashDataV3(t *testing.T) {
|
||||
ph, err := HashPasswordV3("test123", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, seed, payload, hastotp, _, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, 3)
|
||||
tst.AssertEqual(t, len(seed), 32)
|
||||
tst.AssertEqual(t, len(payload), 32)
|
||||
tst.AssertFalse(t, hastotp)
|
||||
}
|
||||
|
||||
func TestPassHashDataV4(t *testing.T) {
|
||||
ph, err := HashPasswordV4("test123", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, _, _, hastotp, _, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, 4)
|
||||
tst.AssertFalse(t, hastotp)
|
||||
}
|
||||
|
||||
func TestPassHashDataV5(t *testing.T) {
|
||||
ph, err := HashPasswordV5("test123", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, _, _, hastotp, _, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, 5)
|
||||
tst.AssertFalse(t, hastotp)
|
||||
}
|
||||
|
||||
func TestPassHashLatestIsV5(t *testing.T) {
|
||||
ph, err := HashPassword("test", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
v, _, _, _, _, valid := ph.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, LatestPassHashVersion)
|
||||
tst.AssertEqual(t, v, 5)
|
||||
}
|
||||
|
||||
func TestPassHashUpgradeLatestIsNoop(t *testing.T) {
|
||||
ph, err := HashPassword("test", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertFalse(t, ph.NeedsPasswordUpgrade())
|
||||
|
||||
ph2, err := ph.Upgrade("test")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(ph), string(ph2))
|
||||
}
|
||||
|
||||
func TestPassHashClearTOTPInvalid(t *testing.T) {
|
||||
_, err := PassHash("garbage").ClearTOTP()
|
||||
if err == nil {
|
||||
t.Errorf("expected error from ClearTOTP on invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashClearTOTPV0V1V2Noop(t *testing.T) {
|
||||
ph0, _ := HashPasswordV0("x")
|
||||
r0, err := ph0.ClearTOTP()
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(r0), string(ph0))
|
||||
|
||||
ph1, _ := HashPasswordV1("x")
|
||||
r1, err := ph1.ClearTOTP()
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(r1), string(ph1))
|
||||
|
||||
ph2, _ := HashPasswordV2("x")
|
||||
r2, err := ph2.ClearTOTP()
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, string(r2), string(ph2))
|
||||
}
|
||||
|
||||
func TestPassHashClearTOTPV3(t *testing.T) {
|
||||
secret := []byte{0x01, 0x02, 0x03}
|
||||
ph, err := HashPasswordV3("test123", secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, ph.HasTOTP())
|
||||
|
||||
cleared, err := ph.ClearTOTP()
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertFalse(t, cleared.HasTOTP())
|
||||
tst.AssertTrue(t, cleared.Valid())
|
||||
tst.AssertTrue(t, cleared.Verify("test123", nil))
|
||||
}
|
||||
|
||||
func TestPassHashClearTOTPV4(t *testing.T) {
|
||||
secret := []byte{0x01, 0x02, 0x03}
|
||||
ph, err := HashPasswordV4("test123", secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, ph.HasTOTP())
|
||||
|
||||
cleared, err := ph.ClearTOTP()
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertFalse(t, cleared.HasTOTP())
|
||||
tst.AssertTrue(t, cleared.Verify("test123", nil))
|
||||
}
|
||||
|
||||
func TestPassHashClearTOTPV5(t *testing.T) {
|
||||
secret := []byte{0x01, 0x02, 0x03}
|
||||
ph, err := HashPasswordV5("test123", secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, ph.HasTOTP())
|
||||
|
||||
cleared, err := ph.ClearTOTP()
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertFalse(t, cleared.HasTOTP())
|
||||
tst.AssertTrue(t, cleared.Verify("test123", nil))
|
||||
}
|
||||
|
||||
func TestPassHashWithTOTPInvalid(t *testing.T) {
|
||||
_, err := PassHash("garbage").WithTOTP([]byte{0x01})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for WithTOTP on invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashWithTOTPV0V1V2Errors(t *testing.T) {
|
||||
ph0, _ := HashPasswordV0("x")
|
||||
if _, err := ph0.WithTOTP([]byte{0x01}); err == nil {
|
||||
t.Errorf("expected v0 not to support TOTP")
|
||||
}
|
||||
ph1, _ := HashPasswordV1("x")
|
||||
if _, err := ph1.WithTOTP([]byte{0x01}); err == nil {
|
||||
t.Errorf("expected v1 not to support TOTP")
|
||||
}
|
||||
ph2, _ := HashPasswordV2("x")
|
||||
if _, err := ph2.WithTOTP([]byte{0x01}); err == nil {
|
||||
t.Errorf("expected v2 not to support TOTP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashWithTOTPV3V4V5(t *testing.T) {
|
||||
secret := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||
|
||||
ph3, _ := HashPasswordV3("pw", nil)
|
||||
tst.AssertFalse(t, ph3.HasTOTP())
|
||||
r3, err := ph3.WithTOTP(secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, r3.HasTOTP())
|
||||
|
||||
ph4, _ := HashPasswordV4("pw", nil)
|
||||
tst.AssertFalse(t, ph4.HasTOTP())
|
||||
r4, err := ph4.WithTOTP(secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, r4.HasTOTP())
|
||||
|
||||
ph5, _ := HashPasswordV5("pw", nil)
|
||||
tst.AssertFalse(t, ph5.HasTOTP())
|
||||
r5, err := ph5.WithTOTP(secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, r5.HasTOTP())
|
||||
}
|
||||
|
||||
func TestPassHashChangeInvalid(t *testing.T) {
|
||||
_, err := PassHash("garbage").Change("new-pw")
|
||||
if err == nil {
|
||||
t.Errorf("expected error from Change on invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashChangeKeepsVersion(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
hashed func() (PassHash, error)
|
||||
version int
|
||||
}{
|
||||
{"V0", func() (PassHash, error) { return HashPasswordV0("old") }, 0},
|
||||
{"V1", func() (PassHash, error) { return HashPasswordV1("old") }, 1},
|
||||
{"V2", func() (PassHash, error) { return HashPasswordV2("old") }, 2},
|
||||
{"V3", func() (PassHash, error) { return HashPasswordV3("old", nil) }, 3},
|
||||
{"V4", func() (PassHash, error) { return HashPasswordV4("old", nil) }, 4},
|
||||
{"V5", func() (PassHash, error) { return HashPasswordV5("old", nil) }, 5},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
ph, err := c.hashed()
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
changed, err := ph.Change("new-pw")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
v, _, _, _, _, valid := changed.Data()
|
||||
tst.AssertTrue(t, valid)
|
||||
tst.AssertEqual(t, v, c.version)
|
||||
|
||||
tst.AssertTrue(t, changed.Verify("new-pw", nil))
|
||||
tst.AssertFalse(t, changed.Verify("old", nil))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPassHashChangeKeepsTOTPV3(t *testing.T) {
|
||||
secret := []byte{0xAB, 0xCD}
|
||||
ph, err := HashPasswordV3("old", secret)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, ph.HasTOTP())
|
||||
|
||||
changed, err := ph.Change("new")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, changed.HasTOTP())
|
||||
}
|
||||
|
||||
func TestPassHashV0Format(t *testing.T) {
|
||||
ph, err := HashPasswordV0("plaintext-pw")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.HasPrefix(string(ph), "0|"))
|
||||
tst.AssertEqual(t, string(ph), "0|plaintext-pw")
|
||||
}
|
||||
|
||||
func TestPassHashV1Format(t *testing.T) {
|
||||
ph, err := HashPasswordV1("test")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.HasPrefix(string(ph), "1|"))
|
||||
}
|
||||
|
||||
func TestPassHashV2Format(t *testing.T) {
|
||||
ph, err := HashPasswordV2("test")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.HasPrefix(string(ph), "2|"))
|
||||
tst.AssertEqual(t, strings.Count(string(ph), "|"), 2)
|
||||
}
|
||||
|
||||
func TestPassHashV3Format(t *testing.T) {
|
||||
ph, err := HashPasswordV3("test", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.HasPrefix(string(ph), "3|"))
|
||||
tst.AssertEqual(t, strings.Count(string(ph), "|"), 3)
|
||||
tst.AssertTrue(t, strings.HasSuffix(string(ph), "|0"))
|
||||
}
|
||||
|
||||
func TestPassHashV4Format(t *testing.T) {
|
||||
ph, err := HashPasswordV4("test", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.HasPrefix(string(ph), "4|"))
|
||||
tst.AssertTrue(t, strings.HasSuffix(string(ph), "|0"))
|
||||
}
|
||||
|
||||
func TestPassHashV5Format(t *testing.T) {
|
||||
ph, err := HashPasswordV5("test", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, strings.HasPrefix(string(ph), "5|"))
|
||||
tst.AssertTrue(t, strings.HasSuffix(string(ph), "|0"))
|
||||
}
|
||||
|
||||
func TestPassHashV5VerifyLongPassword(t *testing.T) {
|
||||
// V5 hashes via sha512 first → bcrypt's 72-byte limit shouldn't apply
|
||||
longPw := strings.Repeat("a", 200)
|
||||
ph, err := HashPasswordV5(longPw, nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertTrue(t, ph.Verify(longPw, nil))
|
||||
tst.AssertFalse(t, ph.Verify(longPw+"x", nil))
|
||||
}
|
||||
|
||||
func TestPassHashV5DifferentEachCall(t *testing.T) {
|
||||
ph1, err := HashPasswordV5("samepw", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
ph2, err := HashPasswordV5("samepw", nil)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
// Bcrypt salts internally — same password should produce different hashes
|
||||
tst.AssertNotEqual(t, string(ph1), string(ph2))
|
||||
|
||||
// Both must verify
|
||||
tst.AssertTrue(t, ph1.Verify("samepw", nil))
|
||||
tst.AssertTrue(t, ph2.Verify("samepw", nil))
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
package cryptext
|
||||
|
||||
import (
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/totpext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/totpext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -32,7 +31,7 @@ func TestPassHashTOTP(t *testing.T) {
|
||||
|
||||
tst.AssertFalse(t, ph.Verify("test123", nil))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
|
||||
tst.AssertTrue(t, ph.Verify("test123", new(totpext.TOTP(sec))))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
}
|
||||
|
||||
@@ -141,7 +140,7 @@ func TestPassHashUpgrade_V3_TOTP(t *testing.T) {
|
||||
|
||||
tst.AssertFalse(t, ph.Verify("test123", nil))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
|
||||
tst.AssertTrue(t, ph.Verify("test123", new(totpext.TOTP(sec))))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
|
||||
ph, err = ph.Upgrade("test123")
|
||||
@@ -153,7 +152,7 @@ func TestPassHashUpgrade_V3_TOTP(t *testing.T) {
|
||||
|
||||
tst.AssertFalse(t, ph.Verify("test123", nil))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
|
||||
tst.AssertTrue(t, ph.Verify("test123", new(totpext.TOTP(sec))))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
}
|
||||
|
||||
@@ -193,7 +192,7 @@ func TestPassHashUpgrade_V4_TOTP(t *testing.T) {
|
||||
|
||||
tst.AssertFalse(t, ph.Verify("test123", nil))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
|
||||
tst.AssertTrue(t, ph.Verify("test123", new(totpext.TOTP(sec))))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
|
||||
ph, err = ph.Upgrade("test123")
|
||||
@@ -205,6 +204,6 @@ func TestPassHashUpgrade_V4_TOTP(t *testing.T) {
|
||||
|
||||
tst.AssertFalse(t, ph.Verify("test123", nil))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
tst.AssertTrue(t, ph.Verify("test123", langext.Ptr(totpext.TOTP(sec))))
|
||||
tst.AssertTrue(t, ph.Verify("test123", new(totpext.TOTP(sec))))
|
||||
tst.AssertFalse(t, ph.Verify("test124", nil))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
package cryptext
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
mathrand "math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func TestPronouncablePasswordLength(t *testing.T) {
|
||||
for _, n := range []int{1, 2, 3, 5, 8, 13, 21, 50, 128} {
|
||||
pw := PronouncablePassword(n)
|
||||
tst.AssertEqual(t, len(pw), n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordZeroOrNegative(t *testing.T) {
|
||||
tst.AssertEqual(t, PronouncablePassword(0), "")
|
||||
tst.AssertEqual(t, PronouncablePassword(-1), "")
|
||||
tst.AssertEqual(t, PronouncablePassword(-1000), "")
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordSeededDeterministic(t *testing.T) {
|
||||
pw1 := PronouncablePasswordSeeded(42, 16)
|
||||
pw2 := PronouncablePasswordSeeded(42, 16)
|
||||
tst.AssertEqual(t, pw1, pw2)
|
||||
tst.AssertEqual(t, len(pw1), 16)
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordSeededDifferentSeeds(t *testing.T) {
|
||||
pw1 := PronouncablePasswordSeeded(1, 16)
|
||||
pw2 := PronouncablePasswordSeeded(2, 16)
|
||||
tst.AssertNotEqual(t, pw1, pw2)
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordExtEntropy(t *testing.T) {
|
||||
rng := mathrand.New(mathrand.NewSource(1))
|
||||
pw, entropy := PronouncablePasswordExt(rng, 32)
|
||||
tst.AssertEqual(t, len(pw), 32)
|
||||
if entropy <= 0 {
|
||||
t.Errorf("expected positive entropy, got %f", entropy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordExtZeroLen(t *testing.T) {
|
||||
rng := mathrand.New(mathrand.NewSource(1))
|
||||
pw, entropy := PronouncablePasswordExt(rng, 0)
|
||||
tst.AssertEqual(t, pw, "")
|
||||
tst.AssertEqual(t, entropy, float64(0))
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordCharacters(t *testing.T) {
|
||||
// Output should be only ASCII letters
|
||||
for i := range 50 {
|
||||
pw := PronouncablePasswordSeeded(int64(i), 32)
|
||||
for _, c := range pw {
|
||||
if !unicode.IsLetter(c) || c > unicode.MaxASCII {
|
||||
t.Errorf("non-letter or non-ASCII rune %q in password %q", c, pw)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordStartsUpper(t *testing.T) {
|
||||
for i := range 50 {
|
||||
pw := PronouncablePasswordSeeded(int64(i), 16)
|
||||
if pw == "" {
|
||||
continue
|
||||
}
|
||||
first := rune(pw[0])
|
||||
if !unicode.IsUpper(first) {
|
||||
t.Errorf("expected first letter uppercase in %q (seed %d)", pw, i)
|
||||
}
|
||||
if !strings.ContainsRune(ppStartChar, first) {
|
||||
t.Errorf("expected first letter from start-set in %q (seed %d)", pw, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPpMakeSet(t *testing.T) {
|
||||
set := ppMakeSet("ABC")
|
||||
tst.AssertTrue(t, set['A'])
|
||||
tst.AssertTrue(t, set['B'])
|
||||
tst.AssertTrue(t, set['C'])
|
||||
tst.AssertFalse(t, set['D'])
|
||||
tst.AssertEqual(t, len(set), 3)
|
||||
}
|
||||
|
||||
func TestPpMakeSetEmpty(t *testing.T) {
|
||||
set := ppMakeSet("")
|
||||
tst.AssertEqual(t, len(set), 0)
|
||||
}
|
||||
|
||||
func TestPpCharType(t *testing.T) {
|
||||
v, c := ppCharType('A')
|
||||
tst.AssertTrue(t, v)
|
||||
tst.AssertFalse(t, c)
|
||||
|
||||
v, c = ppCharType('B')
|
||||
tst.AssertFalse(t, v)
|
||||
tst.AssertTrue(t, c)
|
||||
|
||||
v, c = ppCharType('Y')
|
||||
tst.AssertTrue(t, v)
|
||||
tst.AssertFalse(t, c)
|
||||
|
||||
v, c = ppCharType('1')
|
||||
tst.AssertFalse(t, v)
|
||||
tst.AssertFalse(t, c)
|
||||
}
|
||||
|
||||
func TestPpCharsetRemove(t *testing.T) {
|
||||
set := ppMakeSet("AEIOU")
|
||||
out := ppCharsetRemove("ABCDEFG", set, false)
|
||||
tst.AssertEqual(t, out, "BCDFG")
|
||||
}
|
||||
|
||||
func TestPpCharsetRemoveEmptyDisallowed(t *testing.T) {
|
||||
set := ppMakeSet("AB")
|
||||
out := ppCharsetRemove("AB", set, false)
|
||||
// when result would be empty and allowEmpty=false, it returns the original
|
||||
tst.AssertEqual(t, out, "AB")
|
||||
}
|
||||
|
||||
func TestPpCharsetRemoveEmptyAllowed(t *testing.T) {
|
||||
set := ppMakeSet("AB")
|
||||
out := ppCharsetRemove("AB", set, true)
|
||||
tst.AssertEqual(t, out, "")
|
||||
}
|
||||
|
||||
func TestPpCharsetFilter(t *testing.T) {
|
||||
set := ppMakeSet("AEIOU")
|
||||
out := ppCharsetFilter("ABCDEFG", set, false)
|
||||
tst.AssertEqual(t, out, "AE")
|
||||
}
|
||||
|
||||
func TestPpCharsetFilterEmptyDisallowed(t *testing.T) {
|
||||
set := ppMakeSet("XYZ")
|
||||
out := ppCharsetFilter("ABC", set, false)
|
||||
tst.AssertEqual(t, out, "ABC") // returns original when result empty & not allowed
|
||||
}
|
||||
|
||||
func TestPpCharsetFilterEmptyAllowed(t *testing.T) {
|
||||
set := ppMakeSet("XYZ")
|
||||
out := ppCharsetFilter("ABC", set, true)
|
||||
tst.AssertEqual(t, out, "")
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordContinuationFollowsRules(t *testing.T) {
|
||||
// Make sure each continuation pair (lowercased) appears in ppContinuation
|
||||
// Note: when a new segment starts (uppercase letter mid-string), the continuation
|
||||
// check does not apply across the segment boundary.
|
||||
for s := range 30 {
|
||||
seed := int64(s)
|
||||
pw := PronouncablePasswordSeeded(seed, 32)
|
||||
if len(pw) < 2 {
|
||||
continue
|
||||
}
|
||||
runes := []byte(strings.ToUpper(pw))
|
||||
for i := 1; i < len(runes); i++ {
|
||||
// Detect new segment (original char was uppercase and it's not the first char)
|
||||
origUpper := pw[i] >= 'A' && pw[i] <= 'Z'
|
||||
if origUpper && i > 0 {
|
||||
continue
|
||||
}
|
||||
prev := runes[i-1]
|
||||
cur := runes[i]
|
||||
cont, ok := ppContinuation[prev]
|
||||
if !ok {
|
||||
t.Errorf("no continuation map for %q (pw=%q)", prev, pw)
|
||||
continue
|
||||
}
|
||||
if !strings.ContainsRune(cont, rune(cur)) {
|
||||
t.Errorf("invalid continuation %q -> %q in %q (seed %d)", prev, cur, pw, seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,21 +7,21 @@ import (
|
||||
)
|
||||
|
||||
func TestPronouncablePasswordExt(t *testing.T) {
|
||||
for i := 0; i < 20; i++ {
|
||||
for i := range 20 {
|
||||
pw, entropy := PronouncablePasswordExt(rand.New(rand.NewSource(int64(i))), 16)
|
||||
fmt.Printf("[%.2f] => %s\n", entropy, pw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPronouncablePasswordSeeded(t *testing.T) {
|
||||
for i := 0; i < 20; i++ {
|
||||
for i := range 20 {
|
||||
pw := PronouncablePasswordSeeded(int64(i), 8)
|
||||
fmt.Printf("%s\n", pw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPronouncablePassword(t *testing.T) {
|
||||
for i := 0; i < 20; i++ {
|
||||
for i := range 20 {
|
||||
pw := PronouncablePassword(i + 1)
|
||||
fmt.Printf("%s\n", pw)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package ctxext
|
||||
|
||||
import "context"
|
||||
|
||||
func Value[T any](ctx context.Context, key any) (T, bool) {
|
||||
v := ctx.Value(key)
|
||||
if v == nil {
|
||||
return *new(T), false
|
||||
}
|
||||
if tv, ok := v.(T); !ok {
|
||||
return *new(T), false
|
||||
} else {
|
||||
return tv, true
|
||||
}
|
||||
}
|
||||
|
||||
func ValueOrDefault[T any](ctx context.Context, key any, def T) T {
|
||||
v := ctx.Value(key)
|
||||
if v == nil {
|
||||
return def
|
||||
}
|
||||
if tv, ok := v.(T); !ok {
|
||||
return def
|
||||
} else {
|
||||
return tv
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package ctxext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type ctxKey string
|
||||
|
||||
const (
|
||||
keyString ctxKey = "string-key"
|
||||
keyInt ctxKey = "int-key"
|
||||
keyStruct ctxKey = "struct-key"
|
||||
keyPtr ctxKey = "ptr-key"
|
||||
keyMissing ctxKey = "missing-key"
|
||||
)
|
||||
|
||||
type sampleStruct struct {
|
||||
Name string
|
||||
N int
|
||||
}
|
||||
|
||||
func TestValueStringPresent(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "hello")
|
||||
v, ok := Value[string](ctx, keyString)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v, "hello")
|
||||
}
|
||||
|
||||
func TestValueIntPresent(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyInt, 42)
|
||||
v, ok := Value[int](ctx, keyInt)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v, 42)
|
||||
}
|
||||
|
||||
func TestValueStructPresent(t *testing.T) {
|
||||
want := sampleStruct{Name: "abc", N: 7}
|
||||
ctx := context.WithValue(context.Background(), keyStruct, want)
|
||||
v, ok := Value[sampleStruct](ctx, keyStruct)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v.Name, "abc")
|
||||
tst.AssertEqual(t, v.N, 7)
|
||||
}
|
||||
|
||||
func TestValuePointerPresent(t *testing.T) {
|
||||
want := &sampleStruct{Name: "ptr", N: 99}
|
||||
ctx := context.WithValue(context.Background(), keyPtr, want)
|
||||
v, ok := Value[*sampleStruct](ctx, keyPtr)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v == want, true)
|
||||
tst.AssertEqual(t, v.Name, "ptr")
|
||||
}
|
||||
|
||||
func TestValueMissing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v, ok := Value[string](ctx, keyMissing)
|
||||
tst.AssertEqual(t, ok, false)
|
||||
tst.AssertEqual(t, v, "")
|
||||
}
|
||||
|
||||
func TestValueMissingInt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v, ok := Value[int](ctx, keyMissing)
|
||||
tst.AssertEqual(t, ok, false)
|
||||
tst.AssertEqual(t, v, 0)
|
||||
}
|
||||
|
||||
func TestValueMissingStruct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v, ok := Value[sampleStruct](ctx, keyMissing)
|
||||
tst.AssertEqual(t, ok, false)
|
||||
tst.AssertEqual(t, v.Name, "")
|
||||
tst.AssertEqual(t, v.N, 0)
|
||||
}
|
||||
|
||||
func TestValueMissingPointer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v, ok := Value[*sampleStruct](ctx, keyMissing)
|
||||
tst.AssertEqual(t, ok, false)
|
||||
tst.AssertEqual(t, v == nil, true)
|
||||
}
|
||||
|
||||
func TestValueWrongType(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "hello")
|
||||
v, ok := Value[int](ctx, keyString)
|
||||
tst.AssertEqual(t, ok, false)
|
||||
tst.AssertEqual(t, v, 0)
|
||||
}
|
||||
|
||||
func TestValueWrongTypeStructToString(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyStruct, sampleStruct{Name: "x"})
|
||||
v, ok := Value[string](ctx, keyStruct)
|
||||
tst.AssertEqual(t, ok, false)
|
||||
tst.AssertEqual(t, v, "")
|
||||
}
|
||||
|
||||
func TestValueNilStoredAsInterface(t *testing.T) {
|
||||
var stored *sampleStruct = nil
|
||||
ctx := context.WithValue(context.Background(), keyPtr, stored)
|
||||
v, ok := Value[*sampleStruct](ctx, keyPtr)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v == nil, true)
|
||||
}
|
||||
|
||||
func TestValueEmptyString(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "")
|
||||
v, ok := Value[string](ctx, keyString)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v, "")
|
||||
}
|
||||
|
||||
func TestValueZeroInt(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyInt, 0)
|
||||
v, ok := Value[int](ctx, keyInt)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v, 0)
|
||||
}
|
||||
|
||||
func TestValueWithStringKey(t *testing.T) {
|
||||
type stringKey string
|
||||
k := stringKey("my-key")
|
||||
ctx := context.WithValue(context.Background(), k, "value")
|
||||
v, ok := Value[string](ctx, k)
|
||||
tst.AssertEqual(t, ok, true)
|
||||
tst.AssertEqual(t, v, "value")
|
||||
}
|
||||
|
||||
func TestValueOrDefaultPresent(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "hello")
|
||||
v := ValueOrDefault(ctx, keyString, "default")
|
||||
tst.AssertEqual(t, v, "hello")
|
||||
}
|
||||
|
||||
func TestValueOrDefaultIntPresent(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyInt, 42)
|
||||
v := ValueOrDefault(ctx, keyInt, -1)
|
||||
tst.AssertEqual(t, v, 42)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultMissing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := ValueOrDefault(ctx, keyMissing, "default")
|
||||
tst.AssertEqual(t, v, "default")
|
||||
}
|
||||
|
||||
func TestValueOrDefaultMissingInt(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
v := ValueOrDefault(ctx, keyMissing, 99)
|
||||
tst.AssertEqual(t, v, 99)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultMissingStruct(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
def := sampleStruct{Name: "default", N: 1}
|
||||
v := ValueOrDefault(ctx, keyMissing, def)
|
||||
tst.AssertEqual(t, v.Name, "default")
|
||||
tst.AssertEqual(t, v.N, 1)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultWrongType(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "hello")
|
||||
v := ValueOrDefault(ctx, keyString, 7)
|
||||
tst.AssertEqual(t, v, 7)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultWrongTypeStruct(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyStruct, sampleStruct{Name: "x"})
|
||||
def := "fallback"
|
||||
v := ValueOrDefault(ctx, keyStruct, def)
|
||||
tst.AssertEqual(t, v, "fallback")
|
||||
}
|
||||
|
||||
func TestValueOrDefaultEmptyStringStored(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "")
|
||||
v := ValueOrDefault(ctx, keyString, "default")
|
||||
tst.AssertEqual(t, v, "")
|
||||
}
|
||||
|
||||
func TestValueOrDefaultZeroIntStored(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyInt, 0)
|
||||
v := ValueOrDefault(ctx, keyInt, 99)
|
||||
tst.AssertEqual(t, v, 0)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultPointerPresent(t *testing.T) {
|
||||
want := &sampleStruct{Name: "p", N: 5}
|
||||
ctx := context.WithValue(context.Background(), keyPtr, want)
|
||||
def := &sampleStruct{Name: "def", N: 0}
|
||||
v := ValueOrDefault(ctx, keyPtr, def)
|
||||
tst.AssertEqual(t, v == want, true)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultPointerMissing(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
def := &sampleStruct{Name: "def", N: 0}
|
||||
v := ValueOrDefault(ctx, keyMissing, def)
|
||||
tst.AssertEqual(t, v == def, true)
|
||||
}
|
||||
|
||||
func TestValueOrDefaultNilPointerStored(t *testing.T) {
|
||||
var stored *sampleStruct = nil
|
||||
ctx := context.WithValue(context.Background(), keyPtr, stored)
|
||||
def := &sampleStruct{Name: "def"}
|
||||
v := ValueOrDefault(ctx, keyPtr, def)
|
||||
tst.AssertEqual(t, v == nil, true)
|
||||
}
|
||||
|
||||
func TestValueNestedContext(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), keyString, "outer")
|
||||
ctx = context.WithValue(ctx, keyInt, 123)
|
||||
ctx = context.WithValue(ctx, keyString, "inner")
|
||||
|
||||
vs, oks := Value[string](ctx, keyString)
|
||||
tst.AssertEqual(t, oks, true)
|
||||
tst.AssertEqual(t, vs, "inner")
|
||||
|
||||
vi, oki := Value[int](ctx, keyInt)
|
||||
tst.AssertEqual(t, oki, true)
|
||||
tst.AssertEqual(t, vi, 123)
|
||||
}
|
||||
|
||||
func TestValueDifferentKeyTypesDoNotCollide(t *testing.T) {
|
||||
type keyA string
|
||||
type keyB string
|
||||
ctx := context.WithValue(context.Background(), keyA("k"), "a-val")
|
||||
ctx = context.WithValue(ctx, keyB("k"), "b-val")
|
||||
|
||||
va, oka := Value[string](ctx, keyA("k"))
|
||||
tst.AssertEqual(t, oka, true)
|
||||
tst.AssertEqual(t, va, "a-val")
|
||||
|
||||
vb, okb := Value[string](ctx, keyB("k"))
|
||||
tst.AssertEqual(t, okb, true)
|
||||
tst.AssertEqual(t, vb, "b-val")
|
||||
}
|
||||
@@ -6,3 +6,13 @@ const (
|
||||
SortASC SortDirection = "ASC"
|
||||
SortDESC SortDirection = "DESC"
|
||||
)
|
||||
|
||||
func (sd SortDirection) ToMongo() int {
|
||||
if sd == SortASC {
|
||||
return 1
|
||||
} else if sd == SortDESC {
|
||||
return -1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package cursortoken
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSortDirectionToMongoASC(t *testing.T) {
|
||||
tst.AssertEqual(t, SortASC.ToMongo(), 1)
|
||||
}
|
||||
|
||||
func TestSortDirectionToMongoDESC(t *testing.T) {
|
||||
tst.AssertEqual(t, SortDESC.ToMongo(), -1)
|
||||
}
|
||||
|
||||
func TestSortDirectionToMongoEmpty(t *testing.T) {
|
||||
var sd SortDirection
|
||||
tst.AssertEqual(t, sd.ToMongo(), 0)
|
||||
}
|
||||
|
||||
func TestSortDirectionToMongoUnknown(t *testing.T) {
|
||||
sd := SortDirection("xyz")
|
||||
tst.AssertEqual(t, sd.ToMongo(), 0)
|
||||
}
|
||||
|
||||
func TestSortDirectionConstants(t *testing.T) {
|
||||
tst.AssertEqual(t, string(SortASC), "ASC")
|
||||
tst.AssertEqual(t, string(SortDESC), "DESC")
|
||||
}
|
||||
@@ -2,7 +2,8 @@ package cursortoken
|
||||
|
||||
import (
|
||||
"context"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
)
|
||||
|
||||
type RawFilter interface {
|
||||
|
||||
+51
-139
@@ -3,12 +3,19 @@ package cursortoken
|
||||
import (
|
||||
"encoding/base32"
|
||||
"encoding/json"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/exerr"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/exerr"
|
||||
)
|
||||
|
||||
type CursorToken interface {
|
||||
Token() string
|
||||
IsStart() bool
|
||||
IsEnd() bool
|
||||
}
|
||||
|
||||
type Mode string
|
||||
|
||||
const (
|
||||
@@ -24,97 +31,6 @@ type Extra struct {
|
||||
PageSize *int
|
||||
}
|
||||
|
||||
type CursorToken struct {
|
||||
Mode Mode
|
||||
ValuePrimary string
|
||||
ValueSecondary string
|
||||
Direction SortDirection
|
||||
DirectionSecondary SortDirection
|
||||
PageSize int
|
||||
Extra Extra
|
||||
}
|
||||
|
||||
type cursorTokenSerialize struct {
|
||||
ValuePrimary *string `json:"v1,omitempty"`
|
||||
ValueSecondary *string `json:"v2,omitempty"`
|
||||
Direction *SortDirection `json:"dir,omitempty"`
|
||||
DirectionSecondary *SortDirection `json:"dir2,omitempty"`
|
||||
PageSize *int `json:"size,omitempty"`
|
||||
|
||||
ExtraTimestamp *time.Time `json:"ts,omitempty"`
|
||||
ExtraId *string `json:"id,omitempty"`
|
||||
ExtraPage *int `json:"pg,omitempty"`
|
||||
ExtraPageSize *int `json:"sz,omitempty"`
|
||||
}
|
||||
|
||||
func Start() CursorToken {
|
||||
return CursorToken{
|
||||
Mode: CTMStart,
|
||||
ValuePrimary: "",
|
||||
ValueSecondary: "",
|
||||
Direction: "",
|
||||
DirectionSecondary: "",
|
||||
PageSize: 0,
|
||||
Extra: Extra{},
|
||||
}
|
||||
}
|
||||
|
||||
func End() CursorToken {
|
||||
return CursorToken{
|
||||
Mode: CTMEnd,
|
||||
ValuePrimary: "",
|
||||
ValueSecondary: "",
|
||||
Direction: "",
|
||||
DirectionSecondary: "",
|
||||
PageSize: 0,
|
||||
Extra: Extra{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CursorToken) Token() string {
|
||||
if c.Mode == CTMStart {
|
||||
return "@start"
|
||||
}
|
||||
if c.Mode == CTMEnd {
|
||||
return "@end"
|
||||
}
|
||||
|
||||
// We kinda manually implement omitempty for the CursorToken here
|
||||
// because omitempty does not work for time.Time and otherwise we would always
|
||||
// get weird time values when decoding a token that initially didn't have an Timestamp set
|
||||
// For this usecase we treat Unix=0 as an empty timestamp
|
||||
|
||||
sertok := cursorTokenSerialize{}
|
||||
|
||||
if c.ValuePrimary != "" {
|
||||
sertok.ValuePrimary = &c.ValuePrimary
|
||||
}
|
||||
if c.ValueSecondary != "" {
|
||||
sertok.ValueSecondary = &c.ValueSecondary
|
||||
}
|
||||
if c.Direction != "" {
|
||||
sertok.Direction = &c.Direction
|
||||
}
|
||||
if c.DirectionSecondary != "" {
|
||||
sertok.DirectionSecondary = &c.DirectionSecondary
|
||||
}
|
||||
if c.PageSize != 0 {
|
||||
sertok.PageSize = &c.PageSize
|
||||
}
|
||||
|
||||
sertok.ExtraTimestamp = c.Extra.Timestamp
|
||||
sertok.ExtraId = c.Extra.Id
|
||||
sertok.ExtraPage = c.Extra.Page
|
||||
sertok.ExtraPageSize = c.Extra.PageSize
|
||||
|
||||
body, err := json.Marshal(sertok)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return "tok_" + base32.StdEncoding.EncodeToString(body)
|
||||
}
|
||||
|
||||
func Decode(tok string) (CursorToken, error) {
|
||||
if tok == "" {
|
||||
return Start(), nil
|
||||
@@ -125,60 +41,56 @@ func Decode(tok string) (CursorToken, error) {
|
||||
if strings.ToLower(tok) == "@end" {
|
||||
return End(), nil
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(tok, "tok_") {
|
||||
return CursorToken{}, exerr.New(exerr.TypeCursorTokenDecode, "could not decode token, missing prefix").Str("token", tok).Build()
|
||||
if strings.ToLower(tok) == "$end" {
|
||||
return PageEnd(), nil
|
||||
}
|
||||
if strings.HasPrefix(tok, "$") && len(tok) > 1 {
|
||||
n, err := strconv.ParseInt(tok[1:], 10, 64)
|
||||
if err != nil {
|
||||
return nil, exerr.Wrap(err, "failed to deserialize token").Str("token", tok).WithType(exerr.TypeCursorTokenDecode).Build()
|
||||
}
|
||||
return Page(int(n)), nil
|
||||
}
|
||||
|
||||
body, err := base32.StdEncoding.DecodeString(tok[len("tok_"):])
|
||||
if err != nil {
|
||||
return CursorToken{}, err
|
||||
}
|
||||
if strings.HasPrefix(tok, "tok_") {
|
||||
|
||||
var tokenDeserialize cursorTokenSerialize
|
||||
err = json.Unmarshal(body, &tokenDeserialize)
|
||||
if err != nil {
|
||||
return CursorToken{}, exerr.Wrap(err, "failed to deserialize token").Str("token", tok).Build()
|
||||
}
|
||||
body, err := base32.StdEncoding.DecodeString(tok[len("tok_"):])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token := CursorToken{Mode: CTMNormal}
|
||||
var tokenDeserialize cursorTokenKeySortSerialize
|
||||
err = json.Unmarshal(body, &tokenDeserialize)
|
||||
if err != nil {
|
||||
return nil, exerr.Wrap(err, "failed to deserialize token").Str("token", tok).WithType(exerr.TypeCursorTokenDecode).Build()
|
||||
}
|
||||
|
||||
if tokenDeserialize.ValuePrimary != nil {
|
||||
token.ValuePrimary = *tokenDeserialize.ValuePrimary
|
||||
}
|
||||
if tokenDeserialize.ValueSecondary != nil {
|
||||
token.ValueSecondary = *tokenDeserialize.ValueSecondary
|
||||
}
|
||||
if tokenDeserialize.Direction != nil {
|
||||
token.Direction = *tokenDeserialize.Direction
|
||||
}
|
||||
if tokenDeserialize.DirectionSecondary != nil {
|
||||
token.DirectionSecondary = *tokenDeserialize.DirectionSecondary
|
||||
}
|
||||
if tokenDeserialize.PageSize != nil {
|
||||
token.PageSize = *tokenDeserialize.PageSize
|
||||
}
|
||||
token := CTKeySort{Mode: CTMNormal}
|
||||
|
||||
token.Extra.Timestamp = tokenDeserialize.ExtraTimestamp
|
||||
token.Extra.Id = tokenDeserialize.ExtraId
|
||||
token.Extra.Page = tokenDeserialize.ExtraPage
|
||||
token.Extra.PageSize = tokenDeserialize.ExtraPageSize
|
||||
if tokenDeserialize.ValuePrimary != nil {
|
||||
token.ValuePrimary = *tokenDeserialize.ValuePrimary
|
||||
}
|
||||
if tokenDeserialize.ValueSecondary != nil {
|
||||
token.ValueSecondary = *tokenDeserialize.ValueSecondary
|
||||
}
|
||||
if tokenDeserialize.Direction != nil {
|
||||
token.Direction = *tokenDeserialize.Direction
|
||||
}
|
||||
if tokenDeserialize.DirectionSecondary != nil {
|
||||
token.DirectionSecondary = *tokenDeserialize.DirectionSecondary
|
||||
}
|
||||
if tokenDeserialize.PageSize != nil {
|
||||
token.PageSize = *tokenDeserialize.PageSize
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
token.Extra.Timestamp = tokenDeserialize.ExtraTimestamp
|
||||
token.Extra.Id = tokenDeserialize.ExtraId
|
||||
token.Extra.Page = tokenDeserialize.ExtraPage
|
||||
token.Extra.PageSize = tokenDeserialize.ExtraPageSize
|
||||
|
||||
return token, nil
|
||||
|
||||
func (c *CursorToken) ValuePrimaryObjectId() (primitive.ObjectID, bool) {
|
||||
if oid, err := primitive.ObjectIDFromHex(c.ValuePrimary); err == nil {
|
||||
return oid, true
|
||||
} else {
|
||||
return primitive.ObjectID{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CursorToken) ValueSecondaryObjectId() (primitive.ObjectID, bool) {
|
||||
if oid, err := primitive.ObjectIDFromHex(c.ValueSecondary); err == nil {
|
||||
return oid, true
|
||||
} else {
|
||||
return primitive.ObjectID{}, false
|
||||
return nil, exerr.New(exerr.TypeCursorTokenDecode, "could not decode token, missing/unknown prefix").Str("token", tok).Build()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
package cursortoken
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
type CTKeySort struct {
|
||||
Mode Mode
|
||||
ValuePrimary string
|
||||
ValueSecondary string
|
||||
Direction SortDirection
|
||||
DirectionSecondary SortDirection
|
||||
PageSize int
|
||||
Extra Extra
|
||||
}
|
||||
|
||||
type cursorTokenKeySortSerialize struct {
|
||||
ValuePrimary *string `json:"v1,omitempty"`
|
||||
ValueSecondary *string `json:"v2,omitempty"`
|
||||
Direction *SortDirection `json:"dir,omitempty"`
|
||||
DirectionSecondary *SortDirection `json:"dir2,omitempty"`
|
||||
PageSize *int `json:"size,omitempty"`
|
||||
|
||||
ExtraTimestamp *time.Time `json:"ts,omitempty"`
|
||||
ExtraId *string `json:"id,omitempty"`
|
||||
ExtraPage *int `json:"pg,omitempty"`
|
||||
ExtraPageSize *int `json:"sz,omitempty"`
|
||||
}
|
||||
|
||||
func NewKeySortToken(valuePrimary string, valueSecondary string, direction SortDirection, directionSecondary SortDirection, pageSize int, extra Extra) CursorToken {
|
||||
return CTKeySort{
|
||||
Mode: CTMNormal,
|
||||
ValuePrimary: valuePrimary,
|
||||
ValueSecondary: valueSecondary,
|
||||
Direction: direction,
|
||||
DirectionSecondary: directionSecondary,
|
||||
PageSize: pageSize,
|
||||
Extra: extra,
|
||||
}
|
||||
}
|
||||
|
||||
func Start() CursorToken {
|
||||
return CTKeySort{
|
||||
Mode: CTMStart,
|
||||
ValuePrimary: "",
|
||||
ValueSecondary: "",
|
||||
Direction: "",
|
||||
DirectionSecondary: "",
|
||||
PageSize: 0,
|
||||
Extra: Extra{},
|
||||
}
|
||||
}
|
||||
|
||||
func End() CursorToken {
|
||||
return CTKeySort{
|
||||
Mode: CTMEnd,
|
||||
ValuePrimary: "",
|
||||
ValueSecondary: "",
|
||||
Direction: "",
|
||||
DirectionSecondary: "",
|
||||
PageSize: 0,
|
||||
Extra: Extra{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c CTKeySort) Token() string {
|
||||
|
||||
if c.Mode == CTMStart {
|
||||
return "@start"
|
||||
}
|
||||
if c.Mode == CTMEnd {
|
||||
return "@end"
|
||||
}
|
||||
|
||||
// We kinda manually implement omitempty for the CursorToken here
|
||||
// because omitempty does not work for time.Time and otherwise we would always
|
||||
// get weird time values when decoding a token that initially didn't have an Timestamp set
|
||||
// For this usecase we treat Unix=0 as an empty timestamp
|
||||
|
||||
sertok := cursorTokenKeySortSerialize{}
|
||||
|
||||
if c.ValuePrimary != "" {
|
||||
sertok.ValuePrimary = &c.ValuePrimary
|
||||
}
|
||||
if c.ValueSecondary != "" {
|
||||
sertok.ValueSecondary = &c.ValueSecondary
|
||||
}
|
||||
if c.Direction != "" {
|
||||
sertok.Direction = &c.Direction
|
||||
}
|
||||
if c.DirectionSecondary != "" {
|
||||
sertok.DirectionSecondary = &c.DirectionSecondary
|
||||
}
|
||||
if c.PageSize != 0 {
|
||||
sertok.PageSize = &c.PageSize
|
||||
}
|
||||
|
||||
sertok.ExtraTimestamp = c.Extra.Timestamp
|
||||
sertok.ExtraId = c.Extra.Id
|
||||
sertok.ExtraPage = c.Extra.Page
|
||||
sertok.ExtraPageSize = c.Extra.PageSize
|
||||
|
||||
body, err := json.Marshal(sertok)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return "tok_" + base32.StdEncoding.EncodeToString(body)
|
||||
}
|
||||
|
||||
func (c CTKeySort) IsEnd() bool {
|
||||
return c.Mode == CTMEnd
|
||||
}
|
||||
|
||||
func (c CTKeySort) IsStart() bool {
|
||||
return c.Mode == CTMStart
|
||||
}
|
||||
|
||||
func (c CTKeySort) valuePrimaryObjectId() (bson.ObjectID, bool) {
|
||||
if oid, err := bson.ObjectIDFromHex(c.ValuePrimary); err == nil {
|
||||
return oid, true
|
||||
} else {
|
||||
return bson.ObjectID{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func (c CTKeySort) valueSecondaryObjectId() (bson.ObjectID, bool) {
|
||||
if oid, err := bson.ObjectIDFromHex(c.ValueSecondary); err == nil {
|
||||
return oid, true
|
||||
} else {
|
||||
return bson.ObjectID{}, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package cursortoken
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStartToken(t *testing.T) {
|
||||
tok := Start()
|
||||
tst.AssertEqual(t, tok.Token(), "@start")
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
}
|
||||
|
||||
func TestEndToken(t *testing.T) {
|
||||
tok := End()
|
||||
tst.AssertEqual(t, tok.Token(), "@end")
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
tst.AssertFalse(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestNewKeySortTokenBasic(t *testing.T) {
|
||||
tok := NewKeySortToken("alpha", "beta", SortASC, SortDESC, 50, Extra{})
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
tst.AssertFalse(t, tok.IsStart())
|
||||
str := tok.Token()
|
||||
tst.AssertTrue(t, strings.HasPrefix(str, "tok_"))
|
||||
}
|
||||
|
||||
func TestNewKeySortTokenRoundTrip(t *testing.T) {
|
||||
original := NewKeySortToken("primary-val", "secondary-val", SortASC, SortDESC, 25, Extra{})
|
||||
encoded := original.Token()
|
||||
|
||||
decoded, err := Decode(encoded)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
ks, ok := decoded.(CTKeySort)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertEqual(t, ks.ValuePrimary, "primary-val")
|
||||
tst.AssertEqual(t, ks.ValueSecondary, "secondary-val")
|
||||
tst.AssertEqual(t, ks.Direction, SortASC)
|
||||
tst.AssertEqual(t, ks.DirectionSecondary, SortDESC)
|
||||
tst.AssertEqual(t, ks.PageSize, 25)
|
||||
tst.AssertEqual(t, ks.Mode, CTMNormal)
|
||||
}
|
||||
|
||||
func TestKeySortTokenWithExtra(t *testing.T) {
|
||||
ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
id := "object-id-123"
|
||||
page := 7
|
||||
pageSize := 42
|
||||
|
||||
original := NewKeySortToken("p", "s", SortDESC, SortASC, 10, Extra{
|
||||
Timestamp: &ts,
|
||||
Id: &id,
|
||||
Page: &page,
|
||||
PageSize: &pageSize,
|
||||
})
|
||||
encoded := original.Token()
|
||||
|
||||
decoded, err := Decode(encoded)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
ks, ok := decoded.(CTKeySort)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertTrue(t, ks.Extra.Timestamp != nil)
|
||||
tst.AssertTrue(t, ks.Extra.Timestamp.Equal(ts))
|
||||
tst.AssertDeRefEqual(t, ks.Extra.Id, "object-id-123")
|
||||
tst.AssertDeRefEqual(t, ks.Extra.Page, 7)
|
||||
tst.AssertDeRefEqual(t, ks.Extra.PageSize, 42)
|
||||
}
|
||||
|
||||
func TestKeySortTokenStartRoundTrip(t *testing.T) {
|
||||
original := Start()
|
||||
decoded, err := Decode(original.Token())
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, decoded.IsStart())
|
||||
tst.AssertFalse(t, decoded.IsEnd())
|
||||
}
|
||||
|
||||
func TestKeySortTokenEndRoundTrip(t *testing.T) {
|
||||
original := End()
|
||||
decoded, err := Decode(original.Token())
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, decoded.IsEnd())
|
||||
tst.AssertFalse(t, decoded.IsStart())
|
||||
}
|
||||
|
||||
func TestKeySortTokenEmptyValues(t *testing.T) {
|
||||
tok := CTKeySort{Mode: CTMNormal}
|
||||
encoded := tok.Token()
|
||||
tst.AssertTrue(t, strings.HasPrefix(encoded, "tok_"))
|
||||
|
||||
decoded, err := Decode(encoded)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
ks, ok := decoded.(CTKeySort)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertEqual(t, ks.ValuePrimary, "")
|
||||
tst.AssertEqual(t, ks.ValueSecondary, "")
|
||||
tst.AssertEqual(t, ks.Direction, SortDirection(""))
|
||||
tst.AssertEqual(t, ks.DirectionSecondary, SortDirection(""))
|
||||
tst.AssertEqual(t, ks.PageSize, 0)
|
||||
}
|
||||
|
||||
func TestKeySortTokenOnlyTimestamp(t *testing.T) {
|
||||
ts := time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
tok := CTKeySort{
|
||||
Mode: CTMNormal,
|
||||
Extra: Extra{Timestamp: &ts},
|
||||
}
|
||||
|
||||
decoded, err := Decode(tok.Token())
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
ks, ok := decoded.(CTKeySort)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertTrue(t, ks.Extra.Timestamp != nil)
|
||||
tst.AssertTrue(t, ks.Extra.Timestamp.Equal(ts))
|
||||
tst.AssertTrue(t, ks.Extra.Id == nil)
|
||||
tst.AssertTrue(t, ks.Extra.Page == nil)
|
||||
tst.AssertTrue(t, ks.Extra.PageSize == nil)
|
||||
}
|
||||
|
||||
func TestKeySortTokenSpecialChars(t *testing.T) {
|
||||
original := NewKeySortToken("hello world / @!#$%", "äöü€", SortASC, SortASC, 1, Extra{})
|
||||
decoded, err := Decode(original.Token())
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
ks, ok := decoded.(CTKeySort)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertEqual(t, ks.ValuePrimary, "hello world / @!#$%")
|
||||
tst.AssertEqual(t, ks.ValueSecondary, "äöü€")
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package cursortoken
|
||||
|
||||
import "strconv"
|
||||
|
||||
type CTPaginated struct {
|
||||
Mode Mode
|
||||
Page int
|
||||
}
|
||||
|
||||
func Page(p int) CursorToken {
|
||||
return CTPaginated{
|
||||
Mode: CTMNormal,
|
||||
Page: p,
|
||||
}
|
||||
}
|
||||
|
||||
func PageEnd() CursorToken {
|
||||
return CTPaginated{
|
||||
Mode: CTMEnd,
|
||||
Page: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c CTPaginated) Token() string {
|
||||
if c.Mode == CTMStart {
|
||||
return "$1"
|
||||
}
|
||||
if c.Mode == CTMEnd {
|
||||
return "$end"
|
||||
}
|
||||
|
||||
return "$" + strconv.Itoa(c.Page)
|
||||
}
|
||||
|
||||
func (c CTPaginated) IsEnd() bool {
|
||||
return c.Mode == CTMEnd
|
||||
}
|
||||
|
||||
func (c CTPaginated) IsStart() bool {
|
||||
return c.Mode == CTMStart || c.Page == 1
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package cursortoken
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPageToken(t *testing.T) {
|
||||
tok := Page(5)
|
||||
tst.AssertEqual(t, tok.Token(), "$5")
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
tst.AssertFalse(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestPageTokenOne(t *testing.T) {
|
||||
tok := Page(1)
|
||||
tst.AssertEqual(t, tok.Token(), "$1")
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestPageTokenLarge(t *testing.T) {
|
||||
tok := Page(123456)
|
||||
tst.AssertEqual(t, tok.Token(), "$123456")
|
||||
}
|
||||
|
||||
func TestPageTokenZero(t *testing.T) {
|
||||
tok := Page(0)
|
||||
tst.AssertEqual(t, tok.Token(), "$0")
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
tst.AssertFalse(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestPageEndToken(t *testing.T) {
|
||||
tok := PageEnd()
|
||||
tst.AssertEqual(t, tok.Token(), "$end")
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
tst.AssertFalse(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestPaginatedStartMode(t *testing.T) {
|
||||
tok := CTPaginated{Mode: CTMStart, Page: 0}
|
||||
tst.AssertEqual(t, tok.Token(), "$1")
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
}
|
||||
|
||||
func TestPaginatedEndMode(t *testing.T) {
|
||||
tok := CTPaginated{Mode: CTMEnd, Page: 99}
|
||||
tst.AssertEqual(t, tok.Token(), "$end")
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
}
|
||||
|
||||
func TestPaginatedRoundTrip(t *testing.T) {
|
||||
for _, page := range []int{2, 3, 7, 100, 9999} {
|
||||
tok := Page(page)
|
||||
decoded, err := Decode(tok.Token())
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, decoded.Token(), tok.Token())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package cursortoken
|
||||
|
||||
import (
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecodeEmpty(t *testing.T) {
|
||||
tok, err := Decode("")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
tst.AssertEqual(t, tok.Token(), "@start")
|
||||
}
|
||||
|
||||
func TestDecodeAtStart(t *testing.T) {
|
||||
tok, err := Decode("@start")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
tst.AssertFalse(t, tok.IsEnd())
|
||||
}
|
||||
|
||||
func TestDecodeAtStartUppercase(t *testing.T) {
|
||||
tok, err := Decode("@START")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestDecodeAtStartMixedCase(t *testing.T) {
|
||||
tok, err := Decode("@StArT")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestDecodeAtEnd(t *testing.T) {
|
||||
tok, err := Decode("@end")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
tst.AssertFalse(t, tok.IsStart())
|
||||
}
|
||||
|
||||
func TestDecodeAtEndUppercase(t *testing.T) {
|
||||
tok, err := Decode("@END")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
}
|
||||
|
||||
func TestDecodeDollarEnd(t *testing.T) {
|
||||
tok, err := Decode("$end")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
_, ok := tok.(CTPaginated)
|
||||
tst.AssertTrue(t, ok)
|
||||
}
|
||||
|
||||
func TestDecodeDollarEndUppercase(t *testing.T) {
|
||||
tok, err := Decode("$END")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsEnd())
|
||||
}
|
||||
|
||||
func TestDecodeDollarPage(t *testing.T) {
|
||||
tok, err := Decode("$5")
|
||||
tst.AssertNoErr(t, err)
|
||||
pg, ok := tok.(CTPaginated)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertEqual(t, pg.Page, 5)
|
||||
tst.AssertEqual(t, pg.Mode, CTMNormal)
|
||||
}
|
||||
|
||||
func TestDecodeDollarPageOne(t *testing.T) {
|
||||
tok, err := Decode("$1")
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertTrue(t, tok.IsStart())
|
||||
pg, ok := tok.(CTPaginated)
|
||||
tst.AssertTrue(t, ok)
|
||||
tst.AssertEqual(t, pg.Page, 1)
|
||||
}
|
||||
|
||||
func TestDecodeDollarPageInvalid(t *testing.T) {
|
||||
_, err := Decode("$abc")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid page")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeUnknownPrefix(t *testing.T) {
|
||||
_, err := Decode("foobar")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for unknown prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInvalidBase32(t *testing.T) {
|
||||
_, err := Decode("tok_!!!")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid base32 body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeInvalidJSON(t *testing.T) {
|
||||
// "tok_" prefix with valid base32 but invalid JSON content
|
||||
_, err := Decode("tok_NBSWY3DP")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid json body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeJustDollar(t *testing.T) {
|
||||
// "$" alone (length == 1) should fall through to the unknown-prefix branch
|
||||
_, err := Decode("$")
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for bare $")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeKnownTokenContent(t *testing.T) {
|
||||
tok := NewKeySortToken("k1", "k2", SortASC, SortDESC, 33, Extra{})
|
||||
encoded := tok.Token()
|
||||
|
||||
decoded, err := Decode(encoded)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, decoded.Token(), encoded)
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/syncext"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// Broadcaster is a simple Broadcaster channel
|
||||
// This is a simpler interface over Broadcaster - which does not have distinct namespaces
|
||||
type Broadcaster[TData any] struct {
|
||||
masterLock *sync.Mutex
|
||||
|
||||
subscriptions []*broadcastSubscription[TData]
|
||||
}
|
||||
|
||||
type BroadcastSubscription interface {
|
||||
Unsubscribe()
|
||||
}
|
||||
|
||||
type broadcastSubscription[TData any] struct {
|
||||
ID string
|
||||
|
||||
parent *Broadcaster[TData]
|
||||
|
||||
subLock *sync.Mutex
|
||||
|
||||
Func func(TData)
|
||||
Chan chan TData
|
||||
|
||||
UnsubChan chan bool
|
||||
}
|
||||
|
||||
func (p *broadcastSubscription[TData]) Unsubscribe() {
|
||||
p.parent.unsubscribe(p)
|
||||
}
|
||||
|
||||
func NewBroadcaster[TData any](capacity int) *Broadcaster[TData] {
|
||||
return &Broadcaster[TData]{
|
||||
masterLock: &sync.Mutex{},
|
||||
subscriptions: make([]*broadcastSubscription[TData], 0, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
func (bb *Broadcaster[TData]) SubscriberCount() int {
|
||||
bb.masterLock.Lock()
|
||||
defer bb.masterLock.Unlock()
|
||||
|
||||
return len(bb.subscriptions)
|
||||
}
|
||||
|
||||
// Publish sends `data` to all subscriber
|
||||
// But unbuffered - if one is currently not listening, we skip (the actualReceiver < subscriber)
|
||||
func (bb *Broadcaster[TData]) Publish(data TData) (subscriber int, actualReceiver int) {
|
||||
bb.masterLock.Lock()
|
||||
subs := langext.ArrCopy(bb.subscriptions)
|
||||
bb.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
actualReceiver = 0
|
||||
|
||||
for _, sub := range subs {
|
||||
func() {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
actualReceiver++
|
||||
} else if sub.Chan != nil {
|
||||
msgSent := syncext.WriteNonBlocking(sub.Chan, data)
|
||||
if msgSent {
|
||||
actualReceiver++
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return subscriber, actualReceiver
|
||||
}
|
||||
|
||||
// PublishWithContext sends `data` to all subscriber
|
||||
// buffered - if one is currently not listening, we wait (but error out when the context runs out)
|
||||
func (bb *Broadcaster[TData]) PublishWithContext(ctx context.Context, data TData) (subscriber int, actualReceiver int, err error) {
|
||||
bb.masterLock.Lock()
|
||||
subs := langext.ArrCopy(bb.subscriptions)
|
||||
bb.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
actualReceiver = 0
|
||||
|
||||
for _, sub := range subs {
|
||||
err := func() error {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
actualReceiver++
|
||||
} else if sub.Chan != nil {
|
||||
err := syncext.WriteChannelWithContext(ctx, sub.Chan, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actualReceiver++
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return subscriber, actualReceiver, err
|
||||
}
|
||||
}
|
||||
|
||||
return subscriber, actualReceiver, nil
|
||||
}
|
||||
|
||||
// PublishWithTimeout sends `data` to all subscriber
|
||||
// buffered - if one is currently not listening, we wait (but wait at most `timeout` - if the timeout is exceeded then actualReceiver < subscriber)
|
||||
func (bb *Broadcaster[TData]) PublishWithTimeout(data TData, timeout time.Duration) (subscriber int, actualReceiver int) {
|
||||
bb.masterLock.Lock()
|
||||
subs := langext.ArrCopy(bb.subscriptions)
|
||||
bb.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
actualReceiver = 0
|
||||
|
||||
for _, sub := range subs {
|
||||
func() {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
actualReceiver++
|
||||
} else if sub.Chan != nil {
|
||||
ok := syncext.WriteChannelWithTimeout(sub.Chan, data, timeout)
|
||||
if ok {
|
||||
actualReceiver++
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return subscriber, actualReceiver
|
||||
}
|
||||
|
||||
func (bb *Broadcaster[TData]) SubscribeByCallback(fn func(TData)) BroadcastSubscription {
|
||||
bb.masterLock.Lock()
|
||||
defer bb.masterLock.Unlock()
|
||||
|
||||
sub := &broadcastSubscription[TData]{ID: xid.New().String(), parent: bb, subLock: &sync.Mutex{}, Func: fn, UnsubChan: nil}
|
||||
|
||||
bb.subscriptions = append(bb.subscriptions, sub)
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
func (bb *Broadcaster[TData]) SubscribeByChan(chanBufferSize int) (chan TData, BroadcastSubscription) {
|
||||
bb.masterLock.Lock()
|
||||
defer bb.masterLock.Unlock()
|
||||
|
||||
msgCh := make(chan TData, chanBufferSize)
|
||||
|
||||
sub := &broadcastSubscription[TData]{ID: xid.New().String(), parent: bb, subLock: &sync.Mutex{}, Chan: msgCh, UnsubChan: nil}
|
||||
|
||||
bb.subscriptions = append(bb.subscriptions, sub)
|
||||
|
||||
return msgCh, sub
|
||||
}
|
||||
|
||||
func (bb *Broadcaster[TData]) SubscribeByIter(chanBufferSize int) (iter.Seq[TData], BroadcastSubscription) {
|
||||
bb.masterLock.Lock()
|
||||
defer bb.masterLock.Unlock()
|
||||
|
||||
msgCh := make(chan TData, chanBufferSize)
|
||||
unsubChan := make(chan bool, 8)
|
||||
|
||||
sub := &broadcastSubscription[TData]{ID: xid.New().String(), parent: bb, subLock: &sync.Mutex{}, Chan: msgCh, UnsubChan: unsubChan}
|
||||
|
||||
bb.subscriptions = append(bb.subscriptions, sub)
|
||||
|
||||
iterFun := func(yield func(TData) bool) {
|
||||
for {
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
if !yield(msg) {
|
||||
sub.Unsubscribe()
|
||||
return
|
||||
}
|
||||
case <-sub.UnsubChan:
|
||||
sub.Unsubscribe()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return iterFun, sub
|
||||
}
|
||||
|
||||
func (bb *Broadcaster[TData]) unsubscribe(p *broadcastSubscription[TData]) {
|
||||
bb.masterLock.Lock()
|
||||
defer bb.masterLock.Unlock()
|
||||
|
||||
p.subLock.Lock()
|
||||
defer p.subLock.Unlock()
|
||||
|
||||
if p.Chan != nil {
|
||||
close(p.Chan)
|
||||
p.Chan = nil
|
||||
}
|
||||
if p.UnsubChan != nil {
|
||||
syncext.WriteNonBlocking(p.UnsubChan, true)
|
||||
close(p.UnsubChan)
|
||||
p.UnsubChan = nil
|
||||
}
|
||||
|
||||
bb.subscriptions = langext.ArrFilter(bb.subscriptions, func(v *broadcastSubscription[TData]) bool {
|
||||
return v.ID != p.ID
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewBroadcast(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
if bb == nil {
|
||||
t.Fatal("NewBroadcaster returned nil")
|
||||
}
|
||||
if bb.masterLock == nil {
|
||||
t.Fatal("masterLock is nil")
|
||||
}
|
||||
if bb.subscriptions == nil {
|
||||
t.Fatal("subscriptions is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_SubscribeByCallback(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
var received string
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
callback := func(msg string) {
|
||||
received = msg
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
sub := bb.SubscribeByCallback(callback)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := bb.Publish("hello")
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Wait for the callback to be executed
|
||||
wg.Wait()
|
||||
|
||||
if received != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", received)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_SubscribeByChan(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
ch, sub := bb.SubscribeByChan(1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := bb.Publish("hello")
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Read from the channel with a timeout to avoid blocking
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_SubscribeByIter(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
iterSeq, sub := bb.SubscribeByIter(1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Channel to communicate when message is received
|
||||
done := make(chan bool)
|
||||
goroutineDone := make(chan struct{})
|
||||
received := false
|
||||
|
||||
// Start a goroutine to use the iterator
|
||||
go func() {
|
||||
defer close(goroutineDone)
|
||||
for msg := range iterSeq {
|
||||
if msg == "hello" {
|
||||
received = true
|
||||
done <- true
|
||||
return // Stop iteration — triggers Unsubscribe via yield returning false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Give time for the iterator to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Publish a message
|
||||
bb.Publish("hello")
|
||||
|
||||
// Wait for the message to be received or timeout
|
||||
select {
|
||||
case <-done:
|
||||
if !received {
|
||||
t.Fatal("Message was received but not 'hello'")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Wait for the goroutine to fully exit so Unsubscribe (triggered by the
|
||||
// iterator cleanup when yield returns false) has completed.
|
||||
select {
|
||||
case <-goroutineDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for goroutine to finish")
|
||||
}
|
||||
|
||||
subCount := bb.SubscriberCount()
|
||||
if subCount != 0 {
|
||||
t.Fatalf("Expected 0 receivers, got %d", subCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcast_Publish(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
// Test publishing with no subscribers
|
||||
subs, receivers := bb.Publish("hello")
|
||||
if subs != 0 {
|
||||
t.Fatalf("Expected 0 subscribers, got %d", subs)
|
||||
}
|
||||
if receivers != 0 {
|
||||
t.Fatalf("Expected 0 receivers, got %d", receivers)
|
||||
}
|
||||
|
||||
// Add a subscriber
|
||||
ch, sub := bb.SubscribeByChan(1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers = bb.Publish("hello")
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Test non-blocking behavior with a full channel
|
||||
// First fill the channel
|
||||
bb.Publish("fill")
|
||||
|
||||
// Now publish again - this should not block but may skip the receiver
|
||||
subs, receivers = bb.Publish("overflow")
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
_ = receivers // may be 0 if channel is full
|
||||
|
||||
// Drain the channel
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestBroadcast_PublishWithTimeout(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
// Add a subscriber with a channel
|
||||
ch, sub := bb.SubscribeByChan(1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish with a timeout
|
||||
subs, receivers := bb.PublishWithTimeout("hello", 100*time.Millisecond)
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Fill the channel
|
||||
bb.Publish("fill")
|
||||
|
||||
// Test timeout behavior with a full channel
|
||||
start := time.Now()
|
||||
subs, receivers = bb.PublishWithTimeout("timeout-test", 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
// The receiver count should be 0 if the timeout occurred
|
||||
if elapsed < 50*time.Millisecond {
|
||||
t.Fatalf("Expected to wait at least 50ms, only waited %v", elapsed)
|
||||
}
|
||||
|
||||
// Drain the channel
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestBroadcast_PublishWithContext(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
// Add a subscriber with a channel
|
||||
ch, sub := bb.SubscribeByChan(1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Create a context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Publish with context
|
||||
subs, receivers, err := bb.PublishWithContext(ctx, "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Fill the channel
|
||||
bb.Publish("fill")
|
||||
|
||||
// Test context cancellation with a full channel
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Cancel the context after a short delay
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
subs, receivers, err = bb.PublishWithContext(ctx, "context-test")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
// Should get a context canceled error
|
||||
if err == nil {
|
||||
t.Fatal("Expected context canceled error, got nil")
|
||||
}
|
||||
|
||||
if elapsed < 50*time.Millisecond {
|
||||
t.Fatalf("Expected to wait at least 50ms, only waited %v", elapsed)
|
||||
}
|
||||
|
||||
// Drain the channel
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestBroadcast_Unsubscribe(t *testing.T) {
|
||||
bb := NewBroadcaster[string](10)
|
||||
|
||||
// Add a subscriber
|
||||
ch, sub := bb.SubscribeByChan(1)
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := bb.Publish("hello")
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Unsubscribe
|
||||
sub.Unsubscribe()
|
||||
|
||||
// Publish again
|
||||
subs, receivers = bb.Publish("after-unsub")
|
||||
if subs != 0 {
|
||||
t.Fatalf("Expected 0 subscribers after unsubscribe, got %d", subs)
|
||||
}
|
||||
if receivers != 0 {
|
||||
t.Fatalf("Expected 0 receivers after unsubscribe, got %d", receivers)
|
||||
}
|
||||
|
||||
// Check that the subscriber count is 0
|
||||
if bb.SubscriberCount() != 0 {
|
||||
t.Fatalf("Expected SubscriberCount() == 0, got %d", bb.SubscriberCount())
|
||||
}
|
||||
}
|
||||
@@ -115,6 +115,9 @@ func (b *bufferedReadCloser) BufferedAll() ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := b.Reset(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.buffer, nil
|
||||
|
||||
case modeSourceFinished:
|
||||
@@ -131,10 +134,22 @@ func (b *bufferedReadCloser) BufferedAll() ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the buffer to the beginning of the buffer.
|
||||
// If the original source is partially read, we will finish reading it and fill our buffer
|
||||
func (b *bufferedReadCloser) Reset() error {
|
||||
switch b.mode {
|
||||
case modeSourceReading:
|
||||
fallthrough
|
||||
if b.off == 0 {
|
||||
return nil // nobody has read anything yet
|
||||
}
|
||||
err := b.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b.mode = modeBufferReading
|
||||
b.off = 0
|
||||
return nil
|
||||
|
||||
case modeSourceFinished:
|
||||
err := b.Close()
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeReadCloser struct {
|
||||
r *bytes.Reader
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newFakeReadCloser(data []byte) *fakeReadCloser {
|
||||
return &fakeReadCloser{r: bytes.NewReader(data)}
|
||||
}
|
||||
|
||||
func (f *fakeReadCloser) Read(p []byte) (int, error) {
|
||||
return f.r.Read(p)
|
||||
}
|
||||
|
||||
func (f *fakeReadCloser) Close() error {
|
||||
f.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBufferedReadCloser_ReadAll(t *testing.T) {
|
||||
data := []byte("hello world")
|
||||
brc := NewBufferedReadCloser(newFakeReadCloser(data))
|
||||
|
||||
buf := make([]byte, 64)
|
||||
total := 0
|
||||
for {
|
||||
n, err := brc.Read(buf[total:])
|
||||
total += n
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !bytes.Equal(buf[:total], data) {
|
||||
t.Fatalf("got %q want %q", buf[:total], data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferedReadCloser_BufferedAllThenRead(t *testing.T) {
|
||||
data := []byte("foobar baz")
|
||||
brc := NewBufferedReadCloser(newFakeReadCloser(data))
|
||||
|
||||
all, err := brc.BufferedAll()
|
||||
if err != nil {
|
||||
t.Fatalf("BufferedAll err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(all, data) {
|
||||
t.Fatalf("BufferedAll got %q want %q", all, data)
|
||||
}
|
||||
|
||||
// after BufferedAll, Reset put us in BufferReading mode - we can read again
|
||||
out, err := io.ReadAll(brc)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, data) {
|
||||
t.Fatalf("ReadAll got %q want %q", out, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferedReadCloser_FullyReadResetReread(t *testing.T) {
|
||||
data := []byte("abcdefghij")
|
||||
brc := NewBufferedReadCloser(newFakeReadCloser(data))
|
||||
|
||||
out, err := io.ReadAll(brc)
|
||||
if err != nil {
|
||||
t.Fatalf("first ReadAll err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, data) {
|
||||
t.Fatalf("first read got %q want %q", out, data)
|
||||
}
|
||||
|
||||
if err := brc.Reset(); err != nil {
|
||||
t.Fatalf("reset err: %v", err)
|
||||
}
|
||||
|
||||
out2, err := io.ReadAll(brc)
|
||||
if err != nil {
|
||||
t.Fatalf("second ReadAll err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out2, data) {
|
||||
t.Fatalf("after reset got %q want %q", out2, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferedReadCloser_Close(t *testing.T) {
|
||||
data := []byte("xyz")
|
||||
inner := newFakeReadCloser(data)
|
||||
brc := NewBufferedReadCloser(inner)
|
||||
|
||||
if err := brc.Close(); err != nil {
|
||||
t.Fatalf("close err: %v", err)
|
||||
}
|
||||
if !inner.closed {
|
||||
t.Fatal("inner not closed")
|
||||
}
|
||||
|
||||
// double close should be no-op
|
||||
if err := brc.Close(); err != nil {
|
||||
t.Fatalf("second close err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferedReadCloser_ResetWithoutRead(t *testing.T) {
|
||||
data := []byte("abc")
|
||||
brc := NewBufferedReadCloser(newFakeReadCloser(data))
|
||||
|
||||
if err := brc.Reset(); err != nil {
|
||||
t.Fatalf("reset err: %v", err)
|
||||
}
|
||||
|
||||
out, err := io.ReadAll(brc)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll err: %v", err)
|
||||
}
|
||||
if !bytes.Equal(out, data) {
|
||||
t.Fatalf("got %q want %q", out, data)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCASMutex_LockUnlock(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
m.Lock()
|
||||
m.Unlock()
|
||||
}
|
||||
|
||||
func TestCASMutex_TryLock(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
if !m.TryLock() {
|
||||
t.Fatal("TryLock should succeed on fresh mutex")
|
||||
}
|
||||
if m.TryLock() {
|
||||
t.Fatal("TryLock should fail when already locked")
|
||||
}
|
||||
m.Unlock()
|
||||
if !m.TryLock() {
|
||||
t.Fatal("TryLock should succeed after Unlock")
|
||||
}
|
||||
m.Unlock()
|
||||
}
|
||||
|
||||
func TestCASMutex_TryLockWithTimeout(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
m.Lock()
|
||||
start := time.Now()
|
||||
if m.TryLockWithTimeout(20 * time.Millisecond) {
|
||||
t.Fatal("TryLockWithTimeout should fail when locked")
|
||||
}
|
||||
if time.Since(start) < 15*time.Millisecond {
|
||||
t.Fatal("TryLockWithTimeout returned too quickly")
|
||||
}
|
||||
m.Unlock()
|
||||
|
||||
if !m.TryLockWithTimeout(50 * time.Millisecond) {
|
||||
t.Fatal("TryLockWithTimeout should succeed when unlocked")
|
||||
}
|
||||
m.Unlock()
|
||||
}
|
||||
|
||||
func TestCASMutex_TryLockWithContext_Cancel(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
m.Lock()
|
||||
defer m.Unlock()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
if m.TryLockWithContext(ctx) {
|
||||
t.Fatal("expected lock to fail after cancel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCASMutex_RLockMultiple(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
if !m.RTryLock() {
|
||||
t.Fatal("RTryLock should succeed")
|
||||
}
|
||||
if !m.RTryLock() {
|
||||
t.Fatal("Second RTryLock should succeed")
|
||||
}
|
||||
if m.TryLock() {
|
||||
t.Fatal("Write TryLock should fail with read locks held")
|
||||
}
|
||||
m.RUnlock()
|
||||
m.RUnlock()
|
||||
if !m.TryLock() {
|
||||
t.Fatal("Write TryLock should succeed after read unlocks")
|
||||
}
|
||||
m.Unlock()
|
||||
}
|
||||
|
||||
func TestCASMutex_RLocker(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
rl := m.RLocker()
|
||||
rl.Lock()
|
||||
rl.Unlock()
|
||||
}
|
||||
|
||||
func TestCASMutex_Concurrent(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
var counter int64
|
||||
const n = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Lock()
|
||||
atomic.AddInt64(&counter, 1)
|
||||
m.Unlock()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if atomic.LoadInt64(&counter) != n {
|
||||
t.Fatalf("counter=%d want %d", counter, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCASMutex_RTryLockWithTimeout(t *testing.T) {
|
||||
m := NewCASMutex()
|
||||
m.Lock()
|
||||
if m.RTryLockWithTimeout(20 * time.Millisecond) {
|
||||
t.Fatal("RTryLockWithTimeout should fail when write-locked")
|
||||
}
|
||||
m.Unlock()
|
||||
if !m.RTryLockWithTimeout(20 * time.Millisecond) {
|
||||
t.Fatal("RTryLockWithTimeout should succeed when free")
|
||||
}
|
||||
m.RUnlock()
|
||||
}
|
||||
@@ -0,0 +1,199 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/syncext"
|
||||
)
|
||||
|
||||
// DelayedCombiningInvoker is a utility to combine multiple consecutive requests into a single execution
|
||||
//
|
||||
// Requests are made with Request(), and consecutive requests are combined during the `delay` period.
|
||||
//
|
||||
// Can be used, e.g., for search-controls, where we want to init the search when teh user stops typing
|
||||
// Or generally to queue an execution once a burst of requests is over.
|
||||
type DelayedCombiningInvoker struct {
|
||||
syncLock sync.Mutex
|
||||
triggerChan chan bool
|
||||
cancelChan chan bool
|
||||
execNowChan chan bool
|
||||
action func()
|
||||
delay time.Duration
|
||||
maxDelay time.Duration
|
||||
executorRunning *syncext.AtomicBool
|
||||
pendingRequests *syncext.Atomic[int]
|
||||
lastRequestTime time.Time
|
||||
initialRequestTime time.Time
|
||||
|
||||
onExecutionStart []func(immediately bool) // listener ( actual execution of action starts )
|
||||
onExecutionDone []func() // listener ( actual execution of action finished )
|
||||
onRequest []func(pending int, initial bool) // listener ( a request came in, waiting for execution )
|
||||
}
|
||||
|
||||
func NewDelayedCombiningInvoker(action func(), delay time.Duration, maxDelay time.Duration) *DelayedCombiningInvoker {
|
||||
return &DelayedCombiningInvoker{
|
||||
action: action,
|
||||
delay: delay,
|
||||
maxDelay: maxDelay,
|
||||
executorRunning: syncext.NewAtomicBool(false),
|
||||
pendingRequests: syncext.NewAtomic[int](0),
|
||||
triggerChan: make(chan bool),
|
||||
cancelChan: make(chan bool, 1),
|
||||
execNowChan: make(chan bool, 1),
|
||||
lastRequestTime: time.Now(),
|
||||
initialRequestTime: time.Now(),
|
||||
onExecutionStart: make([]func(bool), 0),
|
||||
onExecutionDone: make([]func(), 0),
|
||||
onRequest: make([]func(int, bool), 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) Request() {
|
||||
now := time.Now()
|
||||
|
||||
d.syncLock.Lock()
|
||||
defer d.syncLock.Unlock()
|
||||
|
||||
if d.executorRunning.Get() {
|
||||
d.lastRequestTime = now
|
||||
d.pendingRequests.Update(func(v int) int { return v + 1 })
|
||||
|
||||
for _, fn := range d.onRequest {
|
||||
_ = langext.RunPanicSafe(func() { fn(d.pendingRequests.Get(), true) })
|
||||
}
|
||||
|
||||
d.triggerChan <- true
|
||||
} else {
|
||||
d.initialRequestTime = now
|
||||
d.lastRequestTime = now
|
||||
|
||||
d.executorRunning.Set(true)
|
||||
d.pendingRequests.Set(1)
|
||||
syncext.ReadNonBlocking(d.triggerChan) // clear the channel
|
||||
syncext.ReadNonBlocking(d.cancelChan) // clear the channel
|
||||
syncext.ReadNonBlocking(d.execNowChan) // clear the channel
|
||||
|
||||
for _, fn := range d.onRequest {
|
||||
_ = langext.RunPanicSafe(func() { fn(d.pendingRequests.Get(), false) })
|
||||
}
|
||||
|
||||
go d.run()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) run() {
|
||||
|
||||
needsExecutorRunningCleanup := true
|
||||
defer func() {
|
||||
if needsExecutorRunningCleanup {
|
||||
d.syncLock.Lock()
|
||||
d.executorRunning.Set(false)
|
||||
d.syncLock.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
d.syncLock.Lock()
|
||||
timeOut := max(min(d.maxDelay-time.Since(d.initialRequestTime), d.delay-time.Since(d.lastRequestTime)), 0)
|
||||
d.syncLock.Unlock()
|
||||
|
||||
immediately := false
|
||||
|
||||
select {
|
||||
case <-d.execNowChan:
|
||||
// run immediately
|
||||
immediately = true
|
||||
break
|
||||
case <-d.triggerChan:
|
||||
// external trigger - needs to re-evaluate
|
||||
break
|
||||
case <-d.cancelChan:
|
||||
// cancel
|
||||
return
|
||||
case <-time.After(timeOut):
|
||||
// time elapsed - check for execution
|
||||
break
|
||||
|
||||
}
|
||||
|
||||
d.syncLock.Lock()
|
||||
execute := immediately || time.Since(d.lastRequestTime) >= d.delay || time.Since(d.initialRequestTime) >= d.maxDelay
|
||||
if !execute {
|
||||
d.syncLock.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
d.pendingRequests.Set(0)
|
||||
|
||||
for _, fn := range d.onExecutionStart {
|
||||
_ = langext.RunPanicSafe(func() { fn(immediately) })
|
||||
}
|
||||
|
||||
// =================================================
|
||||
_ = langext.RunPanicSafe(d.action)
|
||||
// =================================================
|
||||
|
||||
d.executorRunning.Set(false) // ensure HasPendingRequests returns fals ein onExecutionDone listener
|
||||
needsExecutorRunningCleanup = false
|
||||
|
||||
for _, fn := range d.onExecutionDone {
|
||||
_ = langext.RunPanicSafe(fn)
|
||||
}
|
||||
|
||||
d.syncLock.Unlock()
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) CancelPendingRequests() {
|
||||
d.syncLock.Lock()
|
||||
defer d.syncLock.Unlock()
|
||||
|
||||
syncext.WriteNonBlocking(d.cancelChan, true)
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) HasPendingRequests() bool {
|
||||
return d.executorRunning.Get()
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) CountPendingRequests() int {
|
||||
return d.pendingRequests.Get()
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) ExecuteNow() bool {
|
||||
d.syncLock.Lock()
|
||||
defer d.syncLock.Unlock()
|
||||
|
||||
if d.executorRunning.Get() {
|
||||
syncext.WriteNonBlocking(d.execNowChan, true)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) WaitForCompletion(ctx context.Context) error {
|
||||
return d.executorRunning.WaitWithContext(ctx, false)
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) RegisterOnExecutionStart(fn func(immediately bool)) {
|
||||
d.syncLock.Lock()
|
||||
defer d.syncLock.Unlock()
|
||||
d.onExecutionStart = append(d.onExecutionStart, fn)
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) RegisterOnExecutionDone(fn func()) {
|
||||
d.syncLock.Lock()
|
||||
defer d.syncLock.Unlock()
|
||||
d.onExecutionDone = append(d.onExecutionDone, fn)
|
||||
}
|
||||
|
||||
func (d *DelayedCombiningInvoker) RegisterOnRequest(fn func(pending int, initial bool)) {
|
||||
d.syncLock.Lock()
|
||||
defer d.syncLock.Unlock()
|
||||
d.onRequest = append(d.onRequest, fn)
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func waitForCalls(t *testing.T, calls *int64, want int64, max time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(max)
|
||||
for time.Now().Before(deadline) {
|
||||
if atomic.LoadInt64(calls) >= want {
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_SingleRequest(t *testing.T) {
|
||||
var calls int64
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
atomic.AddInt64(&calls, 1)
|
||||
}, 20*time.Millisecond, 200*time.Millisecond)
|
||||
|
||||
d.Request()
|
||||
|
||||
waitForCalls(t, &calls, 1, 2*time.Second)
|
||||
if c := atomic.LoadInt64(&calls); c != 1 {
|
||||
t.Fatalf("calls=%d want 1", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_TwoRequestsCombine(t *testing.T) {
|
||||
var calls int64
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
atomic.AddInt64(&calls, 1)
|
||||
}, 50*time.Millisecond, 1*time.Second)
|
||||
|
||||
d.Request()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
d.Request()
|
||||
|
||||
waitForCalls(t, &calls, 1, 2*time.Second)
|
||||
if c := atomic.LoadInt64(&calls); c != 1 {
|
||||
t.Fatalf("calls=%d want 1 (should be combined)", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_SequentialRuns(t *testing.T) {
|
||||
var calls int64
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
atomic.AddInt64(&calls, 1)
|
||||
}, 20*time.Millisecond, 200*time.Millisecond)
|
||||
|
||||
d.Request()
|
||||
waitForCalls(t, &calls, 1, 2*time.Second)
|
||||
if c := atomic.LoadInt64(&calls); c != 1 {
|
||||
t.Fatalf("after first wait calls=%d want 1", c)
|
||||
}
|
||||
|
||||
// allow executorRunning to clear
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
d.Request()
|
||||
waitForCalls(t, &calls, 2, 2*time.Second)
|
||||
if c := atomic.LoadInt64(&calls); c != 2 {
|
||||
t.Fatalf("calls=%d want 2", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_ExecuteNow(t *testing.T) {
|
||||
var calls int64
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
atomic.AddInt64(&calls, 1)
|
||||
}, 5*time.Second, 30*time.Second)
|
||||
|
||||
d.Request()
|
||||
if !d.HasPendingRequests() {
|
||||
t.Fatal("should have pending requests")
|
||||
}
|
||||
|
||||
if !d.ExecuteNow() {
|
||||
t.Fatal("ExecuteNow should return true when running")
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
if atomic.LoadInt64(&calls) >= 1 {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if c := atomic.LoadInt64(&calls); c != 1 {
|
||||
t.Fatalf("calls=%d want 1 (ExecuteNow should fire well before delay)", c)
|
||||
}
|
||||
|
||||
// allow internal state cleanup
|
||||
for i := 0; i < 100; i++ {
|
||||
if !d.HasPendingRequests() {
|
||||
break
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if d.ExecuteNow() {
|
||||
t.Fatal("ExecuteNow should return false when no pending")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_Cancel(t *testing.T) {
|
||||
var calls int64
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
atomic.AddInt64(&calls, 1)
|
||||
}, 500*time.Millisecond, 5*time.Second)
|
||||
|
||||
d.Request()
|
||||
d.CancelPendingRequests()
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
if c := atomic.LoadInt64(&calls); c != 0 {
|
||||
t.Fatalf("calls=%d want 0 after cancel", c)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_HasAndCountPending(t *testing.T) {
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
// no-op
|
||||
}, 500*time.Millisecond, 5*time.Second)
|
||||
|
||||
if d.HasPendingRequests() {
|
||||
t.Fatal("should not have pending before any Request")
|
||||
}
|
||||
if d.CountPendingRequests() != 0 {
|
||||
t.Fatalf("count=%d want 0", d.CountPendingRequests())
|
||||
}
|
||||
|
||||
d.Request()
|
||||
if !d.HasPendingRequests() {
|
||||
t.Fatal("should have pending")
|
||||
}
|
||||
if d.CountPendingRequests() < 1 {
|
||||
t.Fatalf("count=%d want >=1", d.CountPendingRequests())
|
||||
}
|
||||
d.CancelPendingRequests()
|
||||
}
|
||||
|
||||
func TestDelayedCombiningInvoker_Listeners(t *testing.T) {
|
||||
var (
|
||||
startCount int64
|
||||
doneCount int64
|
||||
requestCount int64
|
||||
)
|
||||
|
||||
d := NewDelayedCombiningInvoker(func() {
|
||||
// no-op
|
||||
}, 20*time.Millisecond, 200*time.Millisecond)
|
||||
|
||||
d.RegisterOnExecutionStart(func(immediately bool) {
|
||||
atomic.AddInt64(&startCount, 1)
|
||||
})
|
||||
d.RegisterOnExecutionDone(func() {
|
||||
atomic.AddInt64(&doneCount, 1)
|
||||
})
|
||||
d.RegisterOnRequest(func(pending int, initial bool) {
|
||||
atomic.AddInt64(&requestCount, 1)
|
||||
})
|
||||
|
||||
d.Request()
|
||||
|
||||
waitForCalls(t, &doneCount, 1, 2*time.Second)
|
||||
|
||||
if atomic.LoadInt64(&startCount) != 1 {
|
||||
t.Fatalf("startCount=%d want 1", startCount)
|
||||
}
|
||||
if atomic.LoadInt64(&doneCount) != 1 {
|
||||
t.Fatalf("doneCount=%d want 1", doneCount)
|
||||
}
|
||||
if atomic.LoadInt64(&requestCount) != 1 {
|
||||
t.Fatalf("requestCount=%d want 1", requestCount)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
+1
-1
@@ -14,7 +14,7 @@ func ObjectMerge[T1 any, T2 any](base T1, override T2) T1 {
|
||||
fieldBase := reflBase.Field(i)
|
||||
fieldOvrd := reflOvrd.Field(i)
|
||||
|
||||
if fieldBase.Kind() != reflect.Ptr || fieldOvrd.Kind() != reflect.Ptr {
|
||||
if fieldBase.Kind() != reflect.Pointer || fieldOvrd.Kind() != reflect.Pointer {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -26,17 +25,17 @@ func TestObjectMerge(t *testing.T) {
|
||||
|
||||
valueA := A{
|
||||
Field1: nil,
|
||||
Field2: langext.Ptr("99"),
|
||||
Field3: langext.Ptr(12.2),
|
||||
Field2: new("99"),
|
||||
Field3: new(12.2),
|
||||
Field4: nil,
|
||||
OnlyA: 1,
|
||||
DiffType: 2,
|
||||
}
|
||||
|
||||
valueB := B{
|
||||
Field1: langext.Ptr(12),
|
||||
Field1: new(12),
|
||||
Field2: nil,
|
||||
Field3: langext.Ptr(13.2),
|
||||
Field3: new(13.2),
|
||||
Field4: nil,
|
||||
OnlyB: 1,
|
||||
DiffType: "X",
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MultiMutex is a simple map[key -> mutex]
|
||||
type MultiMutex[TKey comparable] struct {
|
||||
mutextMap *SyncMap[TKey, *CASMutex]
|
||||
}
|
||||
|
||||
func NewMultiMutex[TKey comparable]() *MultiMutex[TKey] {
|
||||
return &MultiMutex[TKey]{
|
||||
mutextMap: NewSyncMap[TKey, *CASMutex](),
|
||||
}
|
||||
}
|
||||
|
||||
// TryLockWithContext attempts to acquire the lock, blocking until resources
|
||||
// are available or ctx is done (timeout or cancellation).
|
||||
func (mm *MultiMutex[TKey]) TryLockWithContext(ctx context.Context, key TKey) bool {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.TryLockWithContext(ctx)
|
||||
}
|
||||
|
||||
// Lock acquires the lock.
|
||||
// If it is currently held by others, Lock will wait until it has a chance to acquire it.
|
||||
func (mm *MultiMutex[TKey]) Lock(key TKey) {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
lck.Lock()
|
||||
}
|
||||
|
||||
// TryLock attempts to acquire the lock without blocking.
|
||||
// Return false if someone is holding it now.
|
||||
func (mm *MultiMutex[TKey]) TryLock(key TKey) bool {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.TryLock()
|
||||
}
|
||||
|
||||
// TryLockWithTimeout attempts to acquire the lock within a period of time.
|
||||
// Return false if spending time is more than duration and no chance to acquire it.
|
||||
func (mm *MultiMutex[TKey]) TryLockWithTimeout(key TKey, duration time.Duration) bool {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.TryLockWithTimeout(duration)
|
||||
}
|
||||
|
||||
// Unlock releases the lock.
|
||||
func (mm *MultiMutex[TKey]) Unlock(key TKey) {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
lck.Unlock()
|
||||
}
|
||||
|
||||
// RTryLockWithContext attempts to acquire the read lock, blocking until resources
|
||||
// are available or ctx is done (timeout or cancellation).
|
||||
func (mm *MultiMutex[TKey]) RTryLockWithContext(ctx context.Context, key TKey) bool {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.RTryLockWithContext(ctx)
|
||||
}
|
||||
|
||||
// RLock acquires the read lock.
|
||||
// If it is currently held by others writing, RLock will wait until it has a chance to acquire it.
|
||||
func (mm *MultiMutex[TKey]) RLock(key TKey) {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
lck.RLock()
|
||||
}
|
||||
|
||||
// RTryLock attempts to acquire the read lock without blocking.
|
||||
// Return false if someone is writing it now.
|
||||
func (mm *MultiMutex[TKey]) RTryLock(key TKey) bool {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.RTryLock()
|
||||
}
|
||||
|
||||
// RTryLockWithTimeout attempts to acquire the read lock within a period of time.
|
||||
// Return false if spending time is more than duration and no chance to acquire it.
|
||||
func (mm *MultiMutex[TKey]) RTryLockWithTimeout(duration time.Duration, key TKey) bool {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.RTryLockWithTimeout(duration)
|
||||
}
|
||||
|
||||
// RUnlock releases the read lock.
|
||||
func (mm *MultiMutex[TKey]) RUnlock(key TKey) {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
lck.RUnlock()
|
||||
}
|
||||
|
||||
// RLocker returns a Locker interface that implements the Lock and Unlock methods
|
||||
// by calling CASMutex.RLock and CASMutex.RUnlock.
|
||||
func (mm *MultiMutex[TKey]) RLocker(key TKey) sync.Locker {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck.RLocker()
|
||||
}
|
||||
|
||||
// Get returns a Locker interface
|
||||
func (mm *MultiMutex[TKey]) Get(key TKey) sync.Locker {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck
|
||||
}
|
||||
|
||||
// GetCAS returns the underlying CASMutex
|
||||
func (mm *MultiMutex[TKey]) GetCAS(key TKey) *CASMutex {
|
||||
lck := mm.mutextMap.GetAndSetIfNotContainsFunc(key, NewCASMutex)
|
||||
return lck
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMultiMutex_LockDifferentKeys(t *testing.T) {
|
||||
mm := NewMultiMutex[string]()
|
||||
mm.Lock("a")
|
||||
mm.Lock("b")
|
||||
mm.Unlock("a")
|
||||
mm.Unlock("b")
|
||||
}
|
||||
|
||||
func TestMultiMutex_TryLockSameKey(t *testing.T) {
|
||||
mm := NewMultiMutex[string]()
|
||||
if !mm.TryLock("k") {
|
||||
t.Fatal("TryLock should succeed first time")
|
||||
}
|
||||
if mm.TryLock("k") {
|
||||
t.Fatal("TryLock should fail second time")
|
||||
}
|
||||
mm.Unlock("k")
|
||||
if !mm.TryLock("k") {
|
||||
t.Fatal("TryLock should succeed after unlock")
|
||||
}
|
||||
mm.Unlock("k")
|
||||
}
|
||||
|
||||
func TestMultiMutex_TryLockDifferentKeys(t *testing.T) {
|
||||
mm := NewMultiMutex[int]()
|
||||
if !mm.TryLock(1) {
|
||||
t.Fatal("TryLock(1) failed")
|
||||
}
|
||||
if !mm.TryLock(2) {
|
||||
t.Fatal("TryLock(2) failed - different keys should be independent")
|
||||
}
|
||||
mm.Unlock(1)
|
||||
mm.Unlock(2)
|
||||
}
|
||||
|
||||
func TestMultiMutex_RLockMultiple(t *testing.T) {
|
||||
mm := NewMultiMutex[string]()
|
||||
if !mm.RTryLock("k") {
|
||||
t.Fatal("first RTryLock failed")
|
||||
}
|
||||
if !mm.RTryLock("k") {
|
||||
t.Fatal("second RTryLock failed")
|
||||
}
|
||||
mm.RUnlock("k")
|
||||
mm.RUnlock("k")
|
||||
}
|
||||
|
||||
func TestMultiMutex_TryLockWithTimeout(t *testing.T) {
|
||||
mm := NewMultiMutex[string]()
|
||||
mm.Lock("k")
|
||||
if mm.TryLockWithTimeout("k", 10*time.Millisecond) {
|
||||
t.Fatal("expected timeout failure")
|
||||
}
|
||||
mm.Unlock("k")
|
||||
}
|
||||
|
||||
func TestMultiMutex_TryLockWithContext(t *testing.T) {
|
||||
mm := NewMultiMutex[string]()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||
defer cancel()
|
||||
if !mm.TryLockWithContext(ctx, "k") {
|
||||
t.Fatal("TryLockWithContext should succeed on free key")
|
||||
}
|
||||
mm.Unlock("k")
|
||||
}
|
||||
|
||||
func TestMultiMutex_GetAndGetCAS(t *testing.T) {
|
||||
mm := NewMultiMutex[string]()
|
||||
l := mm.Get("a")
|
||||
if l == nil {
|
||||
t.Fatal("Get returned nil")
|
||||
}
|
||||
cas := mm.GetCAS("a")
|
||||
if cas == nil {
|
||||
t.Fatal("GetCAS returned nil")
|
||||
}
|
||||
rl := mm.RLocker("a")
|
||||
if rl == nil {
|
||||
t.Fatal("RLocker returned nil")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package dataext
|
||||
|
||||
import "sync"
|
||||
|
||||
type MutexSet[T comparable] struct {
|
||||
master sync.RWMutex
|
||||
locks map[T]*sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMutexSet[T comparable]() *MutexSet[T] {
|
||||
return &MutexSet[T]{
|
||||
master: sync.RWMutex{},
|
||||
locks: make(map[T]*sync.RWMutex),
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MutexSet[T]) get(key T) *sync.RWMutex {
|
||||
ms.master.RLock()
|
||||
if v, ok := ms.locks[key]; ok {
|
||||
ms.master.RUnlock()
|
||||
return v
|
||||
}
|
||||
ms.master.RUnlock()
|
||||
|
||||
// ---------
|
||||
|
||||
ms.master.Lock()
|
||||
defer ms.master.Unlock()
|
||||
|
||||
if v, ok := ms.locks[key]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
m := &sync.RWMutex{}
|
||||
ms.locks[key] = m
|
||||
return m
|
||||
}
|
||||
|
||||
func (ms *MutexSet[T]) Lock(key T) {
|
||||
ms.get(key).Lock()
|
||||
}
|
||||
|
||||
func (ms *MutexSet[T]) Unlock(key T) {
|
||||
ms.get(key).Unlock()
|
||||
}
|
||||
|
||||
func (ms *MutexSet[T]) RLock(key T) {
|
||||
ms.get(key).RLock()
|
||||
}
|
||||
|
||||
func (ms *MutexSet[T]) RUnlock(key T) {
|
||||
ms.get(key).RUnlock()
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMutexSet_BasicLockUnlock(t *testing.T) {
|
||||
ms := NewMutexSet[string]()
|
||||
ms.Lock("a")
|
||||
ms.Unlock("a")
|
||||
ms.RLock("b")
|
||||
ms.RUnlock("b")
|
||||
}
|
||||
|
||||
func TestMutexSet_DifferentKeysIndependent(t *testing.T) {
|
||||
ms := NewMutexSet[int]()
|
||||
ms.Lock(1)
|
||||
ms.Lock(2)
|
||||
ms.Unlock(1)
|
||||
ms.Unlock(2)
|
||||
}
|
||||
|
||||
func TestMutexSet_SameKeyMutuallyExclusive(t *testing.T) {
|
||||
ms := NewMutexSet[string]()
|
||||
var counter int64
|
||||
const n = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ms.Lock("shared")
|
||||
atomic.AddInt64(&counter, 1)
|
||||
ms.Unlock("shared")
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
if atomic.LoadInt64(&counter) != n {
|
||||
t.Fatalf("got %d want %d", counter, n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMutexSet_RLockMultiple(t *testing.T) {
|
||||
ms := NewMutexSet[string]()
|
||||
ms.RLock("k")
|
||||
ms.RLock("k")
|
||||
ms.RUnlock("k")
|
||||
ms.RUnlock("k")
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package dataext
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
)
|
||||
|
||||
type JsonOpt[T any] struct {
|
||||
@@ -10,6 +11,14 @@ type JsonOpt[T any] struct {
|
||||
value T
|
||||
}
|
||||
|
||||
func NewJsonOpt[T any](v T) JsonOpt[T] {
|
||||
return JsonOpt[T]{isSet: true, value: v}
|
||||
}
|
||||
|
||||
func EmptyJsonOpt[T any]() JsonOpt[T] {
|
||||
return JsonOpt[T]{isSet: false}
|
||||
}
|
||||
|
||||
// MarshalJSON returns m as the JSON encoding of m.
|
||||
func (m JsonOpt[T]) MarshalJSON() ([]byte, error) {
|
||||
if !m.isSet {
|
||||
@@ -51,9 +60,24 @@ func (m JsonOpt[T]) ValueOrNil() *T {
|
||||
return &m.value
|
||||
}
|
||||
|
||||
func (m JsonOpt[T]) ValueDblPtrOrNil() **T {
|
||||
if !m.isSet {
|
||||
return nil
|
||||
}
|
||||
return langext.DblPtr(m.value)
|
||||
}
|
||||
|
||||
func (m JsonOpt[T]) MustValue() T {
|
||||
if !m.isSet {
|
||||
panic("value not set")
|
||||
}
|
||||
return m.value
|
||||
}
|
||||
|
||||
func (m JsonOpt[T]) IfSet(fn func(v T)) bool {
|
||||
if !m.isSet {
|
||||
return false
|
||||
}
|
||||
fn(m.value)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,142 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestJsonOpt_NewAndEmpty(t *testing.T) {
|
||||
o := NewJsonOpt[int](42)
|
||||
if !o.IsSet() {
|
||||
t.Fatal("expected IsSet=true")
|
||||
}
|
||||
if o.IsUnset() {
|
||||
t.Fatal("expected IsUnset=false")
|
||||
}
|
||||
|
||||
e := EmptyJsonOpt[int]()
|
||||
if e.IsSet() {
|
||||
t.Fatal("expected IsSet=false")
|
||||
}
|
||||
if !e.IsUnset() {
|
||||
t.Fatal("expected IsUnset=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_Value(t *testing.T) {
|
||||
o := NewJsonOpt[string]("hello")
|
||||
v, ok := o.Value()
|
||||
if !ok || v != "hello" {
|
||||
t.Fatalf("got (%q,%v)", v, ok)
|
||||
}
|
||||
|
||||
e := EmptyJsonOpt[string]()
|
||||
v, ok = e.Value()
|
||||
if ok || v != "" {
|
||||
t.Fatalf("empty got (%q,%v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_ValueOrNil(t *testing.T) {
|
||||
o := NewJsonOpt[int](7)
|
||||
p := o.ValueOrNil()
|
||||
if p == nil || *p != 7 {
|
||||
t.Fatalf("expected ptr to 7")
|
||||
}
|
||||
e := EmptyJsonOpt[int]()
|
||||
if e.ValueOrNil() != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_ValueDblPtrOrNil(t *testing.T) {
|
||||
o := NewJsonOpt[int](7)
|
||||
p := o.ValueDblPtrOrNil()
|
||||
if p == nil || *p == nil || **p != 7 {
|
||||
t.Fatalf("expected double ptr to 7")
|
||||
}
|
||||
e := EmptyJsonOpt[int]()
|
||||
if e.ValueDblPtrOrNil() != nil {
|
||||
t.Fatal("expected nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_MustValue(t *testing.T) {
|
||||
o := NewJsonOpt[int](9)
|
||||
if o.MustValue() != 9 {
|
||||
t.Fatal("MustValue wrong")
|
||||
}
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected panic")
|
||||
}
|
||||
}()
|
||||
EmptyJsonOpt[int]().MustValue()
|
||||
}
|
||||
|
||||
func TestJsonOpt_IfSet(t *testing.T) {
|
||||
called := false
|
||||
NewJsonOpt[int](1).IfSet(func(v int) {
|
||||
called = true
|
||||
if v != 1 {
|
||||
t.Fatalf("v=%d", v)
|
||||
}
|
||||
})
|
||||
if !called {
|
||||
t.Fatal("IfSet did not invoke fn")
|
||||
}
|
||||
|
||||
called = false
|
||||
EmptyJsonOpt[int]().IfSet(func(v int) { called = true })
|
||||
if called {
|
||||
t.Fatal("IfSet invoked fn on empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_MarshalJSON(t *testing.T) {
|
||||
o := NewJsonOpt[int](5)
|
||||
b, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b) != "5" {
|
||||
t.Fatalf("got %s", b)
|
||||
}
|
||||
|
||||
e := EmptyJsonOpt[int]()
|
||||
b, err = json.Marshal(e)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b) != "null" {
|
||||
t.Fatalf("got %s", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_UnmarshalJSON(t *testing.T) {
|
||||
var o JsonOpt[int]
|
||||
if err := json.Unmarshal([]byte("42"), &o); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !o.IsSet() {
|
||||
t.Fatal("should be set")
|
||||
}
|
||||
if v, _ := o.Value(); v != 42 {
|
||||
t.Fatalf("got %d", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJsonOpt_StructWithJsonOpt(t *testing.T) {
|
||||
type S struct {
|
||||
A JsonOpt[int] `json:"a"`
|
||||
B JsonOpt[string] `json:"b"`
|
||||
}
|
||||
s := S{A: NewJsonOpt[int](1), B: EmptyJsonOpt[string]()}
|
||||
b, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(b) != `{"a":1,"b":null}` {
|
||||
t.Fatalf("got %s", b)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"iter"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/syncext"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// PubSub is a simple Pub/Sub Broker
|
||||
// Clients can subscribe to a namespace and receive published messages on this namespace
|
||||
// Messages are broadcast to all subscribers
|
||||
type PubSub[TNamespace comparable, TData any] struct {
|
||||
masterLock *sync.Mutex
|
||||
|
||||
subscriptions map[TNamespace][]*pubSubSubscription[TNamespace, TData]
|
||||
}
|
||||
|
||||
type PubSubSubscription interface {
|
||||
Unsubscribe()
|
||||
}
|
||||
|
||||
type pubSubSubscription[TNamespace comparable, TData any] struct {
|
||||
ID string
|
||||
|
||||
parent *PubSub[TNamespace, TData]
|
||||
namespace TNamespace
|
||||
|
||||
subLock *sync.Mutex
|
||||
|
||||
Func func(TData)
|
||||
Chan chan TData
|
||||
|
||||
UnsubChan chan bool
|
||||
}
|
||||
|
||||
func (p *pubSubSubscription[TNamespace, TData]) Unsubscribe() {
|
||||
p.parent.unsubscribe(p)
|
||||
}
|
||||
|
||||
func NewPubSub[TNamespace comparable, TData any](capacity int) *PubSub[TNamespace, TData] {
|
||||
return &PubSub[TNamespace, TData]{
|
||||
masterLock: &sync.Mutex{},
|
||||
subscriptions: make(map[TNamespace][]*pubSubSubscription[TNamespace, TData], capacity),
|
||||
}
|
||||
}
|
||||
|
||||
func (ps *PubSub[TNamespace, TData]) Namespaces() []TNamespace {
|
||||
ps.masterLock.Lock()
|
||||
defer ps.masterLock.Unlock()
|
||||
|
||||
return langext.MapKeyArr(ps.subscriptions)
|
||||
}
|
||||
|
||||
func (ps *PubSub[TNamespace, TData]) SubscriberCount(ns TNamespace) int {
|
||||
ps.masterLock.Lock()
|
||||
defer ps.masterLock.Unlock()
|
||||
|
||||
return len(ps.subscriptions[ns])
|
||||
}
|
||||
|
||||
// Publish sends `data` to all subscriber
|
||||
// But unbuffered - if one is currently not listening, we skip (the actualReceiver < subscriber)
|
||||
func (ps *PubSub[TNamespace, TData]) Publish(ns TNamespace, data TData) (subscriber int, actualReceiver int) {
|
||||
ps.masterLock.Lock()
|
||||
subs := langext.ArrCopy(ps.subscriptions[ns])
|
||||
ps.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
actualReceiver = 0
|
||||
|
||||
for _, sub := range subs {
|
||||
func() {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
actualReceiver++
|
||||
} else if sub.Chan != nil {
|
||||
msgSent := syncext.WriteNonBlocking(sub.Chan, data)
|
||||
if msgSent {
|
||||
actualReceiver++
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return subscriber, actualReceiver
|
||||
}
|
||||
|
||||
// PublishWithContext sends `data` to all subscriber
|
||||
// buffered - if one is currently not listening, we wait (but error out when the context runs out)
|
||||
func (ps *PubSub[TNamespace, TData]) PublishWithContext(ctx context.Context, ns TNamespace, data TData) (subscriber int, actualReceiver int, err error) {
|
||||
ps.masterLock.Lock()
|
||||
subs := langext.ArrCopy(ps.subscriptions[ns])
|
||||
ps.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
actualReceiver = 0
|
||||
|
||||
for _, sub := range subs {
|
||||
err := func() error {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
actualReceiver++
|
||||
} else if sub.Chan != nil {
|
||||
err := syncext.WriteChannelWithContext(ctx, sub.Chan, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
actualReceiver++
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return subscriber, actualReceiver, err
|
||||
}
|
||||
}
|
||||
|
||||
return subscriber, actualReceiver, nil
|
||||
}
|
||||
|
||||
// PublishWithTimeout sends `data` to all subscriber
|
||||
// buffered - if one is currently not listening, we wait (but wait at most `timeout` - if the timeout is exceeded then actualReceiver < subscriber)
|
||||
func (ps *PubSub[TNamespace, TData]) PublishWithTimeout(ns TNamespace, data TData, timeout time.Duration) (subscriber int, actualReceiver int) {
|
||||
ps.masterLock.Lock()
|
||||
subs := langext.ArrCopy(ps.subscriptions[ns])
|
||||
ps.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
actualReceiver = 0
|
||||
|
||||
for _, sub := range subs {
|
||||
func() {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
actualReceiver++
|
||||
} else if sub.Chan != nil {
|
||||
ok := syncext.WriteChannelWithTimeout(sub.Chan, data, timeout)
|
||||
if ok {
|
||||
actualReceiver++
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return subscriber, actualReceiver
|
||||
}
|
||||
|
||||
// PublishAsync sends `data` to all subscriber
|
||||
// does not wait for subscriber (this method returns immediately), waits at most {timeout} seconds on channels (async)
|
||||
func (ps *PubSub[TNamespace, TData]) PublishAsync(ns TNamespace, data TData, timeout time.Duration) (subscriber int) {
|
||||
ps.masterLock.Lock()
|
||||
subs := langext.ArrCopy(ps.subscriptions[ns])
|
||||
ps.masterLock.Unlock()
|
||||
|
||||
subscriber = len(subs)
|
||||
|
||||
for _, sub := range subs {
|
||||
go func() {
|
||||
sub.subLock.Lock()
|
||||
defer sub.subLock.Unlock()
|
||||
|
||||
if sub.Func != nil {
|
||||
go func() { sub.Func(data) }()
|
||||
} else if sub.Chan != nil {
|
||||
_ = syncext.WriteChannelWithTimeout(sub.Chan, data, timeout)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return subscriber
|
||||
}
|
||||
|
||||
func (ps *PubSub[TNamespace, TData]) SubscribeByCallback(ns TNamespace, fn func(TData)) PubSubSubscription {
|
||||
ps.masterLock.Lock()
|
||||
defer ps.masterLock.Unlock()
|
||||
|
||||
sub := &pubSubSubscription[TNamespace, TData]{ID: xid.New().String(), namespace: ns, parent: ps, subLock: &sync.Mutex{}, Func: fn, UnsubChan: nil}
|
||||
|
||||
ps.subscriptions[ns] = append(ps.subscriptions[ns], sub)
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
func (ps *PubSub[TNamespace, TData]) SubscribeByChan(ns TNamespace, chanBufferSize int) (chan TData, PubSubSubscription) {
|
||||
ps.masterLock.Lock()
|
||||
defer ps.masterLock.Unlock()
|
||||
|
||||
msgCh := make(chan TData, chanBufferSize)
|
||||
|
||||
sub := &pubSubSubscription[TNamespace, TData]{ID: xid.New().String(), namespace: ns, parent: ps, subLock: &sync.Mutex{}, Chan: msgCh, UnsubChan: nil}
|
||||
|
||||
ps.subscriptions[ns] = append(ps.subscriptions[ns], sub)
|
||||
|
||||
return msgCh, sub
|
||||
}
|
||||
|
||||
func (ps *PubSub[TNamespace, TData]) SubscribeByIter(ns TNamespace, chanBufferSize int) (iter.Seq[TData], PubSubSubscription) {
|
||||
ps.masterLock.Lock()
|
||||
defer ps.masterLock.Unlock()
|
||||
|
||||
msgCh := make(chan TData, chanBufferSize)
|
||||
unsubChan := make(chan bool, 8)
|
||||
|
||||
sub := &pubSubSubscription[TNamespace, TData]{ID: xid.New().String(), namespace: ns, parent: ps, subLock: &sync.Mutex{}, Chan: msgCh, UnsubChan: unsubChan}
|
||||
|
||||
ps.subscriptions[ns] = append(ps.subscriptions[ns], sub)
|
||||
|
||||
iterFun := func(yield func(TData) bool) {
|
||||
for {
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
if !yield(msg) {
|
||||
sub.Unsubscribe()
|
||||
return
|
||||
}
|
||||
case <-sub.UnsubChan:
|
||||
sub.Unsubscribe()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return iterFun, sub
|
||||
}
|
||||
|
||||
func (ps *PubSub[TNamespace, TData]) unsubscribe(p *pubSubSubscription[TNamespace, TData]) {
|
||||
ps.masterLock.Lock()
|
||||
defer ps.masterLock.Unlock()
|
||||
|
||||
p.subLock.Lock()
|
||||
defer p.subLock.Unlock()
|
||||
|
||||
if p.Chan != nil {
|
||||
close(p.Chan)
|
||||
p.Chan = nil
|
||||
}
|
||||
if p.UnsubChan != nil {
|
||||
syncext.WriteNonBlocking(p.UnsubChan, true)
|
||||
close(p.UnsubChan)
|
||||
p.UnsubChan = nil
|
||||
}
|
||||
|
||||
ps.subscriptions[p.namespace] = langext.ArrFilter(ps.subscriptions[p.namespace], func(v *pubSubSubscription[TNamespace, TData]) bool {
|
||||
return v.ID != p.ID
|
||||
})
|
||||
if len(ps.subscriptions[p.namespace]) == 0 {
|
||||
delete(ps.subscriptions, p.namespace)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,438 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewPubSub(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
if ps == nil {
|
||||
t.Fatal("NewPubSub returned nil")
|
||||
}
|
||||
if ps.masterLock == nil {
|
||||
t.Fatal("masterLock is nil")
|
||||
}
|
||||
if ps.subscriptions == nil {
|
||||
t.Fatal("subscriptions is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub_Namespaces(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
// Initially no namespaces
|
||||
namespaces := ps.Namespaces()
|
||||
if len(namespaces) != 0 {
|
||||
t.Fatalf("Expected 0 namespaces, got %d", len(namespaces))
|
||||
}
|
||||
|
||||
// Add a subscription to create a namespace
|
||||
_, sub1 := ps.SubscribeByChan("test-ns1", 1)
|
||||
defer sub1.Unsubscribe()
|
||||
|
||||
// Add another subscription to a different namespace
|
||||
_, sub2 := ps.SubscribeByChan("test-ns2", 1)
|
||||
defer sub2.Unsubscribe()
|
||||
|
||||
// Check namespaces
|
||||
namespaces = ps.Namespaces()
|
||||
if len(namespaces) != 2 {
|
||||
t.Fatalf("Expected 2 namespaces, got %d", len(namespaces))
|
||||
}
|
||||
|
||||
// Check if namespaces contain the expected values
|
||||
found1, found2 := false, false
|
||||
for _, ns := range namespaces {
|
||||
if ns == "test-ns1" {
|
||||
found1 = true
|
||||
}
|
||||
if ns == "test-ns2" {
|
||||
found2 = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found1 || !found2 {
|
||||
t.Fatalf("Expected to find both namespaces, found ns1: %v, ns2: %v", found1, found2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub_SubscribeByCallback(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
var received string
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
callback := func(msg string) {
|
||||
received = msg
|
||||
wg.Done()
|
||||
}
|
||||
|
||||
sub := ps.SubscribeByCallback("test-ns", callback)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := ps.Publish("test-ns", "hello")
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Wait for the callback to be executed
|
||||
wg.Wait()
|
||||
|
||||
if received != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", received)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub_SubscribeByChan(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
ch, sub := ps.SubscribeByChan("test-ns", 1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := ps.Publish("test-ns", "hello")
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Read from the channel with a timeout to avoid blocking
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub_SubscribeByIter(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
iterSeq, sub := ps.SubscribeByIter("test-ns", 1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Channel to communicate when message is received
|
||||
done := make(chan bool)
|
||||
goroutineDone := make(chan struct{})
|
||||
received := false
|
||||
|
||||
// Start a goroutine to use the iterator
|
||||
go func() {
|
||||
defer close(goroutineDone)
|
||||
for msg := range iterSeq {
|
||||
if msg == "hello" {
|
||||
received = true
|
||||
done <- true
|
||||
return // Stop iteration — triggers Unsubscribe via yield returning false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Give time for the iterator to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Publish a message
|
||||
ps.Publish("test-ns", "hello")
|
||||
|
||||
// Wait for the message to be received or timeout
|
||||
select {
|
||||
case <-done:
|
||||
if !received {
|
||||
t.Fatal("Message was received but not 'hello'")
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Wait for the goroutine to fully exit so Unsubscribe (triggered by the
|
||||
// iterator cleanup when yield returns false) has completed.
|
||||
select {
|
||||
case <-goroutineDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for goroutine to finish")
|
||||
}
|
||||
|
||||
subCount := ps.SubscriberCount("test-ns")
|
||||
if subCount != 0 {
|
||||
t.Fatalf("Expected 0 receivers, got %d", subCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub_Publish(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
// Test publishing to a namespace with no subscribers
|
||||
subs, receivers := ps.Publish("empty-ns", "hello")
|
||||
if subs != 0 {
|
||||
t.Fatalf("Expected 0 subscribers, got %d", subs)
|
||||
}
|
||||
if receivers != 0 {
|
||||
t.Fatalf("Expected 0 receivers, got %d", receivers)
|
||||
}
|
||||
|
||||
// Add a subscriber
|
||||
ch, sub := ps.SubscribeByChan("test-ns", 1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers = ps.Publish("test-ns", "hello")
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Test non-blocking behavior with a full channel
|
||||
// First fill the channel
|
||||
ps.Publish("test-ns", "fill")
|
||||
|
||||
// Now publish again - this should not block but skip the receiver
|
||||
subs, receivers = ps.Publish("test-ns", "overflow")
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
// The receiver count might be 0 if the channel is full
|
||||
|
||||
// Drain the channel
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestPubSub_PublishWithTimeout(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
// Add a subscriber with a channel
|
||||
ch, sub := ps.SubscribeByChan("test-ns", 1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Publish with a timeout
|
||||
subs, receivers := ps.PublishWithTimeout("test-ns", "hello", 100*time.Millisecond)
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Fill the channel
|
||||
ps.Publish("test-ns", "fill")
|
||||
|
||||
// Test timeout behavior with a full channel
|
||||
start := time.Now()
|
||||
subs, receivers = ps.PublishWithTimeout("test-ns", "timeout-test", 50*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
// The receiver count should be 0 if the timeout occurred
|
||||
if elapsed < 50*time.Millisecond {
|
||||
t.Fatalf("Expected to wait at least 50ms, only waited %v", elapsed)
|
||||
}
|
||||
|
||||
// Drain the channel
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestPubSub_PublishWithContext(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
// Add a subscriber with a channel
|
||||
ch, sub := ps.SubscribeByChan("test-ns", 1)
|
||||
defer sub.Unsubscribe()
|
||||
|
||||
// Create a context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Publish with context
|
||||
subs, receivers, err := ps.PublishWithContext(ctx, "test-ns", "hello")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Fill the channel
|
||||
ps.Publish("test-ns", "fill")
|
||||
|
||||
// Test context cancellation with a full channel
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Cancel the context after a short delay
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
subs, receivers, err = ps.PublishWithContext(ctx, "test-ns", "context-test")
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
|
||||
// Should get a context canceled error
|
||||
if err == nil {
|
||||
t.Fatal("Expected context canceled error, got nil")
|
||||
}
|
||||
|
||||
if elapsed < 50*time.Millisecond {
|
||||
t.Fatalf("Expected to wait at least 50ms, only waited %v", elapsed)
|
||||
}
|
||||
|
||||
// Drain the channel
|
||||
<-ch
|
||||
}
|
||||
|
||||
func TestPubSub_Unsubscribe(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
// Add a subscriber
|
||||
ch, sub := ps.SubscribeByChan("test-ns", 1)
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := ps.Publish("test-ns", "hello")
|
||||
if subs != 1 {
|
||||
t.Fatalf("Expected 1 subscriber, got %d", subs)
|
||||
}
|
||||
if receivers != 1 {
|
||||
t.Fatalf("Expected 1 receiver, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received
|
||||
select {
|
||||
case msg := <-ch:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message")
|
||||
}
|
||||
|
||||
// Unsubscribe
|
||||
sub.Unsubscribe()
|
||||
|
||||
// Publish again
|
||||
subs, receivers = ps.Publish("test-ns", "after-unsub")
|
||||
if subs != 0 {
|
||||
t.Fatalf("Expected 0 subscribers after unsubscribe, got %d", subs)
|
||||
}
|
||||
if receivers != 0 {
|
||||
t.Fatalf("Expected 0 receivers after unsubscribe, got %d", receivers)
|
||||
}
|
||||
|
||||
// Check that the namespace is removed
|
||||
namespaces := ps.Namespaces()
|
||||
if len(namespaces) != 0 {
|
||||
t.Fatalf("Expected 0 namespaces after unsubscribe, got %d", len(namespaces))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPubSub_MultipleSubscribers(t *testing.T) {
|
||||
ps := NewPubSub[string, string](10)
|
||||
|
||||
// Add multiple subscribers
|
||||
ch1, sub1 := ps.SubscribeByChan("test-ns", 1)
|
||||
defer sub1.Unsubscribe()
|
||||
|
||||
ch2, sub2 := ps.SubscribeByChan("test-ns", 1)
|
||||
defer sub2.Unsubscribe()
|
||||
|
||||
var received string
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
sub3 := ps.SubscribeByCallback("test-ns", func(msg string) {
|
||||
received = msg
|
||||
wg.Done()
|
||||
})
|
||||
defer sub3.Unsubscribe()
|
||||
|
||||
// Publish a message
|
||||
subs, receivers := ps.Publish("test-ns", "hello")
|
||||
if subs != 3 {
|
||||
t.Fatalf("Expected 3 subscribers, got %d", subs)
|
||||
}
|
||||
if receivers != 3 {
|
||||
t.Fatalf("Expected 3 receivers, got %d", receivers)
|
||||
}
|
||||
|
||||
// Verify the message was received by all subscribers
|
||||
select {
|
||||
case msg := <-ch1:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected ch1 to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message on ch1")
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-ch2:
|
||||
if msg != "hello" {
|
||||
t.Fatalf("Expected ch2 to receive 'hello', got '%s'", msg)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Timed out waiting for message on ch2")
|
||||
}
|
||||
|
||||
// Wait for the callback
|
||||
wg.Wait()
|
||||
|
||||
if received != "hello" {
|
||||
t.Fatalf("Expected callback to receive 'hello', got '%s'", received)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package dataext
|
||||
|
||||
import "iter"
|
||||
|
||||
type RingBuffer[T any] struct {
|
||||
items []T //
|
||||
capacity int // max number of items the buffer can hold
|
||||
size int // how many items are in the buffer
|
||||
head int // ptr to next item
|
||||
}
|
||||
|
||||
func NewRingBuffer[T any](capacity int) *RingBuffer[T] {
|
||||
return &RingBuffer[T]{
|
||||
items: make([]T, capacity),
|
||||
capacity: capacity,
|
||||
size: 0,
|
||||
head: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Push(item T) {
|
||||
if rb.size < rb.capacity {
|
||||
rb.size++
|
||||
}
|
||||
rb.items[rb.head] = item
|
||||
rb.head = (rb.head + 1) % rb.capacity
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) PushPop(item T) *T {
|
||||
if rb.size < rb.capacity {
|
||||
rb.size++
|
||||
rb.items[rb.head] = item
|
||||
rb.head = (rb.head + 1) % rb.capacity
|
||||
return nil
|
||||
} else {
|
||||
prev := rb.items[rb.head]
|
||||
rb.items[rb.head] = item
|
||||
rb.head = (rb.head + 1) % rb.capacity
|
||||
return &prev
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Peek() (T, bool) {
|
||||
if rb.size == 0 {
|
||||
return *new(T), false
|
||||
}
|
||||
return rb.items[(rb.head-1+rb.capacity)%rb.capacity], true
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Items() []T {
|
||||
if rb.size < rb.capacity {
|
||||
return rb.items[:rb.size]
|
||||
}
|
||||
return append(rb.items[rb.head:], rb.items[:rb.head]...)
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Size() int {
|
||||
return rb.size
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Capacity() int {
|
||||
return rb.capacity
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Clear() {
|
||||
rb.size = 0
|
||||
rb.head = 0
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) IsFull() bool {
|
||||
return rb.size == rb.capacity
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) At(i int) T {
|
||||
if i < 0 || i >= rb.size {
|
||||
panic("Index out of bounds")
|
||||
}
|
||||
if rb.size < rb.capacity {
|
||||
return rb.items[i]
|
||||
}
|
||||
return rb.items[(rb.head+i)%rb.capacity]
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Get(i int) (T, bool) {
|
||||
if i < 0 || i >= rb.size {
|
||||
return *new(T), false
|
||||
}
|
||||
if rb.size < rb.capacity {
|
||||
return rb.items[i], true
|
||||
}
|
||||
return rb.items[(rb.head+i)%rb.capacity], true
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Iter() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
for i := 0; i < rb.size; i++ {
|
||||
if !yield(rb.At(i)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Iter2() iter.Seq2[int, T] {
|
||||
return func(yield func(int, T) bool) {
|
||||
for i := 0; i < rb.size; i++ {
|
||||
if !yield(i, rb.At(i)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[T]) Remove(fnEqual func(v T) bool) int {
|
||||
// Mike [2024-11-13]: I *really* tried to write an in-place algorithm to remove elements
|
||||
// But after carful consideration, I left that as an exercise for future readers
|
||||
// It is, suprisingly, non-trivial, especially because the head-ptr must be weirdly updated
|
||||
// And out At() method does not work correctly with {head<>0 && size<capacity}
|
||||
|
||||
dc := 0
|
||||
b := make([]T, rb.capacity)
|
||||
bsize := 0
|
||||
|
||||
for i := 0; i < rb.size; i++ {
|
||||
comp := rb.At(i)
|
||||
if fnEqual(comp) {
|
||||
dc++
|
||||
} else {
|
||||
b[bsize] = comp
|
||||
bsize++
|
||||
}
|
||||
}
|
||||
|
||||
if dc == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
rb.items = b
|
||||
rb.size = bsize
|
||||
rb.head = bsize % rb.capacity
|
||||
|
||||
return dc
|
||||
|
||||
}
|
||||
@@ -0,0 +1,447 @@
|
||||
package dataext
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestRingBufferPushAddsItem(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
if rb.Size() != 1 {
|
||||
t.Errorf("Expected size 1, got %d", rb.Size())
|
||||
}
|
||||
if item, _ := rb.Peek(); item != 1 {
|
||||
t.Errorf("Expected item 1, got %d", item)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferPushPopReturnsOldestItem(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
if item := rb.PushPop(4); item == nil || *item != 1 {
|
||||
t.Errorf("Expected item 1, got %v", item)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferPeekReturnsLastPushedItem(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
if item, _ := rb.Peek(); item != 2 {
|
||||
t.Errorf("Expected item 2, got %d", item)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferOverflow1(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1) // overriden
|
||||
rb.Push(2) // overriden
|
||||
rb.Push(3)
|
||||
rb.Push(9)
|
||||
rb.Push(4)
|
||||
rb.Push(5)
|
||||
rb.Push(7)
|
||||
if rb.Size() != 5 {
|
||||
t.Errorf("Expected size 4, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{3, 9, 4, 5, 7}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferItemsReturnsAllItems(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
items := rb.Items()
|
||||
expected := []int{1, 2, 3}
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferClearEmptiesBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Clear()
|
||||
if rb.Size() != 0 {
|
||||
t.Errorf("Expected size 0, got %d", rb.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferIsFullReturnsTrueWhenFull(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
if !rb.IsFull() {
|
||||
t.Errorf("Expected buffer to be full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferAtReturnsCorrectItem(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
if item := rb.At(1); item != 2 {
|
||||
t.Errorf("Expected item 2, got %d", item)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferGetReturnsCorrectItem(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
if item, ok := rb.Get(1); !ok || item != 2 {
|
||||
t.Errorf("Expected item 2, got %d", item)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
rb.Push(2)
|
||||
rb.Push(4)
|
||||
removed := rb.Remove(func(v int) bool { return v == 2 })
|
||||
if removed != 2 {
|
||||
t.Errorf("Expected 2 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 3 {
|
||||
t.Errorf("Expected size 3, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{1, 3, 4}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems2(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
rb.Push(2)
|
||||
rb.Push(4)
|
||||
removed := rb.Remove(func(v int) bool { return v == 3 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 2 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 4 {
|
||||
t.Errorf("Expected size 3, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{1, 2, 2, 4}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems3(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
rb.Push(9)
|
||||
rb.Push(4)
|
||||
removed := rb.Remove(func(v int) bool { return v == 3 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 2 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 4 {
|
||||
t.Errorf("Expected size 3, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{1, 2, 9, 4}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems4(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1) // overriden
|
||||
rb.Push(2) // overriden
|
||||
rb.Push(3)
|
||||
rb.Push(9)
|
||||
rb.Push(4)
|
||||
rb.Push(5)
|
||||
rb.Push(7)
|
||||
removed := rb.Remove(func(v int) bool { return v == 7 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 1 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 4 {
|
||||
t.Errorf("Expected size 4, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{3, 9, 4, 5}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems5(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1) // overriden
|
||||
rb.Push(2) // overriden
|
||||
rb.Push(3)
|
||||
rb.Push(9)
|
||||
rb.Push(4)
|
||||
rb.Push(5)
|
||||
rb.Push(7)
|
||||
removed := rb.Remove(func(v int) bool { return v == 3 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 1 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 4 {
|
||||
t.Errorf("Expected size 4, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{9, 4, 5, 7}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems6(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1) // overriden
|
||||
rb.Push(2) // overriden
|
||||
rb.Push(3)
|
||||
rb.Push(9)
|
||||
rb.Push(4)
|
||||
rb.Push(5)
|
||||
rb.Push(7)
|
||||
removed := rb.Remove(func(v int) bool { return v == 1 })
|
||||
if removed != 0 {
|
||||
t.Errorf("Expected 0 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 5 {
|
||||
t.Errorf("Expected size 5, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{3, 9, 4, 5, 7}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
if !rb.IsFull() {
|
||||
t.Errorf("Expected buffer to not be full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveDeletesMatchingItems7(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1) // overriden
|
||||
rb.Push(2) // overriden
|
||||
rb.Push(3)
|
||||
rb.Push(9)
|
||||
rb.Push(4)
|
||||
rb.Push(5)
|
||||
rb.Push(7)
|
||||
removed := rb.Remove(func(v int) bool { return v == 9 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 1 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 4 {
|
||||
t.Errorf("Expected size 4, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{3, 4, 5, 7}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
if rb.IsFull() {
|
||||
t.Errorf("Expected buffer to not be full")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferAddItemsToFullRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
rb.Push(4)
|
||||
if rb.Size() != 3 {
|
||||
t.Errorf("Expected size 3, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{2, 3, 4}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferAddItemsToNonFullRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
if rb.Size() != 2 {
|
||||
t.Errorf("Expected size 2, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{1, 2}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveItemsFromNonFullRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
removed := rb.Remove(func(v int) bool { return v == 1 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 1 item removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 1 {
|
||||
t.Errorf("Expected size 1, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{2}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveItemsFromFullRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
removed := rb.Remove(func(v int) bool { return v == 2 })
|
||||
if removed != 1 {
|
||||
t.Errorf("Expected 1 item removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 2 {
|
||||
t.Errorf("Expected size 2, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{1, 3}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveMultipleItemsFromRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](5)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
rb.Push(2)
|
||||
rb.Push(4)
|
||||
removed := rb.Remove(func(v int) bool { return v == 2 })
|
||||
if removed != 2 {
|
||||
t.Errorf("Expected 2 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 3 {
|
||||
t.Errorf("Expected size 3, got %d", rb.Size())
|
||||
}
|
||||
expected := []int{1, 3, 4}
|
||||
items := rb.Items()
|
||||
for i, item := range items {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveAllItemsFromRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
removed := rb.Remove(func(v int) bool { return true })
|
||||
if removed != 3 {
|
||||
t.Errorf("Expected 3 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 0 {
|
||||
t.Errorf("Expected size 0, got %d", rb.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferRemoveNoItemsFromRingBuffer(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
removed := rb.Remove(func(v int) bool { return false })
|
||||
if removed != 0 {
|
||||
t.Errorf("Expected 0 items removed, got %d", removed)
|
||||
}
|
||||
if rb.Size() != 3 {
|
||||
t.Errorf("Expected size 3, got %d", rb.Size())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferIteratesOverAllItems(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
expected := []int{1, 2, 3}
|
||||
i := 0
|
||||
for item := range rb.Iter() {
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
i++
|
||||
}
|
||||
if i != len(expected) {
|
||||
t.Errorf("Expected to iterate over %d items, but iterated over %d", len(expected), i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRingBufferIter2IteratesOverAllItemsWithIndices(t *testing.T) {
|
||||
rb := NewRingBuffer[int](3)
|
||||
rb.Push(1)
|
||||
rb.Push(2)
|
||||
rb.Push(3)
|
||||
expected := []int{1, 2, 3}
|
||||
i := 0
|
||||
for index, item := range rb.Iter2() {
|
||||
if index != i {
|
||||
t.Errorf("Expected index %d, got %d", i, index)
|
||||
}
|
||||
if item != expected[i] {
|
||||
t.Errorf("Expected item %d, got %d", expected[i], item)
|
||||
}
|
||||
i++
|
||||
}
|
||||
if i != len(expected) {
|
||||
t.Errorf("Expected to iterate over %d items, but iterated over %d", len(expected), i)
|
||||
}
|
||||
}
|
||||
+2
-3
@@ -2,7 +2,6 @@ package dataext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -64,7 +63,7 @@ func (s *Stack[T]) OptPop() *T {
|
||||
result := s.data[l-1]
|
||||
s.data = s.data[:l-1]
|
||||
|
||||
return langext.Ptr(result)
|
||||
return new(result)
|
||||
}
|
||||
|
||||
func (s *Stack[T]) Peek() (T, error) {
|
||||
@@ -94,7 +93,7 @@ func (s *Stack[T]) OptPeek() *T {
|
||||
return nil
|
||||
}
|
||||
|
||||
return langext.Ptr(s.data[l-1])
|
||||
return new(s.data[l-1])
|
||||
}
|
||||
|
||||
func (s *Stack[T]) Length() int {
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStack_PushPop(t *testing.T) {
|
||||
s := NewStack[int](false, 4)
|
||||
s.Push(1)
|
||||
s.Push(2)
|
||||
s.Push(3)
|
||||
|
||||
if s.Length() != 3 {
|
||||
t.Fatalf("Length=%d", s.Length())
|
||||
}
|
||||
if s.Empty() {
|
||||
t.Fatal("should not be empty")
|
||||
}
|
||||
|
||||
v, err := s.Pop()
|
||||
if err != nil || v != 3 {
|
||||
t.Fatalf("Pop got (%d,%v)", v, err)
|
||||
}
|
||||
v, err = s.Pop()
|
||||
if err != nil || v != 2 {
|
||||
t.Fatalf("Pop got (%d,%v)", v, err)
|
||||
}
|
||||
v, err = s.Pop()
|
||||
if err != nil || v != 1 {
|
||||
t.Fatalf("Pop got (%d,%v)", v, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStack_PopEmpty(t *testing.T) {
|
||||
s := NewStack[int](false, 0)
|
||||
_, err := s.Pop()
|
||||
if !errors.Is(err, ErrEmptyStack) {
|
||||
t.Fatalf("expected ErrEmptyStack, got %v", err)
|
||||
}
|
||||
if !s.Empty() {
|
||||
t.Fatal("should be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStack_Peek(t *testing.T) {
|
||||
s := NewStack[string](false, 0)
|
||||
if _, err := s.Peek(); !errors.Is(err, ErrEmptyStack) {
|
||||
t.Fatalf("expected ErrEmptyStack got %v", err)
|
||||
}
|
||||
s.Push("a")
|
||||
s.Push("b")
|
||||
v, err := s.Peek()
|
||||
if err != nil || v != "b" {
|
||||
t.Fatalf("Peek got (%q,%v)", v, err)
|
||||
}
|
||||
if s.Length() != 2 {
|
||||
t.Fatal("Peek must not pop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStack_OptPopOptPeek(t *testing.T) {
|
||||
s := NewStack[int](false, 0)
|
||||
if s.OptPop() != nil {
|
||||
t.Fatal("OptPop on empty should return nil")
|
||||
}
|
||||
if s.OptPeek() != nil {
|
||||
t.Fatal("OptPeek on empty should return nil")
|
||||
}
|
||||
s.Push(7)
|
||||
if p := s.OptPeek(); p == nil || *p != 7 {
|
||||
t.Fatalf("OptPeek bad")
|
||||
}
|
||||
if p := s.OptPop(); p == nil || *p != 7 {
|
||||
t.Fatalf("OptPop bad")
|
||||
}
|
||||
if !s.Empty() {
|
||||
t.Fatal("should be empty after OptPop")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStack_ThreadSafe(t *testing.T) {
|
||||
s := NewStack[int](true, 0)
|
||||
var wg sync.WaitGroup
|
||||
const n = 200
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(v int) {
|
||||
defer wg.Done()
|
||||
s.Push(v)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
if s.Length() != n {
|
||||
t.Fatalf("Length=%d want %d", s.Length(), n)
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"hash"
|
||||
"io"
|
||||
"reflect"
|
||||
@@ -82,7 +82,7 @@ func binarize(writer io.Writer, dat reflect.Value, opt StructHashOptions) error
|
||||
|
||||
err = binary.Write(writer, binary.LittleEndian, uint8(dat.Kind()))
|
||||
switch dat.Kind() {
|
||||
case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice, reflect.Interface:
|
||||
case reflect.Pointer, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice, reflect.Interface:
|
||||
if dat.IsNil() {
|
||||
err = binary.Write(writer, binary.LittleEndian, uint64(0))
|
||||
if err != nil {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -46,7 +45,7 @@ func TestStructHashSimpleStruct(t *testing.T) {
|
||||
tst.AssertHexEqual(t, "5d09090dc34ac59dd645f197a255f653387723de3afa1b614721ea5a081c675f", noErrStructHash(t, t0{
|
||||
F1: 10,
|
||||
F2: []string{"1", "2", "3"},
|
||||
F3: langext.Ptr(99),
|
||||
F3: new(99),
|
||||
}))
|
||||
|
||||
}
|
||||
|
||||
@@ -2,11 +2,18 @@ package dataext
|
||||
|
||||
import "sync"
|
||||
|
||||
// SyncMap is a thread-safe map implementation for generic key-value pairs.
|
||||
// All functions aresafe to be called in parallel.
|
||||
type SyncMap[TKey comparable, TData any] struct {
|
||||
data map[TKey]TData
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewSyncMap[TKey comparable, TData any]() *SyncMap[TKey, TData] {
|
||||
return &SyncMap[TKey, TData]{data: make(map[TKey]TData), lock: sync.Mutex{}}
|
||||
}
|
||||
|
||||
// Set sets the value for the provided key
|
||||
func (s *SyncMap[TKey, TData]) Set(key TKey, data TData) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -18,6 +25,7 @@ func (s *SyncMap[TKey, TData]) Set(key TKey, data TData) {
|
||||
s.data[key] = data
|
||||
}
|
||||
|
||||
// SetIfNotContains sets the value for the provided key if it does not already exist
|
||||
func (s *SyncMap[TKey, TData]) SetIfNotContains(key TKey, data TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -35,6 +43,7 @@ func (s *SyncMap[TKey, TData]) SetIfNotContains(key TKey, data TData) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetIfNotContainsFunc sets the value for the provided key using the provided function
|
||||
func (s *SyncMap[TKey, TData]) SetIfNotContainsFunc(key TKey, data func() TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -52,6 +61,7 @@ func (s *SyncMap[TKey, TData]) SetIfNotContainsFunc(key TKey, data func() TData)
|
||||
return true
|
||||
}
|
||||
|
||||
// Get retrieves the value for the provided key
|
||||
func (s *SyncMap[TKey, TData]) Get(key TKey) (TData, bool) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -67,6 +77,8 @@ func (s *SyncMap[TKey, TData]) Get(key TKey) (TData, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetAndSetIfNotContains returns the existing value if the key exists.
|
||||
// Otherwise, it sets the provided value and returns it.
|
||||
func (s *SyncMap[TKey, TData]) GetAndSetIfNotContains(key TKey, data TData) TData {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -83,6 +95,8 @@ func (s *SyncMap[TKey, TData]) GetAndSetIfNotContains(key TKey, data TData) TDat
|
||||
}
|
||||
}
|
||||
|
||||
// GetAndSetIfNotContainsFunc returns the existing value if the key exists.
|
||||
// Otherwise, it calls the provided function to generate the value, sets it, and returns it.
|
||||
func (s *SyncMap[TKey, TData]) GetAndSetIfNotContainsFunc(key TKey, data func() TData) TData {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -100,6 +114,7 @@ func (s *SyncMap[TKey, TData]) GetAndSetIfNotContainsFunc(key TKey, data func()
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes the entry with the provided key and returns true if the key existed before.
|
||||
func (s *SyncMap[TKey, TData]) Delete(key TKey) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -115,6 +130,70 @@ func (s *SyncMap[TKey, TData]) Delete(key TKey) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteIf deletes all entries that match the provided function and returns the number of removed entries.
|
||||
func (s *SyncMap[TKey, TData]) DeleteIf(fn func(key TKey, data TData) bool) int {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TKey]TData)
|
||||
}
|
||||
|
||||
rm := 0
|
||||
for k, v := range s.data {
|
||||
if fn(k, v) {
|
||||
delete(s.data, k)
|
||||
rm++
|
||||
}
|
||||
}
|
||||
|
||||
return rm
|
||||
}
|
||||
|
||||
// UpdateIfExists updates the value if the key exists, otherwise it does nothing.
|
||||
func (s *SyncMap[TKey, TData]) UpdateIfExists(key TKey, fn func(data TData) TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TKey]TData)
|
||||
}
|
||||
|
||||
if v, ok := s.data[key]; ok {
|
||||
s.data[key] = fn(v)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateOrInsert updates the value if the key exists, otherwise it inserts the provided `insertValue`.
|
||||
func (s *SyncMap[TKey, TData]) UpdateOrInsert(key TKey, fn func(data TData) TData, insertValue TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TKey]TData)
|
||||
}
|
||||
|
||||
if v, ok := s.data[key]; ok {
|
||||
s.data[key] = fn(v)
|
||||
return true
|
||||
} else {
|
||||
s.data[key] = insertValue
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Clear removes all entries from the map.
|
||||
func (s *SyncMap[TKey, TData]) Clear() {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
s.data = make(map[TKey]TData)
|
||||
}
|
||||
|
||||
// Contains checks if the map contains the provided key.
|
||||
func (s *SyncMap[TKey, TData]) Contains(key TKey) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -128,6 +207,7 @@ func (s *SyncMap[TKey, TData]) Contains(key TKey) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetAllKeys returns a copy (!) of all keys in the map.
|
||||
func (s *SyncMap[TKey, TData]) GetAllKeys() []TKey {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -145,6 +225,7 @@ func (s *SyncMap[TKey, TData]) GetAllKeys() []TKey {
|
||||
return r
|
||||
}
|
||||
|
||||
// GetAllValues returns a copy (!) of all values in the map.
|
||||
func (s *SyncMap[TKey, TData]) GetAllValues() []TData {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -161,3 +242,15 @@ func (s *SyncMap[TKey, TData]) GetAllValues() []TData {
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Count returns the number of entries in the map.
|
||||
func (s *SyncMap[TKey, TData]) Count() int {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TKey]TData)
|
||||
}
|
||||
|
||||
return len(s.data)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncMap_SetGet(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
m.Set("a", 1)
|
||||
v, ok := m.Get("a")
|
||||
if !ok || v != 1 {
|
||||
t.Fatalf("got (%d,%v)", v, ok)
|
||||
}
|
||||
if _, ok := m.Get("missing"); ok {
|
||||
t.Fatal("expected missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_SetIfNotContains(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
if !m.SetIfNotContains("a", 1) {
|
||||
t.Fatal("first set should succeed")
|
||||
}
|
||||
if m.SetIfNotContains("a", 2) {
|
||||
t.Fatal("second set should fail")
|
||||
}
|
||||
v, _ := m.Get("a")
|
||||
if v != 1 {
|
||||
t.Fatalf("expected unchanged got %d", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_SetIfNotContainsFunc(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
calls := 0
|
||||
if !m.SetIfNotContainsFunc("a", func() int { calls++; return 5 }) {
|
||||
t.Fatal("first should succeed")
|
||||
}
|
||||
if m.SetIfNotContainsFunc("a", func() int { calls++; return 6 }) {
|
||||
t.Fatal("second should fail")
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("calls=%d want 1", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_GetAndSetIfNotContains(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
if v := m.GetAndSetIfNotContains("a", 10); v != 10 {
|
||||
t.Fatalf("got %d", v)
|
||||
}
|
||||
if v := m.GetAndSetIfNotContains("a", 99); v != 10 {
|
||||
t.Fatalf("got %d", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_GetAndSetIfNotContainsFunc(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
calls := 0
|
||||
if v := m.GetAndSetIfNotContainsFunc("a", func() int { calls++; return 1 }); v != 1 {
|
||||
t.Fatalf("got %d", v)
|
||||
}
|
||||
if v := m.GetAndSetIfNotContainsFunc("a", func() int { calls++; return 2 }); v != 1 {
|
||||
t.Fatalf("got %d", v)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Fatalf("calls=%d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_Delete(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
m.Set("a", 1)
|
||||
if !m.Delete("a") {
|
||||
t.Fatal("delete existing returned false")
|
||||
}
|
||||
if m.Delete("a") {
|
||||
t.Fatal("delete missing returned true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_DeleteIf(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
rm := m.DeleteIf(func(k string, v int) bool { return v%2 == 1 })
|
||||
if rm != 2 {
|
||||
t.Fatalf("removed=%d", rm)
|
||||
}
|
||||
if m.Count() != 1 {
|
||||
t.Fatalf("count=%d", m.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_UpdateIfExists(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
if m.UpdateIfExists("a", func(v int) int { return v + 1 }) {
|
||||
t.Fatal("should be false on missing key")
|
||||
}
|
||||
m.Set("a", 5)
|
||||
if !m.UpdateIfExists("a", func(v int) int { return v + 1 }) {
|
||||
t.Fatal("should be true on existing")
|
||||
}
|
||||
v, _ := m.Get("a")
|
||||
if v != 6 {
|
||||
t.Fatalf("v=%d", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_UpdateOrInsert(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
if m.UpdateOrInsert("a", func(v int) int { return v + 1 }, 100) {
|
||||
t.Fatal("should return false on insert")
|
||||
}
|
||||
if v, _ := m.Get("a"); v != 100 {
|
||||
t.Fatalf("v=%d", v)
|
||||
}
|
||||
if !m.UpdateOrInsert("a", func(v int) int { return v + 1 }, 100) {
|
||||
t.Fatal("should return true on update")
|
||||
}
|
||||
if v, _ := m.Get("a"); v != 101 {
|
||||
t.Fatalf("v=%d", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_ClearContains(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
m.Set("a", 1)
|
||||
if !m.Contains("a") {
|
||||
t.Fatal("Contains should be true")
|
||||
}
|
||||
m.Clear()
|
||||
if m.Contains("a") {
|
||||
t.Fatal("after Clear should be false")
|
||||
}
|
||||
if m.Count() != 0 {
|
||||
t.Fatalf("count=%d", m.Count())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_GetAllKeysValues(t *testing.T) {
|
||||
m := NewSyncMap[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
keys := m.GetAllKeys()
|
||||
sort.Strings(keys)
|
||||
if len(keys) != 3 || keys[0] != "a" || keys[2] != "c" {
|
||||
t.Fatalf("keys=%v", keys)
|
||||
}
|
||||
vals := m.GetAllValues()
|
||||
sort.Ints(vals)
|
||||
if len(vals) != 3 || vals[0] != 1 || vals[2] != 3 {
|
||||
t.Fatalf("vals=%v", vals)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncMap_Concurrent(t *testing.T) {
|
||||
m := NewSyncMap[int, int]()
|
||||
var wg sync.WaitGroup
|
||||
const n = 200
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(k int) {
|
||||
defer wg.Done()
|
||||
m.Set(k, k*2)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
if m.Count() != n {
|
||||
t.Fatalf("count=%d want %d", m.Count(), n)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package dataext
|
||||
|
||||
import "sync"
|
||||
|
||||
type SyncRingSet[TData comparable] struct {
|
||||
data map[TData]bool
|
||||
lock sync.Mutex
|
||||
ring *RingBuffer[TData]
|
||||
}
|
||||
|
||||
func NewSyncRingSet[TData comparable](capacity int) *SyncRingSet[TData] {
|
||||
return &SyncRingSet[TData]{
|
||||
data: make(map[TData]bool, capacity+1),
|
||||
lock: sync.Mutex{},
|
||||
ring: NewRingBuffer[TData](capacity),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds `value` to the set
|
||||
// returns true if the value was actually inserted (value did not exist beforehand)
|
||||
// returns false if the value already existed
|
||||
func (s *SyncRingSet[TData]) Add(value TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
_, existsInPreState := s.data[value]
|
||||
if existsInPreState {
|
||||
return false
|
||||
}
|
||||
|
||||
prev := s.ring.PushPop(value)
|
||||
|
||||
s.data[value] = true
|
||||
if prev != nil {
|
||||
delete(s.data, *prev)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SyncRingSet[TData]) AddAll(values []TData) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
_, existsInPreState := s.data[value]
|
||||
if existsInPreState {
|
||||
continue
|
||||
}
|
||||
|
||||
prev := s.ring.PushPop(value)
|
||||
|
||||
s.data[value] = true
|
||||
if prev != nil {
|
||||
delete(s.data, *prev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncRingSet[TData]) Remove(value TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
_, existsInPreState := s.data[value]
|
||||
if !existsInPreState {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(s.data, value)
|
||||
s.ring.Remove(func(v TData) bool { return value == v })
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SyncRingSet[TData]) RemoveAll(values []TData) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
delete(s.data, value)
|
||||
s.ring.Remove(func(v TData) bool { return value == v })
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncRingSet[TData]) Contains(value TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
_, ok := s.data[value]
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *SyncRingSet[TData]) Get() []TData {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
r := make([]TData, 0, len(s.data))
|
||||
|
||||
for k := range s.data {
|
||||
r = append(r, k)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddIfNotContains
|
||||
// returns true if the value was actually added (value did not exist beforehand)
|
||||
// returns false if the value already existed
|
||||
func (s *SyncRingSet[TData]) AddIfNotContains(key TData) bool {
|
||||
return s.Add(key)
|
||||
}
|
||||
|
||||
// RemoveIfContains
|
||||
// returns true if the value was actually removed (value did exist beforehand)
|
||||
// returns false if the value did not exist in the set
|
||||
func (s *SyncRingSet[TData]) RemoveIfContains(key TData) bool {
|
||||
return s.Remove(key)
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncRingSet_AddAndContains(t *testing.T) {
|
||||
s := NewSyncRingSet[int](3)
|
||||
if !s.Add(1) {
|
||||
t.Fatal("first Add(1) should be true")
|
||||
}
|
||||
if s.Add(1) {
|
||||
t.Fatal("duplicate Add(1) should be false")
|
||||
}
|
||||
if !s.Contains(1) {
|
||||
t.Fatal("expected Contains(1)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRingSet_CapacityEvicts(t *testing.T) {
|
||||
s := NewSyncRingSet[int](3)
|
||||
s.Add(1)
|
||||
s.Add(2)
|
||||
s.Add(3)
|
||||
s.Add(4) // should evict the oldest (1)
|
||||
if s.Contains(1) {
|
||||
t.Fatal("1 should have been evicted")
|
||||
}
|
||||
for _, v := range []int{2, 3, 4} {
|
||||
if !s.Contains(v) {
|
||||
t.Fatalf("expected %d", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRingSet_Remove(t *testing.T) {
|
||||
s := NewSyncRingSet[string](3)
|
||||
s.Add("a")
|
||||
s.Add("b")
|
||||
if !s.Remove("a") {
|
||||
t.Fatal("remove existing failed")
|
||||
}
|
||||
if s.Remove("a") {
|
||||
t.Fatal("remove missing returned true")
|
||||
}
|
||||
if s.Contains("a") {
|
||||
t.Fatal("a should be gone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRingSet_AddAllRemoveAll(t *testing.T) {
|
||||
s := NewSyncRingSet[int](10)
|
||||
s.AddAll([]int{1, 2, 3, 2})
|
||||
out := s.Get()
|
||||
sort.Ints(out)
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("got %v", out)
|
||||
}
|
||||
|
||||
s.RemoveAll([]int{1, 99})
|
||||
if s.Contains(1) {
|
||||
t.Fatal("1 should be removed")
|
||||
}
|
||||
if !s.Contains(2) || !s.Contains(3) {
|
||||
t.Fatal("2/3 should remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncRingSet_AddIfNotContainsRemoveIfContains(t *testing.T) {
|
||||
s := NewSyncRingSet[string](5)
|
||||
if !s.AddIfNotContains("x") {
|
||||
t.Fatal("first should succeed")
|
||||
}
|
||||
if s.AddIfNotContains("x") {
|
||||
t.Fatal("second should fail")
|
||||
}
|
||||
if !s.RemoveIfContains("x") {
|
||||
t.Fatal("remove existing failed")
|
||||
}
|
||||
if s.RemoveIfContains("x") {
|
||||
t.Fatal("remove missing returned true")
|
||||
}
|
||||
}
|
||||
+54
-3
@@ -7,8 +7,12 @@ type SyncSet[TData comparable] struct {
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewSyncSet[TData comparable]() *SyncSet[TData] {
|
||||
return &SyncSet[TData]{data: make(map[TData]bool), lock: sync.Mutex{}}
|
||||
}
|
||||
|
||||
// Add adds `value` to the set
|
||||
// returns true if the value was actually inserted
|
||||
// returns true if the value was actually inserted (value did not exist beforehand)
|
||||
// returns false if the value already existed
|
||||
func (s *SyncSet[TData]) Add(value TData) bool {
|
||||
s.lock.Lock()
|
||||
@@ -19,9 +23,12 @@ func (s *SyncSet[TData]) Add(value TData) bool {
|
||||
}
|
||||
|
||||
_, existsInPreState := s.data[value]
|
||||
s.data[value] = true
|
||||
if existsInPreState {
|
||||
return false
|
||||
}
|
||||
|
||||
return !existsInPreState
|
||||
s.data[value] = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SyncSet[TData]) AddAll(values []TData) {
|
||||
@@ -37,6 +44,36 @@ func (s *SyncSet[TData]) AddAll(values []TData) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncSet[TData]) Remove(value TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
_, existsInPreState := s.data[value]
|
||||
if !existsInPreState {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(s.data, value)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SyncSet[TData]) RemoveAll(values []TData) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.data == nil {
|
||||
s.data = make(map[TData]bool)
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
delete(s.data, value)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncSet[TData]) Contains(value TData) bool {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
@@ -66,3 +103,17 @@ func (s *SyncSet[TData]) Get() []TData {
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AddIfNotContains
|
||||
// returns true if the value was actually added (value did not exist beforehand)
|
||||
// returns false if the value already existed
|
||||
func (s *SyncSet[TData]) AddIfNotContains(key TData) bool {
|
||||
return s.Add(key)
|
||||
}
|
||||
|
||||
// RemoveIfContains
|
||||
// returns true if the value was actually removed (value did exist beforehand)
|
||||
// returns false if the value did not exist in the set
|
||||
func (s *SyncSet[TData]) RemoveIfContains(key TData) bool {
|
||||
return s.Remove(key)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSyncSet_Add(t *testing.T) {
|
||||
s := NewSyncSet[string]()
|
||||
if !s.Add("a") {
|
||||
t.Fatal("first add should be true")
|
||||
}
|
||||
if s.Add("a") {
|
||||
t.Fatal("duplicate add should be false")
|
||||
}
|
||||
if !s.Contains("a") {
|
||||
t.Fatal("Contains a should be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncSet_AddAll(t *testing.T) {
|
||||
s := NewSyncSet[int]()
|
||||
s.AddAll([]int{1, 2, 3, 2})
|
||||
if !s.Contains(1) || !s.Contains(2) || !s.Contains(3) {
|
||||
t.Fatal("missing items")
|
||||
}
|
||||
if len(s.Get()) != 3 {
|
||||
t.Fatalf("got len %d", len(s.Get()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncSet_Remove(t *testing.T) {
|
||||
s := NewSyncSet[string]()
|
||||
s.Add("a")
|
||||
if !s.Remove("a") {
|
||||
t.Fatal("remove existing failed")
|
||||
}
|
||||
if s.Remove("a") {
|
||||
t.Fatal("remove missing returned true")
|
||||
}
|
||||
if s.Contains("a") {
|
||||
t.Fatal("still contains after remove")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncSet_RemoveAll(t *testing.T) {
|
||||
s := NewSyncSet[int]()
|
||||
s.AddAll([]int{1, 2, 3})
|
||||
s.RemoveAll([]int{1, 2, 99})
|
||||
if s.Contains(1) || s.Contains(2) {
|
||||
t.Fatal("should be removed")
|
||||
}
|
||||
if !s.Contains(3) {
|
||||
t.Fatal("3 should remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncSet_Get(t *testing.T) {
|
||||
s := NewSyncSet[int]()
|
||||
s.AddAll([]int{3, 1, 2})
|
||||
out := s.Get()
|
||||
sort.Ints(out)
|
||||
if len(out) != 3 || out[0] != 1 || out[2] != 3 {
|
||||
t.Fatalf("out=%v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncSet_AddIfNotContainsRemoveIfContains(t *testing.T) {
|
||||
s := NewSyncSet[string]()
|
||||
if !s.AddIfNotContains("x") {
|
||||
t.Fatal("first AddIfNotContains failed")
|
||||
}
|
||||
if s.AddIfNotContains("x") {
|
||||
t.Fatal("second AddIfNotContains succeeded")
|
||||
}
|
||||
if !s.RemoveIfContains("x") {
|
||||
t.Fatal("RemoveIfContains failed")
|
||||
}
|
||||
if s.RemoveIfContains("x") {
|
||||
t.Fatal("RemoveIfContains on missing succeeded")
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,14 @@ func (s Single[T1]) TupleValues() []any {
|
||||
return []any{s.V1}
|
||||
}
|
||||
|
||||
func NewSingle[T1 any](v1 T1) Single[T1] {
|
||||
return Single[T1]{V1: v1}
|
||||
}
|
||||
|
||||
func NewTuple1[T1 any](v1 T1) Single[T1] {
|
||||
return Single[T1]{V1: v1}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Tuple[T1 any, T2 any] struct {
|
||||
@@ -34,6 +42,14 @@ func (t Tuple[T1, T2]) TupleValues() []any {
|
||||
return []any{t.V1, t.V2}
|
||||
}
|
||||
|
||||
func NewTuple[T1 any, T2 any](v1 T1, v2 T2) Tuple[T1, T2] {
|
||||
return Tuple[T1, T2]{V1: v1, V2: v2}
|
||||
}
|
||||
|
||||
func NewTuple2[T1 any, T2 any](v1 T1, v2 T2) Tuple[T1, T2] {
|
||||
return Tuple[T1, T2]{V1: v1, V2: v2}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Triple[T1 any, T2 any, T3 any] struct {
|
||||
@@ -50,6 +66,14 @@ func (t Triple[T1, T2, T3]) TupleValues() []any {
|
||||
return []any{t.V1, t.V2, t.V3}
|
||||
}
|
||||
|
||||
func NewTriple[T1 any, T2 any, T3 any](v1 T1, v2 T2, v3 T3) Triple[T1, T2, T3] {
|
||||
return Triple[T1, T2, T3]{V1: v1, V2: v2, V3: v3}
|
||||
}
|
||||
|
||||
func NewTuple3[T1 any, T2 any, T3 any](v1 T1, v2 T2, v3 T3) Triple[T1, T2, T3] {
|
||||
return Triple[T1, T2, T3]{V1: v1, V2: v2, V3: v3}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Quadruple[T1 any, T2 any, T3 any, T4 any] struct {
|
||||
@@ -67,6 +91,14 @@ func (t Quadruple[T1, T2, T3, T4]) TupleValues() []any {
|
||||
return []any{t.V1, t.V2, t.V3, t.V4}
|
||||
}
|
||||
|
||||
func NewQuadruple[T1 any, T2 any, T3 any, T4 any](v1 T1, v2 T2, v3 T3, v4 T4) Quadruple[T1, T2, T3, T4] {
|
||||
return Quadruple[T1, T2, T3, T4]{V1: v1, V2: v2, V3: v3, V4: v4}
|
||||
}
|
||||
|
||||
func NewTuple4[T1 any, T2 any, T3 any, T4 any](v1 T1, v2 T2, v3 T3, v4 T4) Quadruple[T1, T2, T3, T4] {
|
||||
return Quadruple[T1, T2, T3, T4]{V1: v1, V2: v2, V3: v3, V4: v4}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Quintuple[T1 any, T2 any, T3 any, T4 any, T5 any] struct {
|
||||
@@ -86,6 +118,14 @@ func (t Quintuple[T1, T2, T3, T4, T5]) TupleValues() []any {
|
||||
|
||||
}
|
||||
|
||||
func NewQuintuple[T1 any, T2 any, T3 any, T4 any, T5 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5) Quintuple[T1, T2, T3, T4, T5] {
|
||||
return Quintuple[T1, T2, T3, T4, T5]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5}
|
||||
}
|
||||
|
||||
func NewTuple5[T1 any, T2 any, T3 any, T4 any, T5 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5) Quintuple[T1, T2, T3, T4, T5] {
|
||||
return Quintuple[T1, T2, T3, T4, T5]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Sextuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any] struct {
|
||||
@@ -106,6 +146,14 @@ func (t Sextuple[T1, T2, T3, T4, T5, T6]) TupleValues() []any {
|
||||
|
||||
}
|
||||
|
||||
func NewSextuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6) Sextuple[T1, T2, T3, T4, T5, T6] {
|
||||
return Sextuple[T1, T2, T3, T4, T5, T6]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6}
|
||||
}
|
||||
|
||||
func NewTuple6[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6) Sextuple[T1, T2, T3, T4, T5, T6] {
|
||||
return Sextuple[T1, T2, T3, T4, T5, T6]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Septuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any] struct {
|
||||
@@ -126,6 +174,14 @@ func (t Septuple[T1, T2, T3, T4, T5, T6, T7]) TupleValues() []any {
|
||||
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6, t.V7}
|
||||
}
|
||||
|
||||
func NewSeptuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7) Septuple[T1, T2, T3, T4, T5, T6, T7] {
|
||||
return Septuple[T1, T2, T3, T4, T5, T6, T7]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6, V7: v7}
|
||||
}
|
||||
|
||||
func NewTuple7[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7) Septuple[T1, T2, T3, T4, T5, T6, T7] {
|
||||
return Septuple[T1, T2, T3, T4, T5, T6, T7]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6, V7: v7}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Octuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any] struct {
|
||||
@@ -147,6 +203,14 @@ func (t Octuple[T1, T2, T3, T4, T5, T6, T7, T8]) TupleValues() []any {
|
||||
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6, t.V7, t.V8}
|
||||
}
|
||||
|
||||
func NewOctuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7, v8 T8) Octuple[T1, T2, T3, T4, T5, T6, T7, T8] {
|
||||
return Octuple[T1, T2, T3, T4, T5, T6, T7, T8]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6, V7: v7, V8: v8}
|
||||
}
|
||||
|
||||
func NewTuple8[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7, v8 T8) Octuple[T1, T2, T3, T4, T5, T6, T7, T8] {
|
||||
return Octuple[T1, T2, T3, T4, T5, T6, T7, T8]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6, V7: v7, V8: v8}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Nonuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any, T9 any] struct {
|
||||
@@ -168,3 +232,10 @@ func (t Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]) TupleLength() int {
|
||||
func (t Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]) TupleValues() []any {
|
||||
return []any{t.V1, t.V2, t.V3, t.V4, t.V5, t.V6, t.V7, t.V8, t.V9}
|
||||
}
|
||||
|
||||
func NewNonuple[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any, T9 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7, v8 T8, v9 T9) Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9] {
|
||||
return Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6, V7: v7, V8: v8, V9: v9}
|
||||
}
|
||||
func NewTuple9[T1 any, T2 any, T3 any, T4 any, T5 any, T6 any, T7 any, T8 any, T9 any](v1 T1, v2 T2, v3 T3, v4 T4, v5 T5, v6 T6, v7 T7, v8 T8, v9 T9) Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9] {
|
||||
return Nonuple[T1, T2, T3, T4, T5, T6, T7, T8, T9]{V1: v1, V2: v2, V3: v3, V4: v4, V5: v5, V6: v6, V7: v7, V8: v8, V9: v9}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
package dataext
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSingle(t *testing.T) {
|
||||
s := NewSingle[int](7)
|
||||
if s.V1 != 7 {
|
||||
t.Fatalf("V1=%d", s.V1)
|
||||
}
|
||||
if s.TupleLength() != 1 {
|
||||
t.Fatalf("len=%d", s.TupleLength())
|
||||
}
|
||||
if !reflect.DeepEqual(s.TupleValues(), []any{7}) {
|
||||
t.Fatalf("values=%v", s.TupleValues())
|
||||
}
|
||||
if NewTuple1[int](7).V1 != 7 {
|
||||
t.Fatal("NewTuple1 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTuple(t *testing.T) {
|
||||
tp := NewTuple[int, string](1, "two")
|
||||
if tp.V1 != 1 || tp.V2 != "two" {
|
||||
t.Fatal("values wrong")
|
||||
}
|
||||
if tp.TupleLength() != 2 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(tp.TupleValues(), []any{1, "two"}) {
|
||||
t.Fatalf("values=%v", tp.TupleValues())
|
||||
}
|
||||
if NewTuple2[int, string](1, "two") != tp {
|
||||
t.Fatal("NewTuple2 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriple(t *testing.T) {
|
||||
tr := NewTriple[int, string, bool](1, "x", true)
|
||||
if tr.TupleLength() != 3 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(tr.TupleValues(), []any{1, "x", true}) {
|
||||
t.Fatalf("values=%v", tr.TupleValues())
|
||||
}
|
||||
if NewTuple3[int, string, bool](1, "x", true) != tr {
|
||||
t.Fatal("NewTuple3 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuadruple(t *testing.T) {
|
||||
q := NewQuadruple[int, int, int, int](1, 2, 3, 4)
|
||||
if q.TupleLength() != 4 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(q.TupleValues(), []any{1, 2, 3, 4}) {
|
||||
t.Fatalf("values=%v", q.TupleValues())
|
||||
}
|
||||
if NewTuple4[int, int, int, int](1, 2, 3, 4) != q {
|
||||
t.Fatal("NewTuple4 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuintuple(t *testing.T) {
|
||||
q := NewQuintuple[int, int, int, int, int](1, 2, 3, 4, 5)
|
||||
if q.TupleLength() != 5 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(q.TupleValues(), []any{1, 2, 3, 4, 5}) {
|
||||
t.Fatalf("values=%v", q.TupleValues())
|
||||
}
|
||||
if NewTuple5[int, int, int, int, int](1, 2, 3, 4, 5) != q {
|
||||
t.Fatal("NewTuple5 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSextuple(t *testing.T) {
|
||||
s := NewSextuple[int, int, int, int, int, int](1, 2, 3, 4, 5, 6)
|
||||
if s.TupleLength() != 6 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(s.TupleValues(), []any{1, 2, 3, 4, 5, 6}) {
|
||||
t.Fatalf("values=%v", s.TupleValues())
|
||||
}
|
||||
if NewTuple6[int, int, int, int, int, int](1, 2, 3, 4, 5, 6) != s {
|
||||
t.Fatal("NewTuple6 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeptuple(t *testing.T) {
|
||||
s := NewSeptuple[int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7)
|
||||
if s.TupleLength() != 7 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(s.TupleValues(), []any{1, 2, 3, 4, 5, 6, 7}) {
|
||||
t.Fatalf("values=%v", s.TupleValues())
|
||||
}
|
||||
if NewTuple7[int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7) != s {
|
||||
t.Fatal("NewTuple7 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOctuple(t *testing.T) {
|
||||
o := NewOctuple[int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8)
|
||||
if o.TupleLength() != 8 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(o.TupleValues(), []any{1, 2, 3, 4, 5, 6, 7, 8}) {
|
||||
t.Fatalf("values=%v", o.TupleValues())
|
||||
}
|
||||
if NewTuple8[int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8) != o {
|
||||
t.Fatal("NewTuple8 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNonuple(t *testing.T) {
|
||||
n := NewNonuple[int, int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8, 9)
|
||||
if n.TupleLength() != 9 {
|
||||
t.Fatal("len wrong")
|
||||
}
|
||||
if !reflect.DeepEqual(n.TupleValues(), []any{1, 2, 3, 4, 5, 6, 7, 8, 9}) {
|
||||
t.Fatalf("values=%v", n.TupleValues())
|
||||
}
|
||||
if NewTuple9[int, int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8, 9) != n {
|
||||
t.Fatal("NewTuple9 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueGroupInterface(t *testing.T) {
|
||||
var vg ValueGroup = NewTuple[int, string](1, "a")
|
||||
if vg.TupleLength() != 2 {
|
||||
t.Fatal("interface length wrong")
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,9 @@
|
||||
package enums
|
||||
|
||||
import "maps"
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Enum interface {
|
||||
Valid() bool
|
||||
ValuesAny() []any
|
||||
@@ -31,3 +35,23 @@ type EnumDescriptionMetaValue struct {
|
||||
Value Enum `json:"value"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type EnumDataMetaValue struct {
|
||||
VarName string `json:"varName"`
|
||||
Value Enum `json:"value"`
|
||||
Description *string `json:"description"`
|
||||
|
||||
Data map[string]any `json:"-"` //handled by MarshalJSON
|
||||
}
|
||||
|
||||
func (v EnumDataMetaValue) MarshalJSON() ([]byte, error) {
|
||||
m := make(map[string]any, 8)
|
||||
|
||||
maps.Copy(m, v.Data)
|
||||
|
||||
m["varName"] = v.VarName
|
||||
m["value"] = v.Value
|
||||
m["description"] = v.Description
|
||||
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
package enums
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockEnum struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (m mockEnum) Valid() bool { return m.name != "" }
|
||||
func (m mockEnum) ValuesAny() []any { return []any{mockEnum{name: "a"}, mockEnum{name: "b"}} }
|
||||
func (m mockEnum) ValuesMeta() []EnumMetaValue { return nil }
|
||||
func (m mockEnum) VarName() string { return m.name }
|
||||
func (m mockEnum) TypeName() string { return "mockEnum" }
|
||||
func (m mockEnum) PackageName() string { return "enums_test" }
|
||||
func (m mockEnum) String() string { return "str:" + m.name }
|
||||
func (m mockEnum) Description() string { return "desc:" + m.name }
|
||||
func (m mockEnum) DescriptionMeta() EnumDescriptionMetaValue {
|
||||
return EnumDescriptionMetaValue{VarName: m.name, Value: m, Description: "desc:" + m.name}
|
||||
}
|
||||
|
||||
func (m mockEnum) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(m.name)
|
||||
}
|
||||
|
||||
func TestMockEnumImplementsInterfaces(t *testing.T) {
|
||||
var _ Enum = mockEnum{}
|
||||
var _ StringEnum = mockEnum{}
|
||||
var _ DescriptionEnum = mockEnum{}
|
||||
}
|
||||
|
||||
func TestEnumValid(t *testing.T) {
|
||||
if !(mockEnum{name: "x"}).Valid() {
|
||||
t.Errorf("expected Valid() == true")
|
||||
}
|
||||
if (mockEnum{}).Valid() {
|
||||
t.Errorf("expected Valid() == false for zero value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumMetaValueJSON(t *testing.T) {
|
||||
desc := "the-description"
|
||||
mv := EnumMetaValue{
|
||||
VarName: "Foo",
|
||||
Value: mockEnum{name: "foo"},
|
||||
Description: &desc,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if got["varName"] != "Foo" {
|
||||
t.Errorf("varName == %v, want Foo", got["varName"])
|
||||
}
|
||||
if got["value"] != "foo" {
|
||||
t.Errorf("value == %v, want foo", got["value"])
|
||||
}
|
||||
if got["description"] != "the-description" {
|
||||
t.Errorf("description == %v, want the-description", got["description"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumMetaValueJSONNilDescription(t *testing.T) {
|
||||
mv := EnumMetaValue{
|
||||
VarName: "Foo",
|
||||
Value: mockEnum{name: "foo"},
|
||||
Description: nil,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if got["description"] != nil {
|
||||
t.Errorf("description == %v, want nil", got["description"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumDescriptionMetaValueJSON(t *testing.T) {
|
||||
mv := EnumDescriptionMetaValue{
|
||||
VarName: "Bar",
|
||||
Value: mockEnum{name: "bar"},
|
||||
Description: "bar-desc",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
expected := map[string]any{
|
||||
"varName": "Bar",
|
||||
"value": "bar",
|
||||
"description": "bar-desc",
|
||||
}
|
||||
if !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("json output == %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumDataMetaValueMarshalJSON(t *testing.T) {
|
||||
desc := "data-desc"
|
||||
mv := EnumDataMetaValue{
|
||||
VarName: "Baz",
|
||||
Value: mockEnum{name: "baz"},
|
||||
Description: &desc,
|
||||
Data: map[string]any{
|
||||
"extra1": "hello",
|
||||
"extra2": float64(42),
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if got["varName"] != "Baz" {
|
||||
t.Errorf("varName == %v, want Baz", got["varName"])
|
||||
}
|
||||
if got["value"] != "baz" {
|
||||
t.Errorf("value == %v, want baz", got["value"])
|
||||
}
|
||||
if got["description"] != "data-desc" {
|
||||
t.Errorf("description == %v, want data-desc", got["description"])
|
||||
}
|
||||
if got["extra1"] != "hello" {
|
||||
t.Errorf("extra1 == %v, want hello", got["extra1"])
|
||||
}
|
||||
if got["extra2"] != float64(42) {
|
||||
t.Errorf("extra2 == %v, want 42", got["extra2"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumDataMetaValueMarshalJSONNilData(t *testing.T) {
|
||||
mv := EnumDataMetaValue{
|
||||
VarName: "Baz",
|
||||
Value: mockEnum{name: "baz"},
|
||||
Description: nil,
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if got["varName"] != "Baz" {
|
||||
t.Errorf("varName == %v, want Baz", got["varName"])
|
||||
}
|
||||
if got["value"] != "baz" {
|
||||
t.Errorf("value == %v, want baz", got["value"])
|
||||
}
|
||||
if _, ok := got["description"]; !ok {
|
||||
t.Errorf("description key missing in JSON output")
|
||||
}
|
||||
if got["description"] != nil {
|
||||
t.Errorf("description == %v, want nil", got["description"])
|
||||
}
|
||||
if len(got) != 3 {
|
||||
t.Errorf("expected 3 keys with nil Data, got %d: %v", len(got), got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumDataMetaValueMarshalJSONDataDoesNotOverrideStandardFields(t *testing.T) {
|
||||
desc := "real-desc"
|
||||
mv := EnumDataMetaValue{
|
||||
VarName: "Real",
|
||||
Value: mockEnum{name: "real"},
|
||||
Description: &desc,
|
||||
Data: map[string]any{
|
||||
"varName": "ShouldBeOverwritten",
|
||||
"value": "ShouldBeOverwritten",
|
||||
"description": "ShouldBeOverwritten",
|
||||
"keep": "kept",
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if got["varName"] != "Real" {
|
||||
t.Errorf("varName == %v, want Real (standard field must override Data)", got["varName"])
|
||||
}
|
||||
if got["value"] != "real" {
|
||||
t.Errorf("value == %v, want real (standard field must override Data)", got["value"])
|
||||
}
|
||||
if got["description"] != "real-desc" {
|
||||
t.Errorf("description == %v, want real-desc (standard field must override Data)", got["description"])
|
||||
}
|
||||
if got["keep"] != "kept" {
|
||||
t.Errorf("keep == %v, want kept", got["keep"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumDataMetaValueMarshalJSONEmptyData(t *testing.T) {
|
||||
mv := EnumDataMetaValue{
|
||||
VarName: "E",
|
||||
Value: mockEnum{name: "e"},
|
||||
Description: nil,
|
||||
Data: map[string]any{},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(mv)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
var got map[string]any
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
if got["varName"] != "E" {
|
||||
t.Errorf("varName == %v, want E", got["varName"])
|
||||
}
|
||||
if got["value"] != "e" {
|
||||
t.Errorf("value == %v, want e", got["value"])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
package excelext
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/dataext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/exerr"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
excelize360 "github.com/360EntSecGroup-Skylar/excelize"
|
||||
"github.com/xuri/excelize/v2"
|
||||
)
|
||||
|
||||
type excelMapperColDefinition[T any] struct {
|
||||
style *int
|
||||
header string
|
||||
width *float64
|
||||
fn func(T) (any, error)
|
||||
}
|
||||
|
||||
type ExcelMapper[T any] struct {
|
||||
StyleDate *int
|
||||
StyleDatetime *int
|
||||
StyleEUR *int
|
||||
StylePercentage *int
|
||||
StyleHeader *int
|
||||
StyleWSHeader *int
|
||||
|
||||
SkipColumnHeader bool
|
||||
|
||||
sheetName string
|
||||
wsHeader []dataext.Tuple[string, *int]
|
||||
colDefinitions []excelMapperColDefinition[T]
|
||||
colFilter []func(v T) bool
|
||||
}
|
||||
|
||||
func NewExcelMapper[T any]() (*ExcelMapper[T], error) {
|
||||
|
||||
em := &ExcelMapper[T]{
|
||||
StyleDate: nil,
|
||||
StyleDatetime: nil,
|
||||
StyleEUR: nil,
|
||||
StylePercentage: nil,
|
||||
StyleHeader: nil,
|
||||
StyleWSHeader: nil,
|
||||
sheetName: "",
|
||||
|
||||
SkipColumnHeader: false,
|
||||
|
||||
wsHeader: make([]dataext.Tuple[string, *int], 0),
|
||||
colDefinitions: make([]excelMapperColDefinition[T], 0),
|
||||
}
|
||||
|
||||
return em, nil
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) InitNewFile(sheetName string) (*excelize.File, error) {
|
||||
f := excelize.NewFile()
|
||||
|
||||
defSheet := f.GetSheetList()[0]
|
||||
|
||||
sheet1 := sheetName
|
||||
|
||||
sheetIdx, err := f.NewSheet(sheet1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.SetActiveSheet(sheetIdx)
|
||||
err = f.DeleteSheet(defSheet)
|
||||
|
||||
err = em.InitStyles(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) InitStyles(f *excelize.File) error {
|
||||
styleDate, err := f.NewStyle(&excelize.Style{
|
||||
CustomNumFmt: new("dd.mm.yyyy"),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
styleDatetime, err := f.NewStyle(&excelize.Style{
|
||||
NumFmt: 22,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
styleEUR, err := f.NewStyle(&excelize.Style{
|
||||
NumFmt: 218,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stylePercentage, err := f.NewStyle(&excelize.Style{
|
||||
NumFmt: 10,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
styleHeader, err := f.NewStyle(&excelize.Style{
|
||||
Font: &excelize.Font{Bold: true, Size: 11},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
styleWSHeader, err := f.NewStyle(&excelize.Style{
|
||||
Font: &excelize.Font{Bold: true, Size: 24},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
em.StyleDate = &styleDate
|
||||
em.StyleDatetime = &styleDatetime
|
||||
em.StyleEUR = &styleEUR
|
||||
em.StylePercentage = &stylePercentage
|
||||
em.StyleHeader = &styleHeader
|
||||
em.StyleWSHeader = &styleWSHeader
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) AddWorksheetHeader(header string, style *int) {
|
||||
em.wsHeader = append(em.wsHeader, dataext.NewTuple(header, style))
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) AddColumn(header string, style *int, width *float64, fn func(T) any) {
|
||||
em.colDefinitions = append(em.colDefinitions, excelMapperColDefinition[T]{
|
||||
style: style,
|
||||
header: header,
|
||||
width: width,
|
||||
fn: func(t T) (any, error) { return fn(t), nil },
|
||||
})
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) AddColumnErr(header string, style *int, width *float64, fn func(T) (any, error)) {
|
||||
em.colDefinitions = append(em.colDefinitions, excelMapperColDefinition[T]{
|
||||
style: style,
|
||||
header: header,
|
||||
width: width,
|
||||
fn: fn,
|
||||
})
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) Build(sheetName string, data []T) ([]byte, error) {
|
||||
f, err := em.InitNewFile(sheetName)
|
||||
if err != nil {
|
||||
return nil, exerr.Wrap(err, "failed to init new file").Build()
|
||||
}
|
||||
|
||||
err = em.BuildSingleSheet(f, sheetName, data)
|
||||
if err != nil {
|
||||
return nil, exerr.Wrap(err, "").Build()
|
||||
}
|
||||
|
||||
buffer, err := f.WriteToBuffer()
|
||||
if err != nil {
|
||||
return nil, exerr.Wrap(err, "failed to build xls").Build()
|
||||
}
|
||||
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) BuildSingleSheet(f *excelize.File, sheetName string, data []T) error {
|
||||
if em.StyleHeader == nil || em.StyleDate == nil || em.StyleDatetime == nil || em.StyleEUR == nil || em.StylePercentage == nil || em.StyleWSHeader == nil {
|
||||
err := em.InitStyles(f)
|
||||
if err != nil {
|
||||
return exerr.Wrap(err, "failed to init styles").Build()
|
||||
}
|
||||
}
|
||||
|
||||
rowOffset := 0
|
||||
|
||||
if len(em.wsHeader) > 0 {
|
||||
for range em.wsHeader {
|
||||
rowOffset += 1
|
||||
}
|
||||
rowOffset += 1
|
||||
}
|
||||
|
||||
if !em.SkipColumnHeader {
|
||||
for i, col := range em.colDefinitions {
|
||||
err := f.SetCellValue(sheetName, c(rowOffset+1, i), col.header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, col := range em.colDefinitions {
|
||||
if col.style != nil {
|
||||
err := f.SetColStyle(sheetName, excelize360.ToAlphaString(i), *col.style)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, col := range em.colDefinitions {
|
||||
if col.width != nil {
|
||||
err := f.SetColWidth(sheetName, excelize360.ToAlphaString(i), excelize360.ToAlphaString(i), *col.width)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := f.SetRowStyle(sheetName, rowOffset+1, rowOffset+1, *em.StyleHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(em.wsHeader) > 0 {
|
||||
for i, hdr := range em.wsHeader {
|
||||
style := *langext.CoalesceOpt(hdr.V2, em.StyleWSHeader)
|
||||
|
||||
err = f.SetCellValue(sheetName, c(i+1, 0), hdr.V1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = f.MergeCell(sheetName, c(i+1, 0), c(i+1, len(em.colDefinitions)-1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = f.SetRowStyle(sheetName, 1, 1, style)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
iRow := rowOffset + 1
|
||||
if !em.SkipColumnHeader {
|
||||
iRow += 1
|
||||
}
|
||||
|
||||
for _, dat := range data {
|
||||
|
||||
skip := false
|
||||
for _, filter := range em.colFilter {
|
||||
if !filter(dat) {
|
||||
skip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
|
||||
for iCol, col := range em.colDefinitions {
|
||||
|
||||
cellVal, err := col.fn(dat)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for reflect.ValueOf(cellVal).Kind() == reflect.Pointer && !reflect.ValueOf(cellVal).IsNil() {
|
||||
cellVal = reflect.ValueOf(cellVal).Elem().Interface()
|
||||
}
|
||||
|
||||
if langext.IsNil(cellVal) {
|
||||
err = f.SetCellValue(sheetName, c(iRow, iCol), "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
err = f.SetCellValue(sheetName, c(iRow, iCol), cellVal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
iRow++
|
||||
}
|
||||
|
||||
//for i, col := range em.colDefinitions {
|
||||
// if col.width == nil {
|
||||
// //TODO https://github.com/qax-os/excelize/pull/1386
|
||||
// }
|
||||
//}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (em *ExcelMapper[T]) AddFilter(f func(v T) bool) {
|
||||
em.colFilter = append(em.colFilter, f)
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
package excelext
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"github.com/xuri/excelize/v2"
|
||||
)
|
||||
|
||||
type testRow struct {
|
||||
Name string
|
||||
Age int
|
||||
Score float64
|
||||
}
|
||||
|
||||
func openBytes(t *testing.T, data []byte) *excelize.File {
|
||||
t.Helper()
|
||||
f, err := excelize.OpenReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open xlsx bytes: %v", err)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
func cellValue(t *testing.T, f *excelize.File, sheet, axis string) string {
|
||||
t.Helper()
|
||||
v, err := f.GetCellValue(sheet, axis)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCellValue(%s, %s) failed: %v", sheet, axis, err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func TestNewExcelMapper(t *testing.T) {
|
||||
em, err := NewExcelMapper[testRow]()
|
||||
tst.AssertNoErr(t, err)
|
||||
if em == nil {
|
||||
t.Fatal("expected non-nil mapper")
|
||||
}
|
||||
tst.AssertEqual(t, em.SkipColumnHeader, false)
|
||||
tst.AssertEqual(t, len(em.colDefinitions), 0)
|
||||
tst.AssertEqual(t, len(em.wsHeader), 0)
|
||||
tst.AssertEqual(t, len(em.colFilter), 0)
|
||||
if em.StyleDate != nil || em.StyleHeader != nil {
|
||||
t.Errorf("expected styles to be nil before init")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitNewFileAndStyles(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
f, err := em.InitNewFile("Sheet-Foo")
|
||||
tst.AssertNoErr(t, err)
|
||||
if f == nil {
|
||||
t.Fatal("expected non-nil file")
|
||||
}
|
||||
|
||||
sheets := f.GetSheetList()
|
||||
tst.AssertEqual(t, len(sheets), 1)
|
||||
tst.AssertEqual(t, sheets[0], "Sheet-Foo")
|
||||
|
||||
if em.StyleDate == nil || em.StyleDatetime == nil || em.StyleEUR == nil ||
|
||||
em.StylePercentage == nil || em.StyleHeader == nil || em.StyleWSHeader == nil {
|
||||
t.Errorf("expected all styles to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddColumn(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
em.AddColumn("Age", nil, new(12.0), func(r testRow) any { return r.Age })
|
||||
|
||||
tst.AssertEqual(t, len(em.colDefinitions), 2)
|
||||
tst.AssertEqual(t, em.colDefinitions[0].header, "Name")
|
||||
tst.AssertEqual(t, em.colDefinitions[1].header, "Age")
|
||||
if em.colDefinitions[1].width == nil || *em.colDefinitions[1].width != 12.0 {
|
||||
t.Errorf("expected width 12.0")
|
||||
}
|
||||
|
||||
val, err := em.colDefinitions[0].fn(testRow{Name: "Alice"})
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, val.(string), "Alice")
|
||||
}
|
||||
|
||||
func TestAddColumnErr(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
sentinel := errors.New("boom")
|
||||
em.AddColumnErr("X", nil, nil, func(r testRow) (any, error) {
|
||||
if r.Age < 0 {
|
||||
return nil, sentinel
|
||||
}
|
||||
return r.Age, nil
|
||||
})
|
||||
|
||||
tst.AssertEqual(t, len(em.colDefinitions), 1)
|
||||
|
||||
v, err := em.colDefinitions[0].fn(testRow{Age: 5})
|
||||
tst.AssertNoErr(t, err)
|
||||
tst.AssertEqual(t, v.(int), 5)
|
||||
|
||||
_, err = em.colDefinitions[0].fn(testRow{Age: -1})
|
||||
if !errors.Is(err, sentinel) {
|
||||
t.Errorf("expected sentinel error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddWorksheetHeader(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddWorksheetHeader("Title 1", nil)
|
||||
em.AddWorksheetHeader("Title 2", new(7))
|
||||
|
||||
tst.AssertEqual(t, len(em.wsHeader), 2)
|
||||
tst.AssertEqual(t, em.wsHeader[0].V1, "Title 1")
|
||||
tst.AssertEqual(t, em.wsHeader[1].V1, "Title 2")
|
||||
if em.wsHeader[1].V2 == nil || *em.wsHeader[1].V2 != 7 {
|
||||
t.Errorf("expected style ptr 7")
|
||||
}
|
||||
if em.wsHeader[0].V2 != nil {
|
||||
t.Errorf("expected nil style for first header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFilter(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddFilter(func(v testRow) bool { return v.Age >= 18 })
|
||||
em.AddFilter(func(v testRow) bool { return v.Score > 0 })
|
||||
tst.AssertEqual(t, len(em.colFilter), 2)
|
||||
}
|
||||
|
||||
func TestBuildBasic(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
em.AddColumn("Age", nil, nil, func(r testRow) any { return r.Age })
|
||||
|
||||
rows := []testRow{
|
||||
{Name: "Alice", Age: 30},
|
||||
{Name: "Bob", Age: 25},
|
||||
}
|
||||
|
||||
data, err := em.Build("Sheet1", rows)
|
||||
tst.AssertNoErr(t, err)
|
||||
if len(data) == 0 {
|
||||
t.Fatal("expected non-empty xlsx output")
|
||||
}
|
||||
|
||||
f := openBytes(t, data)
|
||||
defer f.Close()
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "Sheet1", "A1"), "Name")
|
||||
tst.AssertEqual(t, cellValue(t, f, "Sheet1", "B1"), "Age")
|
||||
tst.AssertEqual(t, cellValue(t, f, "Sheet1", "A2"), "Alice")
|
||||
tst.AssertEqual(t, cellValue(t, f, "Sheet1", "B2"), "30")
|
||||
tst.AssertEqual(t, cellValue(t, f, "Sheet1", "A3"), "Bob")
|
||||
tst.AssertEqual(t, cellValue(t, f, "Sheet1", "B3"), "25")
|
||||
}
|
||||
|
||||
func TestBuildSkipColumnHeader(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.SkipColumnHeader = true
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
|
||||
rows := []testRow{{Name: "Alice"}, {Name: "Bob"}}
|
||||
|
||||
data, err := em.Build("Data", rows)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
f := openBytes(t, data)
|
||||
defer f.Close()
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "Data", "A1"), "Alice")
|
||||
tst.AssertEqual(t, cellValue(t, f, "Data", "A2"), "Bob")
|
||||
}
|
||||
|
||||
func TestBuildWithFilter(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
em.AddFilter(func(v testRow) bool { return v.Age >= 18 })
|
||||
|
||||
rows := []testRow{
|
||||
{Name: "Alice", Age: 30},
|
||||
{Name: "Charlie", Age: 12},
|
||||
{Name: "Bob", Age: 25},
|
||||
}
|
||||
|
||||
data, err := em.Build("S", rows)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
f := openBytes(t, data)
|
||||
defer f.Close()
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A1"), "Name")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A2"), "Alice")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A3"), "Bob")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A4"), "")
|
||||
}
|
||||
|
||||
func TestBuildWithWorksheetHeader(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddWorksheetHeader("My Big Title", nil)
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
em.AddColumn("Age", nil, nil, func(r testRow) any { return r.Age })
|
||||
|
||||
rows := []testRow{{Name: "Alice", Age: 30}}
|
||||
|
||||
data, err := em.Build("S", rows)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
f := openBytes(t, data)
|
||||
defer f.Close()
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A1"), "My Big Title")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A3"), "Name")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "B3"), "Age")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A4"), "Alice")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "B4"), "30")
|
||||
}
|
||||
|
||||
func TestBuildHandlesNilPointer(t *testing.T) {
|
||||
type ptrRow struct {
|
||||
Name *string
|
||||
}
|
||||
|
||||
em, _ := NewExcelMapper[ptrRow]()
|
||||
em.AddColumn("Name", nil, nil, func(r ptrRow) any { return r.Name })
|
||||
|
||||
name := "Alice"
|
||||
rows := []ptrRow{
|
||||
{Name: &name},
|
||||
{Name: nil},
|
||||
}
|
||||
|
||||
data, err := em.Build("S", rows)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
f := openBytes(t, data)
|
||||
defer f.Close()
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A2"), "Alice")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A3"), "")
|
||||
}
|
||||
|
||||
func TestBuildPropagatesColumnError(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
sentinel := errors.New("col fail")
|
||||
em.AddColumnErr("Bad", nil, nil, func(r testRow) (any, error) {
|
||||
return nil, sentinel
|
||||
})
|
||||
|
||||
_, err := em.Build("S", []testRow{{Name: "X"}})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from column fn to propagate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEmptyData(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
|
||||
data, err := em.Build("S", []testRow{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
f := openBytes(t, data)
|
||||
defer f.Close()
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A1"), "Name")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S", "A2"), "")
|
||||
}
|
||||
|
||||
func TestBuildSingleSheetWithExistingFile(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name })
|
||||
|
||||
f, err := em.InitNewFile("S1")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
_, err = f.NewSheet("S2")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
err = em.BuildSingleSheet(f, "S2", []testRow{{Name: "Bob"}})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
tst.AssertEqual(t, cellValue(t, f, "S2", "A1"), "Name")
|
||||
tst.AssertEqual(t, cellValue(t, f, "S2", "A2"), "Bob")
|
||||
}
|
||||
|
||||
func TestBuildWithColumnWidthAndStyle(t *testing.T) {
|
||||
em, _ := NewExcelMapper[testRow]()
|
||||
|
||||
f, err := em.InitNewFile("S")
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
em.AddColumn("Name", em.StyleHeader, new(20.5), func(r testRow) any { return r.Name })
|
||||
|
||||
err = em.BuildSingleSheet(f, "S", []testRow{{Name: "Alice"}})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
w, err := f.GetColWidth("S", "A")
|
||||
tst.AssertNoErr(t, err)
|
||||
if w < 20.0 || w > 21.0 {
|
||||
t.Errorf("expected column width near 20.5, got %v", w)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package excelext
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/rfctime"
|
||||
"github.com/360EntSecGroup-Skylar/excelize"
|
||||
)
|
||||
|
||||
func c(row int, col int) string {
|
||||
return excelize.ToAlphaString(col) + strconv.Itoa(row)
|
||||
}
|
||||
|
||||
func excelizeOptTime(t *rfctime.RFC3339NanoTime) any {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
return t.Time()
|
||||
}
|
||||
|
||||
func excelizeOptDate(t *rfctime.Date) any {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
return t.TimeUTC()
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package excelext
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/rfctime"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
)
|
||||
|
||||
func TestCellAddress(t *testing.T) {
|
||||
tst.AssertEqual(t, c(1, 0), "A1")
|
||||
tst.AssertEqual(t, c(1, 1), "B1")
|
||||
tst.AssertEqual(t, c(2, 0), "A2")
|
||||
tst.AssertEqual(t, c(10, 25), "Z10")
|
||||
tst.AssertEqual(t, c(1, 26), "AA1")
|
||||
tst.AssertEqual(t, c(99, 27), "AB99")
|
||||
tst.AssertEqual(t, c(100, 51), "AZ100")
|
||||
tst.AssertEqual(t, c(1, 52), "BA1")
|
||||
}
|
||||
|
||||
func TestExcelizeOptTimeNil(t *testing.T) {
|
||||
got := excelizeOptTime(nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for nil time, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExcelizeOptTimeValue(t *testing.T) {
|
||||
now := time.Date(2024, 5, 17, 13, 45, 30, 0, time.UTC)
|
||||
rt := rfctime.RFC3339NanoTime(now)
|
||||
got := excelizeOptTime(&rt)
|
||||
|
||||
gt, ok := got.(time.Time)
|
||||
if !ok {
|
||||
t.Fatalf("expected time.Time, got %T", got)
|
||||
}
|
||||
if !gt.Equal(now) {
|
||||
t.Errorf("expected %v, got %v", now, gt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExcelizeOptDateNil(t *testing.T) {
|
||||
got := excelizeOptDate(nil)
|
||||
if got != "" {
|
||||
t.Errorf("expected empty string for nil date, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExcelizeOptDateValue(t *testing.T) {
|
||||
d := rfctime.NewDate(time.Date(2025, 11, 3, 0, 0, 0, 0, time.UTC))
|
||||
got := excelizeOptDate(&d)
|
||||
|
||||
gt, ok := got.(time.Time)
|
||||
if !ok {
|
||||
t.Fatalf("expected time.Time, got %T", got)
|
||||
}
|
||||
if gt.Year() != 2025 || gt.Month() != 11 || gt.Day() != 3 {
|
||||
t.Errorf("unexpected date returned: %v", gt)
|
||||
}
|
||||
}
|
||||
+117
-47
@@ -5,17 +5,18 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/dataext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/enums"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/dataext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/enums"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
//
|
||||
@@ -30,6 +31,10 @@ import (
|
||||
// If possible add metadata to the error (eg the id that was not found, ...), the methods are the same as in zerolog
|
||||
// return nil, exerror.Wrap(err, "do something failed").Str("someid", id).Int("count", in.Count).Build()
|
||||
//
|
||||
// You can also add extra-data to an error with Extra(..)
|
||||
// in contrast to metadata is extradata always printed in the resulting error and is more intended for additional (programmatically readable) data in addition to the errortype
|
||||
// (metadata is more internal debug info/help)
|
||||
//
|
||||
// You can change the errortype with `.User()` and `.System()` (User-errors are 400 and System-errors 500)
|
||||
// You can also manually set the statuscode with `.WithStatuscode(http.NotFound)`
|
||||
// You can set the type with `WithType(..)`
|
||||
@@ -76,12 +81,14 @@ func Wrap(err error, msg string) *Builder {
|
||||
return &Builder{errorData: newExErr(CatSystem, TypeInternal, msg)} // prevent NPE if we call Wrap with err==nil
|
||||
}
|
||||
|
||||
v := FromError(err)
|
||||
|
||||
if !pkgconfig.RecursiveErrors {
|
||||
v := FromError(err)
|
||||
v.Message = msg
|
||||
return &Builder{wrappedErr: err, errorData: v}
|
||||
} else {
|
||||
return &Builder{wrappedErr: err, errorData: wrapExErr(v, msg, CatWrap, 1)}
|
||||
}
|
||||
return &Builder{wrappedErr: err, errorData: wrapExErr(FromError(err), msg, CatWrap, 1)}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@@ -101,6 +108,16 @@ func (b *Builder) WithMessage(msg string) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) WithSeverity(v ErrorSeverity) *Builder {
|
||||
b.errorData.Severity = v
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Builder) WithCategory(v ErrorCategory) *Builder {
|
||||
b.errorData.Category = v
|
||||
return b
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Err changes the Severity to ERROR (default)
|
||||
@@ -236,7 +253,7 @@ func (b *Builder) Bytes(key string, val []byte) *Builder {
|
||||
return b.addMeta(key, MDTBytes, val)
|
||||
}
|
||||
|
||||
func (b *Builder) ObjectID(key string, val primitive.ObjectID) *Builder {
|
||||
func (b *Builder) ObjectID(key string, val bson.ObjectID) *Builder {
|
||||
return b.addMeta(key, MDTObjectID, val)
|
||||
}
|
||||
|
||||
@@ -260,11 +277,11 @@ func (b *Builder) Ints32(key string, val []int32) *Builder {
|
||||
return b.addMeta(key, MDTInt32Array, val)
|
||||
}
|
||||
|
||||
func (b *Builder) Type(key string, cls interface{}) *Builder {
|
||||
func (b *Builder) Type(key string, cls any) *Builder {
|
||||
return b.addMeta(key, MDTString, fmt.Sprintf("%T", cls))
|
||||
}
|
||||
|
||||
func (b *Builder) Interface(key string, val interface{}) *Builder {
|
||||
func (b *Builder) Interface(key string, val any) *Builder {
|
||||
return b.addMeta(key, MDTAny, newAnyWrap(val))
|
||||
}
|
||||
|
||||
@@ -303,6 +320,7 @@ func (b *Builder) GinReq(ctx context.Context, g *gin.Context, req *http.Request)
|
||||
}
|
||||
}
|
||||
b.Str("gin_method", req.Method)
|
||||
b.Str("gin_host", req.Host)
|
||||
b.Str("gin_path", g.FullPath())
|
||||
b.Strs("gin_header", extractHeader(g.Request.Header))
|
||||
if req.URL != nil {
|
||||
@@ -368,29 +386,6 @@ func (b *Builder) CtxData(method Method, ctx context.Context) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
func formatHeader(header map[string][]string) string {
|
||||
ml := 1
|
||||
for k, _ := range header {
|
||||
if len(k) > ml {
|
||||
ml = len(k)
|
||||
}
|
||||
}
|
||||
r := ""
|
||||
for k, v := range header {
|
||||
if r != "" {
|
||||
r += "\n"
|
||||
}
|
||||
for _, hval := range v {
|
||||
value := hval
|
||||
value = strings.ReplaceAll(value, "\n", "\\n")
|
||||
value = strings.ReplaceAll(value, "\r", "\\r")
|
||||
value = strings.ReplaceAll(value, "\t", "\\t")
|
||||
r += langext.StrPadRight(k, " ", ml) + " := " + value
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func extractHeader(header map[string][]string) []string {
|
||||
r := make([]string, 0, len(header))
|
||||
for k, v := range header {
|
||||
@@ -407,11 +402,25 @@ func extractHeader(header map[string][]string) []string {
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Extra adds additional data to the error
|
||||
// this is not like the other metadata (like Id(), Str(), etc)
|
||||
// this data is public and will be printed/outputted
|
||||
func (b *Builder) Extra(key string, val any) *Builder {
|
||||
b.errorData.Extra[key] = val
|
||||
return b
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Build creates a new error, ready to pass up the stack
|
||||
// If the errors is not SevWarn or SevInfo it gets also logged (in short form, without stacktrace) onto stdout
|
||||
// Can be gloablly configured with ZeroLogErrTraces and ZeroLogAllTraces
|
||||
// Can be locally suppressed with Builder.NoLog()
|
||||
func (b *Builder) Build(ctxs ...context.Context) error {
|
||||
return b.BuildAsExerr(ctxs...)
|
||||
}
|
||||
|
||||
func (b *Builder) BuildAsExerr(ctxs ...context.Context) *ExErr {
|
||||
warnOnPkgConfigNotInitialized()
|
||||
|
||||
for _, dctx := range ctxs {
|
||||
@@ -419,16 +428,26 @@ func (b *Builder) Build(ctxs ...context.Context) error {
|
||||
}
|
||||
|
||||
if pkgconfig.DisableErrorWrapping && b.wrappedErr != nil {
|
||||
return b.wrappedErr
|
||||
return FromError(b.wrappedErr)
|
||||
}
|
||||
|
||||
if pkgconfig.ZeroLogErrTraces && !b.noLog && (b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal) {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Error())
|
||||
} else if pkgconfig.ZeroLogAllTraces && !b.noLog {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Error())
|
||||
}
|
||||
if pkgconfig.ZeroLogErrTraces && !b.noLog {
|
||||
if b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Error())
|
||||
} else if b.errorData.Severity == SevWarn {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Warn())
|
||||
} else if b.errorData.Severity == SevInfo {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Info())
|
||||
} else if b.errorData.Severity == SevDebug {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Debug())
|
||||
} else if b.errorData.Severity == SevTrace {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Trace())
|
||||
} else {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Error()) // ?!? unknown severity
|
||||
}
|
||||
|
||||
b.errorData.CallListener(MethodBuild)
|
||||
}
|
||||
b.errorData.CallListener(MethodBuild, ListenerOpt{NoLog: b.noLog})
|
||||
|
||||
return b.errorData
|
||||
}
|
||||
@@ -437,6 +456,8 @@ func (b *Builder) Build(ctxs ...context.Context) error {
|
||||
// The error also gets printed to stdout/stderr
|
||||
// If the error is SevErr|SevFatal we also send it to the error-service
|
||||
func (b *Builder) Output(ctx context.Context, g *gin.Context) {
|
||||
warnOnPkgConfigNotInitialized()
|
||||
|
||||
if !b.containsGinData && g.Request != nil {
|
||||
// Auto-Add gin metadata if the caller hasn't already done it
|
||||
b.GinReq(ctx, g, g.Request)
|
||||
@@ -444,6 +465,21 @@ func (b *Builder) Output(ctx context.Context, g *gin.Context) {
|
||||
|
||||
b.CtxData(MethodOutput, ctx)
|
||||
|
||||
// this is only here to add one level to the trace
|
||||
// so that .Build() and .Output() and .Print() have the same depth and our stack-skip logger can have the same skip-count
|
||||
b.doGinOutput(ctx, g)
|
||||
}
|
||||
|
||||
// OutputRaw works teh same as Output() - but does not depend on gin and works with a raw http.ResponseWriter
|
||||
func (b *Builder) OutputRaw(w http.ResponseWriter) {
|
||||
warnOnPkgConfigNotInitialized()
|
||||
|
||||
// this is only here to add one level to the trace
|
||||
// so that .Build() and .Output() and .Print() have the same depth and our stack-skip logger can have the same skip-count
|
||||
b.doRawOutput(w)
|
||||
}
|
||||
|
||||
func (b *Builder) doGinOutput(ctx context.Context, g *gin.Context) {
|
||||
b.errorData.Output(g)
|
||||
|
||||
if (b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal) && (pkgconfig.ZeroLogErrGinOutput || pkgconfig.ZeroLogAllGinOutput) {
|
||||
@@ -452,25 +488,53 @@ func (b *Builder) Output(ctx context.Context, g *gin.Context) {
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.Warn())
|
||||
}
|
||||
|
||||
b.errorData.CallListener(MethodOutput)
|
||||
b.errorData.CallListener(MethodOutput, ListenerOpt{NoLog: b.noLog})
|
||||
}
|
||||
|
||||
func (b *Builder) doRawOutput(w http.ResponseWriter) {
|
||||
b.errorData.OutputRaw(w)
|
||||
|
||||
if (b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal) && (pkgconfig.ZeroLogErrGinOutput || pkgconfig.ZeroLogAllGinOutput) {
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.Error())
|
||||
} else if (b.errorData.Severity == SevWarn) && (pkgconfig.ZeroLogAllGinOutput) {
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.Warn())
|
||||
}
|
||||
|
||||
b.errorData.CallListener(MethodOutput, ListenerOpt{NoLog: b.noLog})
|
||||
}
|
||||
|
||||
// Print prints the error
|
||||
// If the error is SevErr we also send it to the error-service
|
||||
func (b *Builder) Print(ctxs ...context.Context) {
|
||||
func (b *Builder) Print(ctxs ...context.Context) Proxy {
|
||||
warnOnPkgConfigNotInitialized()
|
||||
|
||||
for _, dctx := range ctxs {
|
||||
b.CtxData(MethodPrint, dctx)
|
||||
}
|
||||
|
||||
// this is only here to add one level to the trace
|
||||
// so that .Build() and .Output() and .Print() have the same depth and our stack-skip logger can have the same skip-count
|
||||
return b.doPrint()
|
||||
}
|
||||
|
||||
func (b *Builder) doPrint() Proxy {
|
||||
if b.errorData.Severity == SevErr || b.errorData.Severity == SevFatal {
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.Error())
|
||||
} else if b.errorData.Severity == SevWarn {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Warn())
|
||||
} else if b.errorData.Severity == SevInfo {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Info())
|
||||
} else if b.errorData.Severity == SevDebug {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Debug())
|
||||
} else if b.errorData.Severity == SevTrace {
|
||||
b.errorData.ShortLog(pkgconfig.ZeroLogger.Trace())
|
||||
} else {
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.Error()) // ?!? unknown severity
|
||||
}
|
||||
|
||||
b.errorData.CallListener(MethodPrint)
|
||||
b.errorData.CallListener(MethodPrint, ListenerOpt{NoLog: b.noLog})
|
||||
|
||||
return Proxy{v: *b.errorData} // we return Proxy<Exerr> here instead of Exerr to prevent warnings on ignored err-returns
|
||||
}
|
||||
|
||||
func (b *Builder) Format(level LogPrintLevel) string {
|
||||
@@ -487,16 +551,22 @@ func (b *Builder) Fatal(ctxs ...context.Context) {
|
||||
b.CtxData(MethodFatal, dctx)
|
||||
}
|
||||
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.WithLevel(zerolog.FatalLevel))
|
||||
// this is only here to add one level to the trace
|
||||
// so that .Build() and .Output() and .Print() have the same depth and our stack-skip logger can have the same skip-count
|
||||
b.doLogFatal()
|
||||
|
||||
b.errorData.CallListener(MethodFatal)
|
||||
b.errorData.CallListener(MethodFatal, ListenerOpt{NoLog: b.noLog})
|
||||
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func (b *Builder) doLogFatal() {
|
||||
b.errorData.Log(pkgconfig.ZeroLogger.WithLevel(zerolog.FatalLevel))
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (b *Builder) addMeta(key string, mdtype metaDataType, val interface{}) *Builder {
|
||||
func (b *Builder) addMeta(key string, mdtype metaDataType, val any) *Builder {
|
||||
b.errorData.Meta.add(key, mdtype, val)
|
||||
return b
|
||||
}
|
||||
|
||||
+81
-13
@@ -3,20 +3,89 @@ package exerr
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"maps"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
var reflectTypeStr = reflect.TypeOf("")
|
||||
var reflectTypeStr = reflect.TypeFor[string]()
|
||||
|
||||
func FromError(err error) *ExErr {
|
||||
|
||||
if err == nil {
|
||||
// prevent NPE if we call FromError with err==nil
|
||||
return &ExErr{
|
||||
UniqueID: newID(),
|
||||
Category: CatForeign,
|
||||
Type: TypeInternal,
|
||||
Severity: SevErr,
|
||||
Timestamp: time.Time{},
|
||||
StatusCode: nil,
|
||||
Message: "",
|
||||
WrappedErrType: "nil",
|
||||
WrappedErr: err,
|
||||
Caller: "",
|
||||
OriginalError: nil,
|
||||
Meta: make(MetaMap),
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
if verr, ok := err.(*ExErr); ok {
|
||||
// A simple ExErr
|
||||
return verr
|
||||
}
|
||||
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
if verr, ok := err.(langext.PanicWrappedErr); ok {
|
||||
return &ExErr{
|
||||
UniqueID: newID(),
|
||||
Category: CatForeign,
|
||||
Type: TypePanic,
|
||||
Severity: SevErr,
|
||||
Timestamp: time.Time{},
|
||||
StatusCode: nil,
|
||||
Message: "A panic occured",
|
||||
WrappedErrType: fmt.Sprintf("%T", verr),
|
||||
WrappedErr: err,
|
||||
Caller: "",
|
||||
OriginalError: nil,
|
||||
Meta: MetaMap{
|
||||
"panic_object": {DataType: MDTString, Value: fmt.Sprintf("%+v", verr.RecoveredObj())},
|
||||
"panic_type": {DataType: MDTString, Value: fmt.Sprintf("%T", verr.RecoveredObj())},
|
||||
"stack": {DataType: MDTString, Value: verr.Stack},
|
||||
},
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
if verr, ok := err.(*langext.PanicWrappedErr); ok && verr != nil {
|
||||
return &ExErr{
|
||||
UniqueID: newID(),
|
||||
Category: CatForeign,
|
||||
Type: TypePanic,
|
||||
Severity: SevErr,
|
||||
Timestamp: time.Time{},
|
||||
StatusCode: nil,
|
||||
Message: "A panic occured",
|
||||
WrappedErrType: fmt.Sprintf("%T", verr),
|
||||
WrappedErr: err,
|
||||
Caller: "",
|
||||
OriginalError: nil,
|
||||
Meta: MetaMap{
|
||||
"panic_object": {DataType: MDTString, Value: fmt.Sprintf("%+v", verr.RecoveredObj())},
|
||||
"panic_type": {DataType: MDTString, Value: fmt.Sprintf("%T", verr.RecoveredObj())},
|
||||
"stack": {DataType: MDTString, Value: verr.Stack},
|
||||
},
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
// A foreign error (eg a MongoDB exception)
|
||||
return &ExErr{
|
||||
UniqueID: newID(),
|
||||
@@ -31,6 +100,7 @@ func FromError(err error) *ExErr {
|
||||
Caller: "",
|
||||
OriginalError: nil,
|
||||
Meta: getForeignMeta(err),
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +118,7 @@ func newExErr(cat ErrorCategory, errtype ErrorType, msg string) *ExErr {
|
||||
Caller: callername(2),
|
||||
OriginalError: nil,
|
||||
Meta: make(map[string]MetaValue),
|
||||
Extra: make(map[string]any),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +136,7 @@ func wrapExErr(e *ExErr, msg string, cat ErrorCategory, stacktraceskip int) *ExE
|
||||
Caller: callername(1 + stacktraceskip),
|
||||
OriginalError: e,
|
||||
Meta: make(map[string]MetaValue),
|
||||
Extra: langext.CopyMap(langext.ForceMap(e.Extra)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -82,20 +154,18 @@ func getForeignMeta(err error) (mm MetaMap) {
|
||||
}()
|
||||
|
||||
rval := reflect.ValueOf(err)
|
||||
if rval.Kind() == reflect.Interface || rval.Kind() == reflect.Ptr {
|
||||
if rval.Kind() == reflect.Interface || rval.Kind() == reflect.Pointer {
|
||||
rval = reflect.ValueOf(err).Elem()
|
||||
}
|
||||
|
||||
mm.add("foreign.errortype", MDTString, rval.Type().String())
|
||||
|
||||
for k, v := range addMetaPrefix("foreign", getReflectedMetaValues(err, 8)) {
|
||||
mm[k] = v
|
||||
}
|
||||
maps.Copy(mm, addMetaPrefix("foreign", getReflectedMetaValues(err, 8)))
|
||||
|
||||
return mm
|
||||
}
|
||||
|
||||
func getReflectedMetaValues(value interface{}, remainingDepth int) map[string]MetaValue {
|
||||
func getReflectedMetaValues(value any, remainingDepth int) map[string]MetaValue {
|
||||
|
||||
if remainingDepth <= 0 {
|
||||
return map[string]MetaValue{}
|
||||
@@ -107,7 +177,7 @@ func getReflectedMetaValues(value interface{}, remainingDepth int) map[string]Me
|
||||
|
||||
rval := reflect.ValueOf(value)
|
||||
|
||||
if rval.Type().Kind() == reflect.Ptr {
|
||||
if rval.Type().Kind() == reflect.Pointer {
|
||||
|
||||
if rval.IsNil() {
|
||||
return map[string]MetaValue{"*": {DataType: MDTNil, Value: nil}}
|
||||
@@ -153,7 +223,7 @@ func getReflectedMetaValues(value interface{}, remainingDepth int) map[string]Me
|
||||
return map[string]MetaValue{"": {DataType: MDTIntArray, Value: ifraw}}
|
||||
case []int32:
|
||||
return map[string]MetaValue{"": {DataType: MDTInt32Array, Value: ifraw}}
|
||||
case primitive.ObjectID:
|
||||
case bson.ObjectID:
|
||||
return map[string]MetaValue{"": {DataType: MDTObjectID, Value: ifraw}}
|
||||
case []string:
|
||||
return map[string]MetaValue{"": {DataType: MDTStringArray, Value: ifraw}}
|
||||
@@ -167,9 +237,7 @@ func getReflectedMetaValues(value interface{}, remainingDepth int) map[string]Me
|
||||
fieldname := fieldtype.Name
|
||||
|
||||
if fieldtype.IsExported() {
|
||||
for k, v := range addMetaPrefix(fieldname, getReflectedMetaValues(rval.Field(i).Interface(), remainingDepth-1)) {
|
||||
m[k] = v
|
||||
}
|
||||
maps.Copy(m, addMetaPrefix(fieldname, getReflectedMetaValues(rval.Field(i).Interface(), remainingDepth-1)))
|
||||
}
|
||||
}
|
||||
return m
|
||||
|
||||
+10
-44
@@ -4,11 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"reflect"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
type ErrorCategory struct{ Category string }
|
||||
@@ -28,8 +25,8 @@ func (e ErrorCategory) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(e.Category)
|
||||
}
|
||||
|
||||
func (e *ErrorCategory) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
|
||||
if bt == bson.TypeNull {
|
||||
func (e *ErrorCategory) UnmarshalBSONValue(bt byte, data []byte) error {
|
||||
if bson.Type(bt) == bson.TypeNull {
|
||||
// we can't set nil in UnmarshalBSONValue (so we use default(struct))
|
||||
// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values
|
||||
// https://stackoverflow.com/questions/75167597
|
||||
@@ -37,11 +34,11 @@ func (e *ErrorCategory) UnmarshalBSONValue(bt bsontype.Type, data []byte) error
|
||||
*e = ErrorCategory{}
|
||||
return nil
|
||||
}
|
||||
if bt != bson.TypeString {
|
||||
return errors.New(fmt.Sprintf("cannot unmarshal %v into String", bt))
|
||||
if bson.Type(bt) != bson.TypeString {
|
||||
return errors.New(fmt.Sprintf("cannot unmarshal %v into String", bson.Type(bt)))
|
||||
}
|
||||
var tt string
|
||||
err := bson.RawValue{Type: bt, Value: data}.Unmarshal(&tt)
|
||||
err := bson.RawValue{Type: bson.Type(bt), Value: data}.Unmarshal(&tt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -49,40 +46,9 @@ func (e *ErrorCategory) UnmarshalBSONValue(bt bsontype.Type, data []byte) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e ErrorCategory) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
return bson.MarshalValue(e.Category)
|
||||
}
|
||||
|
||||
func (e ErrorCategory) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if val.Kind() == reflect.Ptr && val.IsNil() {
|
||||
if !val.CanSet() {
|
||||
return errors.New("ValueUnmarshalerDecodeValue")
|
||||
}
|
||||
val.Set(reflect.New(val.Type().Elem()))
|
||||
}
|
||||
|
||||
tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.Kind() == reflect.Ptr && len(src) == 0 {
|
||||
val.Set(reflect.Zero(val.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = e.UnmarshalBSONValue(tp, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val.Set(reflect.ValueOf(&e))
|
||||
} else {
|
||||
val.Set(reflect.ValueOf(e))
|
||||
}
|
||||
|
||||
return nil
|
||||
func (e ErrorCategory) MarshalBSONValue() (byte, []byte, error) {
|
||||
tp, data, err := bson.MarshalValue(e.Category)
|
||||
return byte(tp), data, err
|
||||
}
|
||||
|
||||
//goland:noinspection GoUnusedGlobalVariable
|
||||
|
||||
+10
-44
@@ -4,11 +4,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"reflect"
|
||||
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
type ErrorSeverity struct{ Severity string }
|
||||
@@ -30,8 +27,8 @@ func (e ErrorSeverity) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(e.Severity)
|
||||
}
|
||||
|
||||
func (e *ErrorSeverity) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
|
||||
if bt == bson.TypeNull {
|
||||
func (e *ErrorSeverity) UnmarshalBSONValue(bt byte, data []byte) error {
|
||||
if bson.Type(bt) == bson.TypeNull {
|
||||
// we can't set nil in UnmarshalBSONValue (so we use default(struct))
|
||||
// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values
|
||||
// https://stackoverflow.com/questions/75167597
|
||||
@@ -39,11 +36,11 @@ func (e *ErrorSeverity) UnmarshalBSONValue(bt bsontype.Type, data []byte) error
|
||||
*e = ErrorSeverity{}
|
||||
return nil
|
||||
}
|
||||
if bt != bson.TypeString {
|
||||
return errors.New(fmt.Sprintf("cannot unmarshal %v into String", bt))
|
||||
if bson.Type(bt) != bson.TypeString {
|
||||
return errors.New(fmt.Sprintf("cannot unmarshal %v into String", bson.Type(bt)))
|
||||
}
|
||||
var tt string
|
||||
err := bson.RawValue{Type: bt, Value: data}.Unmarshal(&tt)
|
||||
err := bson.RawValue{Type: bson.Type(bt), Value: data}.Unmarshal(&tt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -51,40 +48,9 @@ func (e *ErrorSeverity) UnmarshalBSONValue(bt bsontype.Type, data []byte) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e ErrorSeverity) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
return bson.MarshalValue(e.Severity)
|
||||
}
|
||||
|
||||
func (e ErrorSeverity) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if val.Kind() == reflect.Ptr && val.IsNil() {
|
||||
if !val.CanSet() {
|
||||
return errors.New("ValueUnmarshalerDecodeValue")
|
||||
}
|
||||
val.Set(reflect.New(val.Type().Elem()))
|
||||
}
|
||||
|
||||
tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.Kind() == reflect.Ptr && len(src) == 0 {
|
||||
val.Set(reflect.Zero(val.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = e.UnmarshalBSONValue(tp, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val.Set(reflect.ValueOf(&e))
|
||||
} else {
|
||||
val.Set(reflect.ValueOf(e))
|
||||
}
|
||||
|
||||
return nil
|
||||
func (e ErrorSeverity) MarshalBSONValue() (byte, []byte, error) {
|
||||
tp, data, err := bson.MarshalValue(e.Severity)
|
||||
return byte(tp), data, err
|
||||
}
|
||||
|
||||
//goland:noinspection GoUnusedGlobalVariable
|
||||
|
||||
+38
-71
@@ -4,13 +4,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/bsoncodec"
|
||||
"go.mongodb.org/mongo-driver/bson/bsonrw"
|
||||
"go.mongodb.org/mongo-driver/bson/bsontype"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/dataext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"reflect"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/dataext"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
type ErrorType struct {
|
||||
@@ -20,40 +16,42 @@ type ErrorType struct {
|
||||
|
||||
//goland:noinspection GoUnusedGlobalVariable
|
||||
var (
|
||||
TypeInternal = NewType("INTERNAL_ERROR", langext.Ptr(500))
|
||||
TypePanic = NewType("PANIC", langext.Ptr(500))
|
||||
TypeNotImplemented = NewType("NOT_IMPLEMENTED", langext.Ptr(500))
|
||||
TypeAssert = NewType("ASSERT", langext.Ptr(500))
|
||||
TypeInternal = NewType("INTERNAL_ERROR", new(500))
|
||||
TypePanic = NewType("PANIC", new(500))
|
||||
TypeNotImplemented = NewType("NOT_IMPLEMENTED", new(500))
|
||||
TypeAssert = NewType("ASSERT", new(500))
|
||||
|
||||
TypeMongoQuery = NewType("MONGO_QUERY", langext.Ptr(500))
|
||||
TypeCursorTokenDecode = NewType("CURSOR_TOKEN_DECODE", langext.Ptr(500))
|
||||
TypeMongoFilter = NewType("MONGO_FILTER", langext.Ptr(500))
|
||||
TypeMongoReflection = NewType("MONGO_REFLECTION", langext.Ptr(500))
|
||||
TypeMongoInvalidOpt = NewType("MONGO_INVALIDOPT", langext.Ptr(500))
|
||||
TypeMongoQuery = NewType("MONGO_QUERY", new(500))
|
||||
TypeCursorTokenDecode = NewType("CURSOR_TOKEN_DECODE", new(500))
|
||||
TypeMongoFilter = NewType("MONGO_FILTER", new(500))
|
||||
TypeMongoReflection = NewType("MONGO_REFLECTION", new(500))
|
||||
TypeMongoInvalidOpt = NewType("MONGO_INVALIDOPT", new(500))
|
||||
|
||||
TypeSQLQuery = NewType("SQL_QUERY", langext.Ptr(500))
|
||||
TypeSQLBuild = NewType("SQL_BUILD", langext.Ptr(500))
|
||||
TypeSQLDecode = NewType("SQL_DECODE", langext.Ptr(500))
|
||||
TypeSQLQuery = NewType("SQL_QUERY", new(500))
|
||||
TypeSQLBuild = NewType("SQL_BUILD", new(500))
|
||||
TypeSQLDecode = NewType("SQL_DECODE", new(500))
|
||||
|
||||
TypeWrap = NewType("Wrap", nil)
|
||||
|
||||
TypeBindFailURI = NewType("BINDFAIL_URI", langext.Ptr(400))
|
||||
TypeBindFailQuery = NewType("BINDFAIL_QUERY", langext.Ptr(400))
|
||||
TypeBindFailJSON = NewType("BINDFAIL_JSON", langext.Ptr(400))
|
||||
TypeBindFailFormData = NewType("BINDFAIL_FORMDATA", langext.Ptr(400))
|
||||
TypeBindFailHeader = NewType("BINDFAIL_HEADER", langext.Ptr(400))
|
||||
TypeBindFailURI = NewType("BINDFAIL_URI", new(400))
|
||||
TypeBindFailQuery = NewType("BINDFAIL_QUERY", new(400))
|
||||
TypeBindFailJSON = NewType("BINDFAIL_JSON", new(400))
|
||||
TypeBindFailFormData = NewType("BINDFAIL_FORMDATA", new(400))
|
||||
TypeBindFailHeader = NewType("BINDFAIL_HEADER", new(400))
|
||||
|
||||
TypeMarshalEntityID = NewType("MARSHAL_ENTITY_ID", langext.Ptr(400))
|
||||
TypeInvalidCSID = NewType("INVALID_CSID", langext.Ptr(400))
|
||||
TypeMarshalEntityID = NewType("MARSHAL_ENTITY_ID", new(400))
|
||||
TypeInvalidCSID = NewType("INVALID_CSID", new(400))
|
||||
|
||||
TypeGoogleStatuscode = NewType("GOOGLE_STATUSCODE", langext.Ptr(400))
|
||||
TypeGoogleResponse = NewType("GOOGLE_RESPONSE", langext.Ptr(400))
|
||||
TypeGoogleStatuscode = NewType("GOOGLE_STATUSCODE", new(400))
|
||||
TypeGoogleResponse = NewType("GOOGLE_RESPONSE", new(400))
|
||||
|
||||
TypeUnauthorized = NewType("UNAUTHORIZED", langext.Ptr(401))
|
||||
TypeAuthFailed = NewType("AUTH_FAILED", langext.Ptr(401))
|
||||
TypeUnauthorized = NewType("UNAUTHORIZED", new(401))
|
||||
TypeAuthFailed = NewType("AUTH_FAILED", new(401))
|
||||
|
||||
TypeInvalidImage = NewType("IMAGEEXT_INVALID_IMAGE", langext.Ptr(400))
|
||||
TypeInvalidMimeType = NewType("IMAGEEXT_INVALID_MIMETYPE", langext.Ptr(400))
|
||||
TypeInvalidImage = NewType("IMAGEEXT_INVALID_IMAGE", new(400))
|
||||
TypeInvalidMimeType = NewType("IMAGEEXT_INVALID_MIMETYPE", new(400))
|
||||
|
||||
TypeWebsocket = NewType("WEBSOCKET", new(500))
|
||||
|
||||
// other values come from the downstream application that uses goext
|
||||
)
|
||||
@@ -78,8 +76,8 @@ func (e ErrorType) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(e.Key)
|
||||
}
|
||||
|
||||
func (e *ErrorType) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
|
||||
if bt == bson.TypeNull {
|
||||
func (e *ErrorType) UnmarshalBSONValue(bt byte, data []byte) error {
|
||||
if bson.Type(bt) == bson.TypeNull {
|
||||
// we can't set nil in UnmarshalBSONValue (so we use default(struct))
|
||||
// Use mongoext.CreateGoExtBsonRegistry if you need to unmarsh pointer values
|
||||
// https://stackoverflow.com/questions/75167597
|
||||
@@ -87,11 +85,11 @@ func (e *ErrorType) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
|
||||
*e = ErrorType{}
|
||||
return nil
|
||||
}
|
||||
if bt != bson.TypeString {
|
||||
return errors.New(fmt.Sprintf("cannot unmarshal %v into String", bt))
|
||||
if bson.Type(bt) != bson.TypeString {
|
||||
return errors.New(fmt.Sprintf("cannot unmarshal %v into String", bson.Type(bt)))
|
||||
}
|
||||
var tt string
|
||||
err := bson.RawValue{Type: bt, Value: data}.Unmarshal(&tt)
|
||||
err := bson.RawValue{Type: bson.Type(bt), Value: data}.Unmarshal(&tt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -105,40 +103,9 @@ func (e *ErrorType) UnmarshalBSONValue(bt bsontype.Type, data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (e ErrorType) MarshalBSONValue() (bsontype.Type, []byte, error) {
|
||||
return bson.MarshalValue(e.Key)
|
||||
}
|
||||
|
||||
func (e ErrorType) DecodeValue(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
|
||||
if val.Kind() == reflect.Ptr && val.IsNil() {
|
||||
if !val.CanSet() {
|
||||
return errors.New("ValueUnmarshalerDecodeValue")
|
||||
}
|
||||
val.Set(reflect.New(val.Type().Elem()))
|
||||
}
|
||||
|
||||
tp, src, err := bsonrw.Copier{}.CopyValueToBytes(vr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.Kind() == reflect.Ptr && len(src) == 0 {
|
||||
val.Set(reflect.Zero(val.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
err = e.UnmarshalBSONValue(tp, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val.Set(reflect.ValueOf(&e))
|
||||
} else {
|
||||
val.Set(reflect.ValueOf(e))
|
||||
}
|
||||
|
||||
return nil
|
||||
func (e ErrorType) MarshalBSONValue() (byte, []byte, error) {
|
||||
tp, data, err := bson.MarshalValue(e.Key)
|
||||
return byte(tp), data, err
|
||||
}
|
||||
|
||||
var registeredTypes = dataext.SyncMap[string, ErrorType]{}
|
||||
|
||||
+16
-16
@@ -3,12 +3,12 @@ package exerr
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
"go.mongodb.org/mongo-driver/v2/mongo"
|
||||
)
|
||||
|
||||
func TestJSONMarshalErrorCategory(t *testing.T) {
|
||||
@@ -57,7 +57,7 @@ func TestBSONMarshalErrorCategory(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
client, err := mongo.Connect(ctx)
|
||||
client, err := mongo.Connect()
|
||||
if err != nil {
|
||||
t.Skip("Skip test - no local mongo found")
|
||||
return
|
||||
@@ -68,7 +68,7 @@ func TestBSONMarshalErrorCategory(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
primimd := primitive.NewObjectID()
|
||||
primimd := bson.NewObjectID()
|
||||
|
||||
_, err = client.Database("_test").Collection("goext-cicd").InsertOne(ctx, bson.M{"_id": primimd, "val": CatSystem})
|
||||
tst.AssertNoErr(t, err)
|
||||
@@ -76,8 +76,8 @@ func TestBSONMarshalErrorCategory(t *testing.T) {
|
||||
cursor := client.Database("_test").Collection("goext-cicd").FindOne(ctx, bson.M{"_id": primimd, "val": bson.M{"$type": "string"}})
|
||||
|
||||
var c1 struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
Val ErrorCategory `bson:"val"`
|
||||
ID bson.ObjectID `bson:"_id"`
|
||||
Val ErrorCategory `bson:"val"`
|
||||
}
|
||||
|
||||
err = cursor.Decode(&c1)
|
||||
@@ -90,7 +90,7 @@ func TestBSONMarshalErrorSeverity(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
client, err := mongo.Connect(ctx)
|
||||
client, err := mongo.Connect()
|
||||
if err != nil {
|
||||
t.Skip("Skip test - no local mongo found")
|
||||
return
|
||||
@@ -101,7 +101,7 @@ func TestBSONMarshalErrorSeverity(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
primimd := primitive.NewObjectID()
|
||||
primimd := bson.NewObjectID()
|
||||
|
||||
_, err = client.Database("_test").Collection("goext-cicd").InsertOne(ctx, bson.M{"_id": primimd, "val": SevErr})
|
||||
tst.AssertNoErr(t, err)
|
||||
@@ -109,8 +109,8 @@ func TestBSONMarshalErrorSeverity(t *testing.T) {
|
||||
cursor := client.Database("_test").Collection("goext-cicd").FindOne(ctx, bson.M{"_id": primimd, "val": bson.M{"$type": "string"}})
|
||||
|
||||
var c1 struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
Val ErrorSeverity `bson:"val"`
|
||||
ID bson.ObjectID `bson:"_id"`
|
||||
Val ErrorSeverity `bson:"val"`
|
||||
}
|
||||
|
||||
err = cursor.Decode(&c1)
|
||||
@@ -123,7 +123,7 @@ func TestBSONMarshalErrorType(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 350*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
client, err := mongo.Connect(ctx)
|
||||
client, err := mongo.Connect()
|
||||
if err != nil {
|
||||
t.Skip("Skip test - no local mongo found")
|
||||
return
|
||||
@@ -134,7 +134,7 @@ func TestBSONMarshalErrorType(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
primimd := primitive.NewObjectID()
|
||||
primimd := bson.NewObjectID()
|
||||
|
||||
_, err = client.Database("_test").Collection("goext-cicd").InsertOne(ctx, bson.M{"_id": primimd, "val": TypeNotImplemented})
|
||||
tst.AssertNoErr(t, err)
|
||||
@@ -142,8 +142,8 @@ func TestBSONMarshalErrorType(t *testing.T) {
|
||||
cursor := client.Database("_test").Collection("goext-cicd").FindOne(ctx, bson.M{"_id": primimd, "val": bson.M{"$type": "string"}})
|
||||
|
||||
var c1 struct {
|
||||
ID primitive.ObjectID `bson:"_id"`
|
||||
Val ErrorType `bson:"val"`
|
||||
ID bson.ObjectID `bson:"_id"`
|
||||
Val ErrorType `bson:"val"`
|
||||
}
|
||||
|
||||
err = cursor.Decode(&c1)
|
||||
|
||||
+2
-2
@@ -3,9 +3,9 @@ package exerr
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/rs/zerolog"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
@@ -119,7 +119,7 @@ func newDefaultLogger() zerolog.Logger {
|
||||
|
||||
multi := zerolog.MultiLevelWriter(cw)
|
||||
|
||||
return zerolog.New(multi).With().Timestamp().CallerWithSkipFrameCount(4).Logger()
|
||||
return zerolog.New(multi).With().Timestamp().CallerWithSkipFrameCount(5).Logger()
|
||||
}
|
||||
|
||||
func Initialized() bool {
|
||||
|
||||
+88
-35
@@ -1,12 +1,14 @@
|
||||
package exerr
|
||||
|
||||
import (
|
||||
"github.com/rs/xid"
|
||||
"github.com/rs/zerolog"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"github.com/rs/xid"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type ExErr struct {
|
||||
@@ -26,7 +28,8 @@ type ExErr struct {
|
||||
|
||||
OriginalError *ExErr `json:"originalError"`
|
||||
|
||||
Meta MetaMap `json:"meta"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
Meta MetaMap `json:"meta"`
|
||||
}
|
||||
|
||||
func (ee *ExErr) Error() string {
|
||||
@@ -36,6 +39,13 @@ func (ee *ExErr) Error() string {
|
||||
// Unwrap must be implemented so that some error.XXX methods work
|
||||
func (ee *ExErr) Unwrap() error {
|
||||
if ee.OriginalError == nil {
|
||||
|
||||
if ee.WrappedErr != nil {
|
||||
if werr, ok := ee.WrappedErr.(error); ok {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
|
||||
return nil // this is neccessary - otherwise we return a wrapped nil and the `x == nil` comparison fails (= panic in errors.Is and other failures)
|
||||
}
|
||||
return ee.OriginalError
|
||||
@@ -81,6 +91,23 @@ func (ee *ExErr) Log(evt *zerolog.Event) {
|
||||
}
|
||||
|
||||
func (ee *ExErr) FormatLog(lvl LogPrintLevel) string {
|
||||
|
||||
// [LogPrintShort]
|
||||
//
|
||||
// - Only print message and type
|
||||
// - Used e.g. for logging to the console when Build is called
|
||||
// - also used in Print() if level == Warn/Info
|
||||
//
|
||||
// [LogPrintOverview]
|
||||
//
|
||||
// - print message, extra and errortrace
|
||||
//
|
||||
// [LogPrintFull]
|
||||
//
|
||||
// - print full error, with meta and extra, and trace, etc
|
||||
// - Used in Output() and Print()
|
||||
//
|
||||
|
||||
if lvl == LogPrintShort {
|
||||
|
||||
msg := ee.Message
|
||||
@@ -99,65 +126,75 @@ func (ee *ExErr) FormatLog(lvl LogPrintLevel) string {
|
||||
|
||||
} else if lvl == LogPrintOverview {
|
||||
|
||||
str := "[" + ee.RecursiveType().Key + "] <" + ee.UniqueID + "> " + strings.ReplaceAll(ee.RecursiveMessage(), "\n", " ") + "\n"
|
||||
var str strings.Builder
|
||||
str.WriteString("[" + ee.RecursiveType().Key + "] <" + ee.UniqueID + "> " + strings.ReplaceAll(ee.RecursiveMessage(), "\n", " ") + "\n")
|
||||
|
||||
indent := ""
|
||||
for exk, exv := range ee.Extra {
|
||||
str.WriteString(fmt.Sprintf(" # [[[ %s ==> %v ]]]\n", exk, exv))
|
||||
}
|
||||
|
||||
var indent strings.Builder
|
||||
for curr := ee; curr != nil; curr = curr.OriginalError {
|
||||
indent += " "
|
||||
indent.WriteString(" ")
|
||||
|
||||
str += indent
|
||||
str += "-> "
|
||||
str.WriteString(indent.String())
|
||||
str.WriteString("-> ")
|
||||
strmsg := strings.Trim(curr.Message, " \r\n\t")
|
||||
if lbidx := strings.Index(curr.Message, "\n"); lbidx >= 0 {
|
||||
strmsg = strmsg[0:lbidx]
|
||||
}
|
||||
strmsg = langext.StrLimit(strmsg, 61, "...")
|
||||
str += strmsg
|
||||
str += "\n"
|
||||
str.WriteString(strmsg)
|
||||
str.WriteString("\n")
|
||||
|
||||
}
|
||||
return str
|
||||
return str.String()
|
||||
|
||||
} else if lvl == LogPrintFull {
|
||||
|
||||
str := "[" + ee.RecursiveType().Key + "] <" + ee.UniqueID + "> " + strings.ReplaceAll(ee.RecursiveMessage(), "\n", " ") + "\n"
|
||||
var str strings.Builder
|
||||
str.WriteString("[" + ee.RecursiveType().Key + "] <" + ee.UniqueID + "> " + strings.ReplaceAll(ee.RecursiveMessage(), "\n", " ") + "\n")
|
||||
|
||||
for exk, exv := range ee.Extra {
|
||||
str.WriteString(fmt.Sprintf(" # [[[ %s ==> %v ]]]\n", exk, exv))
|
||||
}
|
||||
|
||||
indent := ""
|
||||
for curr := ee; curr != nil; curr = curr.OriginalError {
|
||||
indent += " "
|
||||
|
||||
etype := ee.Type.Key
|
||||
if ee.Type == TypeWrap {
|
||||
etype := curr.Type.Key
|
||||
if curr.Type == TypeWrap {
|
||||
etype = "~"
|
||||
}
|
||||
|
||||
str += indent
|
||||
str += "-> ["
|
||||
str += etype
|
||||
str.WriteString(indent)
|
||||
str.WriteString("-> [")
|
||||
str.WriteString(etype)
|
||||
if curr.Category == CatForeign {
|
||||
str += "|Foreign"
|
||||
str.WriteString("|Foreign")
|
||||
}
|
||||
str += "] "
|
||||
str += strings.ReplaceAll(curr.Message, "\n", " ")
|
||||
str.WriteString("] ")
|
||||
str.WriteString(strings.ReplaceAll(curr.Message, "\n", " "))
|
||||
if curr.Caller != "" {
|
||||
str += " (@ "
|
||||
str += curr.Caller
|
||||
str += ")"
|
||||
str.WriteString(" (@ ")
|
||||
str.WriteString(curr.Caller)
|
||||
str.WriteString(")")
|
||||
}
|
||||
str += "\n"
|
||||
str.WriteString("\n")
|
||||
|
||||
if curr.Meta.Any() {
|
||||
meta := indent + " {" + curr.Meta.FormatOneLine(240) + "}"
|
||||
if len(meta) < 200 {
|
||||
str += meta
|
||||
str += "\n"
|
||||
str.WriteString(meta)
|
||||
str.WriteString("\n")
|
||||
} else {
|
||||
str += curr.Meta.FormatMultiLine(indent+" ", " ", 1024)
|
||||
str += "\n"
|
||||
str.WriteString(curr.Meta.FormatMultiLine(indent+" ", " ", 1024))
|
||||
str.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
return str
|
||||
return str.String()
|
||||
|
||||
} else {
|
||||
|
||||
@@ -167,11 +204,11 @@ func (ee *ExErr) FormatLog(lvl LogPrintLevel) string {
|
||||
}
|
||||
|
||||
func (ee *ExErr) ShortLog(evt *zerolog.Event) {
|
||||
ee.Meta.Apply(evt, langext.Ptr(240)).Msg(ee.FormatLog(LogPrintShort))
|
||||
ee.Meta.Apply(evt, new(240)).Msg(ee.FormatLog(LogPrintShort))
|
||||
}
|
||||
|
||||
// RecursiveMessage returns the message to show
|
||||
// = first error (top-down) that is not wrapping/foreign/empty
|
||||
// = first error (top-down) that is not foreign/empty
|
||||
// = lowest level error (that is not empty)
|
||||
// = fallback to self.message
|
||||
func (ee *ExErr) RecursiveMessage() string {
|
||||
@@ -179,7 +216,7 @@ func (ee *ExErr) RecursiveMessage() string {
|
||||
// ==== [1] ==== first error (top-down) that is not wrapping/foreign/empty
|
||||
|
||||
for curr := ee; curr != nil; curr = curr.OriginalError {
|
||||
if curr.Message != "" && curr.Category != CatWrap && curr.Category != CatForeign {
|
||||
if curr.Message != "" && curr.Category != CatForeign {
|
||||
return curr.Message
|
||||
}
|
||||
}
|
||||
@@ -220,7 +257,7 @@ func (ee *ExErr) RecursiveType() ErrorType {
|
||||
func (ee *ExErr) RecursiveStatuscode() *int {
|
||||
for curr := ee; curr != nil; curr = curr.OriginalError {
|
||||
if curr.StatusCode != nil {
|
||||
return langext.Ptr(*curr.StatusCode)
|
||||
return new(*curr.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -245,7 +282,7 @@ func (ee *ExErr) RecursiveCategory() ErrorCategory {
|
||||
func (ee *ExErr) RecursiveMeta(key string) *MetaValue {
|
||||
for curr := ee; curr != nil; curr = curr.OriginalError {
|
||||
if metaval, ok := curr.Meta[key]; ok {
|
||||
return langext.Ptr(metaval)
|
||||
return new(metaval)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -328,6 +365,22 @@ func (ee *ExErr) GetMetaTime(key string) (time.Time, bool) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func (ee *ExErr) GetExtra(key string) (any, bool) {
|
||||
if v, ok := ee.Extra[key]; ok {
|
||||
return v, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (ee *ExErr) UniqueIDs() []string {
|
||||
ids := make([]string, 0, 1)
|
||||
for curr := ee; curr != nil; curr = curr.OriginalError {
|
||||
ids = append(ids, curr.UniqueID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
// contains test if the supplied error is contained in this error (anywhere in the chain)
|
||||
func (ee *ExErr) contains(original *ExErr) (*ExErr, bool) {
|
||||
if original == nil {
|
||||
|
||||
+2
-2
@@ -2,8 +2,8 @@ package exerr
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
+53
-3
@@ -1,9 +1,11 @@
|
||||
package exerr
|
||||
|
||||
import (
|
||||
"maps"
|
||||
|
||||
json "git.blackforestbytes.com/BlackForestBytes/goext/gojson"
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"github.com/gin-gonic/gin"
|
||||
json "gogs.mikescher.com/BlackForestBytes/goext/gojson"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@@ -48,6 +50,10 @@ func (ee *ExErr) toJson(depth int, applyExtendListener bool, outputMeta bool) la
|
||||
metaJson[metaKey] = metaVal.rawValueForJson()
|
||||
}
|
||||
ginJson["meta"] = metaJson
|
||||
|
||||
extraJson := langext.H{}
|
||||
maps.Copy(extraJson, ee.Extra)
|
||||
ginJson["extra"] = extraJson
|
||||
}
|
||||
|
||||
if applyExtendListener {
|
||||
@@ -62,7 +68,6 @@ func (ee *ExErr) ToDefaultAPIJson() (string, error) {
|
||||
gjr := json.GoJsonRender{Data: ee.ToAPIJson(true, pkgconfig.ExtendedGinOutput, pkgconfig.IncludeMetaInGinOutput), NilSafeSlices: true, NilSafeMaps: true}
|
||||
|
||||
r, err := gjr.RenderString()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -90,6 +95,20 @@ func (ee *ExErr) ToAPIJson(applyExtendListener bool, includeWrappedErrors bool,
|
||||
apiOutput["__data"] = ee.toJson(0, applyExtendListener, includeMetaFields)
|
||||
}
|
||||
|
||||
for exkey, exval := range ee.Extra {
|
||||
|
||||
// ensure we do not override existing values
|
||||
for {
|
||||
if _, ok := apiOutput[exkey]; ok {
|
||||
exkey = "_" + exkey
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
apiOutput[exkey] = exval
|
||||
}
|
||||
|
||||
if applyExtendListener {
|
||||
pkgconfig.ExtendGinOutput(ee, apiOutput)
|
||||
}
|
||||
@@ -123,3 +142,34 @@ func (ee *ExErr) Output(g *gin.Context) {
|
||||
|
||||
g.Render(statuscode, json.GoJsonRender{Data: ginOutput, NilSafeSlices: true, NilSafeMaps: true})
|
||||
}
|
||||
|
||||
func (ee *ExErr) OutputRaw(w http.ResponseWriter) {
|
||||
|
||||
warnOnPkgConfigNotInitialized()
|
||||
|
||||
var statuscode = http.StatusInternalServerError
|
||||
|
||||
var baseCat = ee.RecursiveCategory()
|
||||
var baseType = ee.RecursiveType()
|
||||
var baseStatuscode = ee.RecursiveStatuscode()
|
||||
|
||||
if baseCat == CatUser {
|
||||
statuscode = http.StatusBadRequest
|
||||
} else if baseCat == CatSystem {
|
||||
statuscode = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
if baseStatuscode != nil {
|
||||
statuscode = *ee.StatusCode
|
||||
} else if baseType.DefaultStatusCode != nil {
|
||||
statuscode = *baseType.DefaultStatusCode
|
||||
}
|
||||
|
||||
ginOutput, err := ee.ToDefaultAPIJson()
|
||||
if err != nil {
|
||||
panic(err) // cannot happen
|
||||
}
|
||||
|
||||
w.WriteHeader(statuscode)
|
||||
_, _ = w.Write([]byte(ginOutput))
|
||||
}
|
||||
|
||||
@@ -86,3 +86,41 @@ func MessageMatch(e error, matcher func(string) bool) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// OriginalError returns the lowest level error, probably the original/external error that was originally wrapped
|
||||
func OriginalError(e error) error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
bmerr, ok := e.(*ExErr)
|
||||
for !ok {
|
||||
return e
|
||||
}
|
||||
|
||||
for bmerr.OriginalError != nil {
|
||||
bmerr = bmerr.OriginalError
|
||||
}
|
||||
|
||||
if bmerr.WrappedErr != nil {
|
||||
if werr, ok := bmerr.WrappedErr.(error); ok {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
|
||||
return bmerr
|
||||
}
|
||||
|
||||
func UniqueID(v error) *string {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
//goland:noinspection GoTypeAssertionOnErrors
|
||||
if verr, ok := v.(*ExErr); ok {
|
||||
return &verr.UniqueID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
+7
-3
@@ -4,7 +4,11 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Listener = func(method Method, v *ExErr)
|
||||
type ListenerOpt struct {
|
||||
NoLog bool
|
||||
}
|
||||
|
||||
type Listener = func(method Method, v *ExErr, opt ListenerOpt)
|
||||
|
||||
var listenerLock = sync.Mutex{}
|
||||
var listener = make([]Listener, 0)
|
||||
@@ -16,11 +20,11 @@ func RegisterListener(l Listener) {
|
||||
listener = append(listener, l)
|
||||
}
|
||||
|
||||
func (ee *ExErr) CallListener(m Method) {
|
||||
func (ee *ExErr) CallListener(m Method, opt ListenerOpt) {
|
||||
listenerLock.Lock()
|
||||
defer listenerLock.Unlock()
|
||||
|
||||
for _, v := range listener {
|
||||
v(m, ee)
|
||||
v(m, ee, opt)
|
||||
}
|
||||
}
|
||||
|
||||
+35
-32
@@ -5,14 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/bson/primitive"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.blackforestbytes.com/BlackForestBytes/goext/langext"
|
||||
"github.com/rs/zerolog"
|
||||
"go.mongodb.org/mongo-driver/v2/bson"
|
||||
)
|
||||
|
||||
// This is a buffed up map[string]any
|
||||
@@ -49,13 +49,13 @@ const (
|
||||
|
||||
type MetaValue struct {
|
||||
DataType metaDataType `json:"dataType"`
|
||||
Value interface{} `json:"value"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
type metaValueSerialization struct {
|
||||
DataType metaDataType `bson:"dataType"`
|
||||
Value string `bson:"value"`
|
||||
Raw interface{} `bson:"raw"`
|
||||
Raw any `bson:"raw"`
|
||||
}
|
||||
|
||||
func (v MetaValue) SerializeValue() (string, error) {
|
||||
@@ -99,7 +99,7 @@ func (v MetaValue) SerializeValue() (string, error) {
|
||||
case MDTBytes:
|
||||
return hex.EncodeToString(v.Value.([]byte)), nil
|
||||
case MDTObjectID:
|
||||
return v.Value.(primitive.ObjectID).Hex(), nil
|
||||
return v.Value.(bson.ObjectID).Hex(), nil
|
||||
case MDTTime:
|
||||
return strconv.FormatInt(v.Value.(time.Time).Unix(), 10) + "|" + strconv.FormatInt(int64(v.Value.(time.Time).Nanosecond()), 10), nil
|
||||
case MDTDuration:
|
||||
@@ -178,7 +178,7 @@ func (v MetaValue) ShortString(lim int) string {
|
||||
case MDTBytes:
|
||||
return langext.StrLimit(hex.EncodeToString(v.Value.([]byte)), lim, "...")
|
||||
case MDTObjectID:
|
||||
return v.Value.(primitive.ObjectID).Hex()
|
||||
return v.Value.(bson.ObjectID).Hex()
|
||||
case MDTTime:
|
||||
return v.Value.(time.Time).Format(time.RFC3339)
|
||||
case MDTDuration:
|
||||
@@ -266,7 +266,7 @@ func (v MetaValue) Apply(key string, evt *zerolog.Event, limitLen *int) *zerolog
|
||||
case MDTBytes:
|
||||
return evt.Bytes(key, v.Value.([]byte))
|
||||
case MDTObjectID:
|
||||
return evt.Str(key, v.Value.(primitive.ObjectID).Hex())
|
||||
return evt.Str(key, v.Value.(bson.ObjectID).Hex())
|
||||
case MDTTime:
|
||||
return evt.Time(key, v.Value.(time.Time))
|
||||
case MDTDuration:
|
||||
@@ -379,7 +379,7 @@ func (v *MetaValue) Deserialize(value string, datatype metaDataType) error {
|
||||
v.DataType = datatype
|
||||
return nil
|
||||
} else {
|
||||
v.Value = langext.Ptr(value[1:])
|
||||
v.Value = new(value[1:])
|
||||
v.DataType = datatype
|
||||
return nil
|
||||
}
|
||||
@@ -460,7 +460,7 @@ func (v *MetaValue) Deserialize(value string, datatype metaDataType) error {
|
||||
v.DataType = datatype
|
||||
return nil
|
||||
case MDTObjectID:
|
||||
r, err := primitive.ObjectIDFromHex(value)
|
||||
r, err := bson.ObjectIDFromHex(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -577,7 +577,7 @@ func (v MetaValue) ValueString() string {
|
||||
case MDTBytes:
|
||||
return hex.EncodeToString(v.Value.([]byte))
|
||||
case MDTObjectID:
|
||||
return v.Value.(primitive.ObjectID).Hex()
|
||||
return v.Value.(bson.ObjectID).Hex()
|
||||
case MDTTime:
|
||||
return v.Value.(time.Time).Format(time.RFC3339Nano)
|
||||
case MDTDuration:
|
||||
@@ -628,8 +628,8 @@ func (v MetaValue) rawValueForJson() any {
|
||||
if v.Value.(AnyWrap).IsError {
|
||||
return bson.M{"@error": true}
|
||||
}
|
||||
jsonobj := primitive.M{}
|
||||
jsonarr := primitive.A{}
|
||||
jsonobj := bson.M{}
|
||||
jsonarr := bson.A{}
|
||||
if err := json.Unmarshal([]byte(v.Value.(AnyWrap).Json), &jsonobj); err == nil {
|
||||
return jsonobj
|
||||
} else if err := json.Unmarshal([]byte(v.Value.(AnyWrap).Json), &jsonarr); err == nil {
|
||||
@@ -654,7 +654,7 @@ func (v MetaValue) rawValueForJson() any {
|
||||
return v.Value.(time.Time).Format(time.RFC3339Nano)
|
||||
}
|
||||
if v.DataType == MDTObjectID {
|
||||
return v.Value.(primitive.ObjectID).Hex()
|
||||
return v.Value.(bson.ObjectID).Hex()
|
||||
}
|
||||
if v.DataType == MDTNil {
|
||||
return nil
|
||||
@@ -694,43 +694,46 @@ func (v MetaValue) rawValueForJson() any {
|
||||
}
|
||||
|
||||
func (mm MetaMap) FormatOneLine(singleMaxLen int) string {
|
||||
r := ""
|
||||
var r strings.Builder
|
||||
|
||||
i := 0
|
||||
for key, val := range mm {
|
||||
if i > 0 {
|
||||
r += ", "
|
||||
r.WriteString(", ")
|
||||
}
|
||||
|
||||
r += "\"" + key + "\""
|
||||
r += ": "
|
||||
r += "\"" + val.ShortString(singleMaxLen) + "\""
|
||||
r.WriteString("\"" + key + "\"")
|
||||
r.WriteString(": ")
|
||||
r.WriteString("\"" + val.ShortString(singleMaxLen) + "\"")
|
||||
|
||||
i++
|
||||
}
|
||||
|
||||
return r
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (mm MetaMap) FormatMultiLine(indentFront string, indentKeys string, maxLenValue int) string {
|
||||
r := ""
|
||||
var r strings.Builder
|
||||
|
||||
r += indentFront + "{" + "\n"
|
||||
r.WriteString(indentFront + "{" + "\n")
|
||||
for key, val := range mm {
|
||||
if key == "gin.body" {
|
||||
continue
|
||||
}
|
||||
if key == "gin_body" {
|
||||
continue
|
||||
}
|
||||
|
||||
r += indentFront
|
||||
r += indentKeys
|
||||
r += "\"" + key + "\""
|
||||
r += ": "
|
||||
r += "\"" + val.ShortString(maxLenValue) + "\""
|
||||
r += ",\n"
|
||||
r.WriteString(indentFront)
|
||||
r.WriteString(indentKeys)
|
||||
r.WriteString("\"" + key + "\"")
|
||||
r.WriteString(": ")
|
||||
r.WriteString("\"" + val.ShortString(maxLenValue) + "\"")
|
||||
r.WriteString(",\n")
|
||||
}
|
||||
r += indentFront + "}"
|
||||
r.WriteString(indentFront + "}")
|
||||
|
||||
return r
|
||||
return r.String()
|
||||
}
|
||||
|
||||
func (mm MetaMap) Any() bool {
|
||||
@@ -744,7 +747,7 @@ func (mm MetaMap) Apply(evt *zerolog.Event, limitLen *int) *zerolog.Event {
|
||||
return evt
|
||||
}
|
||||
|
||||
func (mm MetaMap) add(key string, mdtype metaDataType, val interface{}) {
|
||||
func (mm MetaMap) add(key string, mdtype metaDataType, val any) {
|
||||
if _, ok := mm[key]; !ok {
|
||||
mm[key] = MetaValue{DataType: mdtype, Value: val}
|
||||
return
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user