diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 0000000..7e8cddc --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,44 @@ +{ + "permissions": { + "allow": [ + "Bash(make:*)", + + "Bash(mkdir:*)", + + "Bash(go build:*)", + "Bash(go test:*)", + "Bash(go get:*)", + "Bash(go mod:*)", + "Bash(go clean:*)", + "Bash(go doc:*)", + + "Bash(grep:*)", + "Bash(find:*)", + "Bash(rg:*)", + "Bash(base64:*)", + "Bash(sed:*)", + "Bash(ls:*)", + + "Bash(curl:*)", + + "Bash(timeout 60s go test -v -count=1 ./...)", + "Bash(timeout 60s go test -v -count=1 ./tests/integration/...)", + "Bash(timeout 60s go test:*)", + "Bash(timeout 300 make test)", + "Bash(timeout 30s go test ./tests/integration -run:*)", + + "Bash(done)", + "Bash(awk:*)", + + "WebFetch(domain:platform.openai.com)" + ], + "deny": [ + ], + "defaultMode": "acceptEdits" + }, + "env": { + "CLAUDE_CODE_ENABLE_TELEMETRY": "0", + "DISABLE_ERROR_REPORTING": "1", + "DISABLE_TELEMETRY": "1" + } +} diff --git a/.gitea/workflows/tests.yml b/.gitea/workflows/tests.yml index 171f66e..fe2f79e 100644 --- a/.gitea/workflows/tests.yml +++ b/.gitea/workflows/tests.yml @@ -4,7 +4,7 @@ # https://docs.github.com/en/actions/learn-github-actions/contexts#github-context name: Build Docker and Deploy -run-name: Build & Deploy ${{ gitea.ref }} on ${{ gitea.actor }} +run-name: "[test]: ${{ github.event.head_commit.message }}" on: push: diff --git a/.gitignore b/.gitignore index 616f9dd..dda84a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ +.claude-queue + ########################################################################## .idea/**/workspace.xml diff --git a/bfcodegen/unit_csid_test.go b/bfcodegen/unit_csid_test.go new file mode 100644 index 0000000..8630f8f --- /dev/null +++ b/bfcodegen/unit_csid_test.go @@ -0,0 +1,160 @@ +package bfcodegen + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestProcessCSIDFileSimple(t *testing.T) { + dir := t.TempDir() + + src := `package mymodels + +type UserID string // @csid:type [USR] +type OrderID string // @csid:type [ORD] +` + fp := writeTestFile(t, dir, "models.go", src) + + ids, pkg, err := processCSIDFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "mymodels") + tst.AssertEqual(t, len(ids), 2) + tst.AssertEqual(t, ids[0].Name, "UserID") + tst.AssertEqual(t, ids[0].Prefix, "USR") + tst.AssertEqual(t, ids[1].Name, "OrderID") + tst.AssertEqual(t, ids[1].Prefix, "ORD") + tst.AssertEqual(t, ids[0].FileRelative, "models.go") +} + +func TestProcessCSIDFilePrefixMustBeUppercase(t *testing.T) { + dir := t.TempDir() + + // lowercase prefix should not match the regex (only [A-Z0-9]{3}) + src := `package x + +type FooID string // @csid:type [usr] +` + fp := writeTestFile(t, dir, "x.go", src) + + ids, pkg, err := processCSIDFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "x") + tst.AssertEqual(t, len(ids), 0) +} + +func TestProcessCSIDFileGeneratedHeaderSkipped(t *testing.T) { + dir := t.TempDir() + + src := `// Code generated by csid-generate.go DO NOT EDIT. + +package x + +type SkipMeID string // @csid:type [SKP] +` + fp := writeTestFile(t, dir, "skip.go", src) + + ids, pkg, err := processCSIDFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "") + tst.AssertEqual(t, len(ids), 0) +} + +func TestGenerateCharsetIDSpecsEndToEnd(t *testing.T) { + dir := t.TempDir() + + src1 := `package models + +type EntityID string // @csid:type [ENT] +type UserID string // @csid:type [USR] +` + writeTestFile(t, dir, "a_models.go", src1) + + src2 := `package models + +type OrderID string // @csid:type [ORD] +` + writeTestFile(t, dir, "b_models.go", src2) + + dest := filepath.Join(dir, "csid_gen.go") + + err := GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{DebugOutput: langext.PFalse}) + tst.AssertNoErr(t, err) + + out, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + outStr := string(out) + + tst.AssertTrue(t, strings.Contains(outStr, "package models")) + tst.AssertTrue(t, strings.Contains(outStr, "ChecksumCharsetIDGenerator")) + tst.AssertTrue(t, strings.Contains(outStr, "func NewUserID()")) + tst.AssertTrue(t, strings.Contains(outStr, "func NewOrderID()")) + tst.AssertTrue(t, strings.Contains(outStr, "func NewEntityID()")) + tst.AssertTrue(t, strings.Contains(outStr, `prefixUserID`) && strings.Contains(outStr, `"USR"`)) + tst.AssertTrue(t, strings.Contains(outStr, `prefixOrderID`) && strings.Contains(outStr, `"ORD"`)) + tst.AssertTrue(t, strings.Contains(outStr, `prefixEntityID`) && strings.Contains(outStr, `"ENT"`)) +} + +func TestGenerateCharsetIDSpecsIdempotentWhenUnchanged(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type SomeID string // @csid:type [SOM] +` + writeTestFile(t, dir, "models.go", src) + dest := filepath.Join(dir, "csid_gen.go") + + err := GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{}) + tst.AssertNoErr(t, err) + + content1, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + + err = GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{}) + tst.AssertNoErr(t, err) + + content2, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, string(content1), string(content2)) +} + +func TestGenerateCharsetIDSpecsErrorsWithoutPackage(t *testing.T) { + dir := t.TempDir() + + src := `// Code generated by csid-generate.go DO NOT EDIT. + +package x + +type SkippedID string // @csid:type [SKP] +` + writeTestFile(t, dir, "z.go", src) + dest := filepath.Join(dir, "csid_gen.go") + + err := GenerateCharsetIDSpecs(dir, dest, CSIDGenOptions{}) + tst.AssertTrue(t, err != nil) +} + +func TestGenerateCharsetIDSpecsMissingDir(t *testing.T) { + dir := filepath.Join(t.TempDir(), "definitely-missing") + err := GenerateCharsetIDSpecs(dir, filepath.Join(dir, "csid_gen.go"), CSIDGenOptions{}) + tst.AssertTrue(t, err != nil) +} + +func TestFmtCSIDOutputContainsAllNames(t *testing.T) { + ids := []CSIDDef{ + {File: "a.go", FileRelative: "a.go", Name: "AlphaID", Prefix: "ALP"}, + {File: "b.go", FileRelative: "b.go", Name: "BetaID", Prefix: "BET"}, + } + out := fmtCSIDOutput("CHK_XYZ", ids, "models") + tst.AssertTrue(t, strings.Contains(out, "package models")) + tst.AssertTrue(t, strings.Contains(out, "CHK_XYZ")) + tst.AssertTrue(t, strings.Contains(out, "AlphaID")) + tst.AssertTrue(t, strings.Contains(out, "BetaID")) + tst.AssertTrue(t, strings.Contains(out, `prefixAlphaID`) && strings.Contains(out, `"ALP"`)) + tst.AssertTrue(t, strings.Contains(out, `prefixBetaID`) && strings.Contains(out, `"BET"`)) +} diff --git a/bfcodegen/unit_enum_test.go b/bfcodegen/unit_enum_test.go new file mode 100644 index 0000000..44d809e --- /dev/null +++ b/bfcodegen/unit_enum_test.go @@ -0,0 +1,369 @@ +package bfcodegen + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestProcessEnumFileBasicStringEnum(t *testing.T) { + dir := t.TempDir() + + src := `package mymodels + +type Color string // @enum:type + +const ( + ColorRed Color = "red" + ColorBlue Color = "blue" + ColorGreen Color = "green" +) +` + fp := writeTestFile(t, dir, "color.go", src) + + enums, pkg, err := processEnumFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "mymodels") + tst.AssertEqual(t, len(enums), 1) + tst.AssertEqual(t, enums[0].EnumTypeName, "Color") + tst.AssertEqual(t, enums[0].Type, "string") + tst.AssertEqual(t, len(enums[0].Values), 3) + tst.AssertEqual(t, enums[0].Values[0].VarName, "ColorRed") + tst.AssertEqual(t, enums[0].Values[0].Value, `"red"`) + tst.AssertEqual(t, enums[0].Values[1].VarName, "ColorBlue") + tst.AssertEqual(t, enums[0].Values[2].VarName, "ColorGreen") +} + +func TestProcessEnumFileIntEnum(t *testing.T) { + dir := t.TempDir() + + src := `package m + +type Priority int // @enum:type + +const ( + PriorityLow Priority = 0 + PriorityMedium Priority = 1 + PriorityHigh Priority = 2 +) +` + fp := writeTestFile(t, dir, "prio.go", src) + + enums, pkg, err := processEnumFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "m") + tst.AssertEqual(t, len(enums), 1) + tst.AssertEqual(t, enums[0].EnumTypeName, "Priority") + tst.AssertEqual(t, enums[0].Type, "int") + tst.AssertEqual(t, len(enums[0].Values), 3) + tst.AssertEqual(t, enums[0].Values[0].Value, "0") + tst.AssertEqual(t, enums[0].Values[2].Value, "2") +} + +func TestProcessEnumFileWithDescriptions(t *testing.T) { + dir := t.TempDir() + + src := `package m + +type Status string // @enum:type + +const ( + StatusActive Status = "active" // The active status + StatusInactive Status = "inactive" // The inactive status +) +` + fp := writeTestFile(t, dir, "s.go", src) + + enums, _, err := processEnumFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(enums), 1) + tst.AssertEqual(t, len(enums[0].Values), 2) + + v0 := enums[0].Values[0] + tst.AssertTrue(t, v0.Description != nil) + tst.AssertEqual(t, *v0.Description, "The active status") + tst.AssertTrue(t, v0.Data == nil) +} + +func TestProcessEnumFileWithDataComment(t *testing.T) { + dir := t.TempDir() + + src := `package m + +type Severity string // @enum:type + +const ( + SeverityLow Severity = "low" // {"description": "Low severity", "weight": 1} + SeverityHigh Severity = "high" // {"description": "High severity", "weight": 9} +) +` + fp := writeTestFile(t, dir, "sev.go", src) + + enums, _, err := processEnumFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(enums), 1) + tst.AssertEqual(t, len(enums[0].Values), 2) + + v0 := enums[0].Values[0] + tst.AssertTrue(t, v0.Data != nil) + tst.AssertTrue(t, v0.Description != nil) + tst.AssertEqual(t, *v0.Description, "Low severity") +} + +func TestProcessEnumFileNonMatchingValuesNotAttached(t *testing.T) { + dir := t.TempDir() + + src := `package m + +type Color string // @enum:type + +const ( + ColorRed Color = "red" +) + +const ( + OtherX OtherType = "x" +) +` + fp := writeTestFile(t, dir, "c.go", src) + + enums, _, err := processEnumFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(enums), 1) + tst.AssertEqual(t, len(enums[0].Values), 1) + tst.AssertEqual(t, enums[0].Values[0].VarName, "ColorRed") +} + +func TestProcessEnumFileGeneratedHeaderSkipped(t *testing.T) { + dir := t.TempDir() + + src := `// Code generated by enum-generate.go DO NOT EDIT. + +package x + +type Foo string // @enum:type +` + fp := writeTestFile(t, dir, "skip.go", src) + + enums, pkg, err := processEnumFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "") + tst.AssertEqual(t, len(enums), 0) +} + +func TestTryParseDataCommentValid(t *testing.T) { + m, ok := tryParseDataComment(`{"description": "hello", "weight": 5}`) + tst.AssertTrue(t, ok) + descr, _ := m["description"].(string) + tst.AssertEqual(t, descr, "hello") + weight, _ := m["weight"].(float64) + tst.AssertEqual(t, weight, float64(5)) +} + +func TestTryParseDataCommentBool(t *testing.T) { + m, ok := tryParseDataComment(`{"a": true, "b": false}`) + tst.AssertTrue(t, ok) + a, _ := m["a"].(bool) + tst.AssertTrue(t, a) + b, _ := m["b"].(bool) + tst.AssertFalse(t, b) +} + +func TestTryParseDataCommentRejectsNull(t *testing.T) { + // null becomes a nil interface — its reflect.Kind is Invalid, not Pointer, + // so it does not match any of the allowed kinds and is rejected. + _, ok := tryParseDataComment(`{"x": null}`) + tst.AssertFalse(t, ok) +} + +func TestTryParseDataCommentInvalidJSON(t *testing.T) { + _, ok := tryParseDataComment(`{not valid json}`) + tst.AssertFalse(t, ok) +} + +func TestTryParseDataCommentRejectsArrays(t *testing.T) { + // arrays as values are not in the supported kinds list + _, ok := tryParseDataComment(`{"x": [1, 2, 3]}`) + tst.AssertFalse(t, ok) +} + +func TestTryParseDataCommentRejectsObjects(t *testing.T) { + _, ok := tryParseDataComment(`{"x": {"nested": 1}}`) + tst.AssertFalse(t, ok) +} + +func TestGenerateEnumSpecsEndToEnd(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type Color string // @enum:type + +const ( + ColorRed Color = "red" + ColorGreen Color = "green" + ColorBlue Color = "blue" +) +` + writeTestFile(t, dir, "color.go", src) + dest := filepath.Join(dir, "enum_gen.go") + + err := GenerateEnumSpecs(dir, dest, EnumGenOptions{ + DebugOutput: langext.PFalse, + GoFormat: langext.PTrue, + }) + tst.AssertNoErr(t, err) + + out, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + outStr := string(out) + + tst.AssertTrue(t, strings.Contains(outStr, "package models")) + tst.AssertTrue(t, strings.Contains(outStr, "ChecksumEnumGenerator")) + tst.AssertTrue(t, strings.Contains(outStr, "ParseColor")) + tst.AssertTrue(t, strings.Contains(outStr, "ColorValues")) + tst.AssertTrue(t, strings.Contains(outStr, "ColorRed")) + tst.AssertTrue(t, strings.Contains(outStr, "ColorBlue")) + tst.AssertTrue(t, strings.Contains(outStr, "ColorGreen")) +} + +func TestGenerateEnumSpecsDeterministic(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type Status string // @enum:type + +const ( + StatusActive Status = "active" // The active one + StatusOff Status = "off" // The off one +) +` + writeTestFile(t, dir, "s.go", src) + + s1, cs1, changed1, err := _generateEnumSpecs(dir, "", "N/A", true, false) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, changed1) + + s2, cs2, changed2, err := _generateEnumSpecs(dir, "", "N/A", true, false) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, changed2) + + tst.AssertEqual(t, cs1, cs2) + tst.AssertEqual(t, s1, s2) +} + +func TestGenerateEnumSpecsNoChangeWhenChecksumMatches(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type Status string // @enum:type + +const ( + StatusActive Status = "active" +) +` + writeTestFile(t, dir, "s.go", src) + + _, cs, changed, err := _generateEnumSpecs(dir, "", "N/A", true, false) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, changed) + + s2, cs2, changed2, err := _generateEnumSpecs(dir, "", cs, true, false) + tst.AssertNoErr(t, err) + tst.AssertFalse(t, changed2) + tst.AssertEqual(t, cs2, cs) + tst.AssertEqual(t, s2, "") +} + +func TestGenerateEnumSpecsWithoutGoFormat(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type Color string // @enum:type + +const ( + ColorRed Color = "red" +) +` + writeTestFile(t, dir, "c.go", src) + + out, _, _, err := _generateEnumSpecs(dir, "", "N/A", false, false) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.Contains(out, "ColorRed")) + tst.AssertTrue(t, strings.Contains(out, "package models")) +} + +func TestGenerateEnumSpecsErrorsWithoutPackage(t *testing.T) { + dir := t.TempDir() + + src := `// Code generated by enum-generate.go DO NOT EDIT. + +package x + +type Foo string // @enum:type +` + writeTestFile(t, dir, "z.go", src) + + _, _, _, err := _generateEnumSpecs(dir, "", "N/A", false, false) + tst.AssertTrue(t, err != nil) +} + +func TestGenerateEnumSpecsMissingDir(t *testing.T) { + dir := filepath.Join(t.TempDir(), "definitely-missing") + _, _, _, err := _generateEnumSpecs(dir, "", "N/A", false, false) + tst.AssertTrue(t, err != nil) +} + +func TestFmtEnumOutputContainsTypes(t *testing.T) { + descr := "the red one" + enums := []EnumDef{ + { + File: "color.go", + FileRelative: "color.go", + EnumTypeName: "Color", + Type: "string", + Values: []EnumDefVal{ + {VarName: "ColorRed", Value: `"red"`, Description: &descr}, + {VarName: "ColorBlue", Value: `"blue"`, Description: &descr}, + }, + }, + } + out := fmtEnumOutput("CHK1", enums, "models") + tst.AssertTrue(t, strings.Contains(out, "package models")) + tst.AssertTrue(t, strings.Contains(out, "CHK1")) + tst.AssertTrue(t, strings.Contains(out, "ColorRed")) + tst.AssertTrue(t, strings.Contains(out, "ColorBlue")) + tst.AssertTrue(t, strings.Contains(out, "ParseColor")) +} + +func TestGenerateEnumSpecsSkipsGenFile(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type Color string // @enum:type + +const ( + ColorRed Color = "red" +) +` + writeTestFile(t, dir, "c.go", src) + + // generated file in same dir - should be filtered out + gensrc := `package models + +type ShouldBeIgnored string // @enum:type +` + writeTestFile(t, dir, "ignored_gen.go", gensrc) + + out, _, _, err := _generateEnumSpecs(dir, filepath.Join(dir, "enum_gen.go"), "N/A", false, false) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.Contains(out, "ColorRed")) + tst.AssertFalse(t, strings.Contains(out, "ShouldBeIgnored")) +} diff --git a/bfcodegen/unit_id_test.go b/bfcodegen/unit_id_test.go new file mode 100644 index 0000000..feeafb6 --- /dev/null +++ b/bfcodegen/unit_id_test.go @@ -0,0 +1,209 @@ +package bfcodegen + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "os" + "path/filepath" + "strings" + "testing" +) + +func writeTestFile(t *testing.T, dir string, name string, content string) string { + t.Helper() + p := filepath.Join(dir, name) + err := os.WriteFile(p, []byte(content), 0o644) + tst.AssertNoErr(t, err) + return p +} + +func TestProcessIDFileSimple(t *testing.T) { + dir := t.TempDir() + + src := `package mymodels + +type UserID string // @id:type +type OrderID string // @id:type +` + fp := writeTestFile(t, dir, "models.go", src) + + ids, pkg, err := processIDFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "mymodels") + tst.AssertEqual(t, len(ids), 2) + tst.AssertEqual(t, ids[0].Name, "UserID") + tst.AssertEqual(t, ids[1].Name, "OrderID") + tst.AssertEqual(t, ids[0].FileRelative, "models.go") +} + +func TestProcessIDFileNoMatches(t *testing.T) { + dir := t.TempDir() + + src := `package x + +type Foo string +type Bar int +type Baz string // not the right marker +` + fp := writeTestFile(t, dir, "x.go", src) + + ids, pkg, err := processIDFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "x") + tst.AssertEqual(t, len(ids), 0) +} + +func TestProcessIDFileGeneratedHeaderSkipped(t *testing.T) { + dir := t.TempDir() + + src := `// Code generated by id-generate.go DO NOT EDIT. + +package x + +type SkipMeID string // @id:type +` + fp := writeTestFile(t, dir, "skip.go", src) + + ids, pkg, err := processIDFile(dir, fp, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, pkg, "") + tst.AssertEqual(t, len(ids), 0) +} + +func TestProcessIDFileMissingFile(t *testing.T) { + _, _, err := processIDFile(t.TempDir(), filepath.Join(t.TempDir(), "does_not_exist.go"), false) + tst.AssertTrue(t, err != nil) +} + +func TestGenerateIDSpecsEndToEnd(t *testing.T) { + dir := t.TempDir() + + src1 := `package models + +type UserID string // @id:type +type AnyID string // @id:type +` + writeTestFile(t, dir, "a_models.go", src1) + + src2 := `package models + +type OrderID string // @id:type +` + writeTestFile(t, dir, "b_models.go", src2) + + dest := filepath.Join(dir, "id_gen.go") + + err := GenerateIDSpecs(dir, dest, IDGenOptions{DebugOutput: langext.PFalse}) + tst.AssertNoErr(t, err) + + out, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + outStr := string(out) + + tst.AssertTrue(t, strings.Contains(outStr, "package models")) + tst.AssertTrue(t, strings.Contains(outStr, "ChecksumIDGenerator")) + tst.AssertTrue(t, strings.Contains(outStr, "func NewUserID()")) + tst.AssertTrue(t, strings.Contains(outStr, "func NewOrderID()")) + tst.AssertTrue(t, strings.Contains(outStr, "func NewAnyID()")) + tst.AssertTrue(t, strings.Contains(outStr, "AsAny()")) +} + +func TestGenerateIDSpecsIdempotentWhenUnchanged(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type SomeID string // @id:type +` + writeTestFile(t, dir, "models.go", src) + dest := filepath.Join(dir, "id_gen.go") + + err := GenerateIDSpecs(dir, dest, IDGenOptions{}) + tst.AssertNoErr(t, err) + + stat1, err := os.Stat(dest) + tst.AssertNoErr(t, err) + + content1, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + + err = GenerateIDSpecs(dir, dest, IDGenOptions{}) + tst.AssertNoErr(t, err) + + stat2, err := os.Stat(dest) + tst.AssertNoErr(t, err) + + content2, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, stat1.ModTime().Equal(stat2.ModTime()), true) + tst.AssertEqual(t, string(content1), string(content2)) +} + +func TestGenerateIDSpecsRegeneratesAfterChange(t *testing.T) { + dir := t.TempDir() + + src := `package models + +type FirstID string // @id:type +` + fp := writeTestFile(t, dir, "models.go", src) + dest := filepath.Join(dir, "id_gen.go") + + err := GenerateIDSpecs(dir, dest, IDGenOptions{}) + tst.AssertNoErr(t, err) + + content1, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.Contains(string(content1), "FirstID")) + tst.AssertFalse(t, strings.Contains(string(content1), "SecondID")) + + src2 := `package models + +type FirstID string // @id:type +type SecondID string // @id:type +` + err = os.WriteFile(fp, []byte(src2), 0o644) + tst.AssertNoErr(t, err) + + err = GenerateIDSpecs(dir, dest, IDGenOptions{}) + tst.AssertNoErr(t, err) + + content2, err := os.ReadFile(dest) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.Contains(string(content2), "SecondID")) +} + +func TestGenerateIDSpecsErrorsWithoutPackage(t *testing.T) { + dir := t.TempDir() + + src := `// Code generated by id-generate.go DO NOT EDIT. + +package x + +type SkippedID string // @id:type +` + writeTestFile(t, dir, "z.go", src) + dest := filepath.Join(dir, "id_gen.go") + + err := GenerateIDSpecs(dir, dest, IDGenOptions{}) + tst.AssertTrue(t, err != nil) +} + +func TestGenerateIDSpecsMissingDir(t *testing.T) { + dir := filepath.Join(t.TempDir(), "definitely-missing") + err := GenerateIDSpecs(dir, filepath.Join(dir, "id_gen.go"), IDGenOptions{}) + tst.AssertTrue(t, err != nil) +} + +func TestFmtIDOutputContainsAllNames(t *testing.T) { + ids := []IDDef{ + {File: "a.go", FileRelative: "a.go", Name: "AlphaID"}, + {File: "b.go", FileRelative: "b.go", Name: "BetaID"}, + } + out := fmtIDOutput("CHK_ABC", ids, "models") + tst.AssertTrue(t, strings.Contains(out, "package models")) + tst.AssertTrue(t, strings.Contains(out, "CHK_ABC")) + tst.AssertTrue(t, strings.Contains(out, "AlphaID")) + tst.AssertTrue(t, strings.Contains(out, "BetaID")) +} diff --git a/cmdext/builder_test.go b/cmdext/builder_test.go new file mode 100644 index 0000000..6c576f1 --- /dev/null +++ b/cmdext/builder_test.go @@ -0,0 +1,324 @@ +package cmdext + +import ( + "reflect" + "testing" + "time" +) + +func TestRunnerInit(t *testing.T) { + r := Runner("myprog") + + if r == nil { + t.Fatalf("Runner returned nil") + } + if r.program != "myprog" { + t.Errorf("program == %v, want myprog", r.program) + } + if r.args == nil { + t.Errorf("args is nil, want empty slice") + } + if len(r.args) != 0 { + t.Errorf("len(args) == %v, want 0", len(r.args)) + } + if r.env == nil { + t.Errorf("env is nil, want empty slice") + } + if len(r.env) != 0 { + t.Errorf("len(env) == %v, want 0", len(r.env)) + } + if r.listener == nil { + t.Errorf("listener is nil, want empty slice") + } + if len(r.listener) != 0 { + t.Errorf("len(listener) == %v, want 0", len(r.listener)) + } + if r.timeout != nil { + t.Errorf("timeout == %v, want nil", r.timeout) + } + if r.enforceExitCodes != nil { + t.Errorf("enforceExitCodes == %v, want nil", r.enforceExitCodes) + } + if r.enforceNoTimeout { + t.Errorf("enforceNoTimeout == true, want false") + } + if r.enforceNoStderr { + t.Errorf("enforceNoStderr == true, want false") + } +} + +func TestArgSingle(t *testing.T) { + r := Runner("p").Arg("a") + if !reflect.DeepEqual(r.args, []string{"a"}) { + t.Errorf("args == %v, want [a]", r.args) + } +} + +func TestArgMultiple(t *testing.T) { + r := Runner("p").Arg("a").Arg("b").Arg("c") + if !reflect.DeepEqual(r.args, []string{"a", "b", "c"}) { + t.Errorf("args == %v, want [a b c]", r.args) + } +} + +func TestArgsAppendsAll(t *testing.T) { + r := Runner("p").Args([]string{"x", "y"}).Args([]string{"z"}) + if !reflect.DeepEqual(r.args, []string{"x", "y", "z"}) { + t.Errorf("args == %v, want [x y z]", r.args) + } +} + +func TestArgAndArgsMixed(t *testing.T) { + r := Runner("p").Arg("a").Args([]string{"b", "c"}).Arg("d") + if !reflect.DeepEqual(r.args, []string{"a", "b", "c", "d"}) { + t.Errorf("args == %v, want [a b c d]", r.args) + } +} + +func TestArgsEmptySlice(t *testing.T) { + r := Runner("p").Args([]string{}) + if len(r.args) != 0 { + t.Errorf("len(args) == %v, want 0", len(r.args)) + } +} + +func TestTimeoutSet(t *testing.T) { + d := 500 * time.Millisecond + r := Runner("p").Timeout(d) + if r.timeout == nil { + t.Fatalf("timeout is nil") + } + if *r.timeout != d { + t.Errorf("timeout == %v, want %v", *r.timeout, d) + } +} + +func TestTimeoutOverride(t *testing.T) { + r := Runner("p").Timeout(1 * time.Second).Timeout(2 * time.Second) + if *r.timeout != 2*time.Second { + t.Errorf("timeout == %v, want 2s", *r.timeout) + } +} + +func TestEnv(t *testing.T) { + r := Runner("p").Env("KEY", "VALUE") + if !reflect.DeepEqual(r.env, []string{"KEY=VALUE"}) { + t.Errorf("env == %v, want [KEY=VALUE]", r.env) + } +} + +func TestEnvMultiple(t *testing.T) { + r := Runner("p").Env("A", "1").Env("B", "2") + if !reflect.DeepEqual(r.env, []string{"A=1", "B=2"}) { + t.Errorf("env == %v, want [A=1 B=2]", r.env) + } +} + +func TestEnvWithEmptyValue(t *testing.T) { + r := Runner("p").Env("KEY", "") + if !reflect.DeepEqual(r.env, []string{"KEY="}) { + t.Errorf("env == %v, want [KEY=]", r.env) + } +} + +func TestRawEnv(t *testing.T) { + r := Runner("p").RawEnv("FOO=BAR=BAZ") + if !reflect.DeepEqual(r.env, []string{"FOO=BAR=BAZ"}) { + t.Errorf("env == %v, want [FOO=BAR=BAZ]", r.env) + } +} + +func TestEnvs(t *testing.T) { + r := Runner("p").Envs([]string{"A=1", "B=2"}) + if !reflect.DeepEqual(r.env, []string{"A=1", "B=2"}) { + t.Errorf("env == %v, want [A=1 B=2]", r.env) + } +} + +func TestEnvMixed(t *testing.T) { + r := Runner("p").Env("A", "1").RawEnv("B=2").Envs([]string{"C=3", "D=4"}) + if !reflect.DeepEqual(r.env, []string{"A=1", "B=2", "C=3", "D=4"}) { + t.Errorf("env == %v, want [A=1 B=2 C=3 D=4]", r.env) + } +} + +func TestEnsureExitcodeSingle(t *testing.T) { + r := Runner("p").EnsureExitcode(2) + if r.enforceExitCodes == nil { + t.Fatalf("enforceExitCodes is nil") + } + if !reflect.DeepEqual(*r.enforceExitCodes, []int{2}) { + t.Errorf("enforceExitCodes == %v, want [2]", *r.enforceExitCodes) + } +} + +func TestEnsureExitcodeMultiple(t *testing.T) { + r := Runner("p").EnsureExitcode(0, 1, 2) + if r.enforceExitCodes == nil { + t.Fatalf("enforceExitCodes is nil") + } + if !reflect.DeepEqual(*r.enforceExitCodes, []int{0, 1, 2}) { + t.Errorf("enforceExitCodes == %v, want [0 1 2]", *r.enforceExitCodes) + } +} + +func TestFailOnExitCode(t *testing.T) { + r := Runner("p").FailOnExitCode() + if r.enforceExitCodes == nil { + t.Fatalf("enforceExitCodes is nil") + } + if !reflect.DeepEqual(*r.enforceExitCodes, []int{0}) { + t.Errorf("enforceExitCodes == %v, want [0]", *r.enforceExitCodes) + } +} + +func TestFailOnTimeoutFlag(t *testing.T) { + r := Runner("p") + if r.enforceNoTimeout { + t.Errorf("enforceNoTimeout was true before set") + } + r = r.FailOnTimeout() + if !r.enforceNoTimeout { + t.Errorf("enforceNoTimeout == false after FailOnTimeout()") + } +} + +func TestFailOnStderrFlag(t *testing.T) { + r := Runner("p") + if r.enforceNoStderr { + t.Errorf("enforceNoStderr was true before set") + } + r = r.FailOnStderr() + if !r.enforceNoStderr { + t.Errorf("enforceNoStderr == false after FailOnStderr()") + } +} + +func TestListen(t *testing.T) { + r := Runner("p").Listen(genericCommandListener{}) + if len(r.listener) != 1 { + t.Errorf("len(listener) == %v, want 1", len(r.listener)) + } +} + +func TestListenMultiple(t *testing.T) { + r := Runner("p"). + Listen(genericCommandListener{}). + Listen(genericCommandListener{}). + Listen(genericCommandListener{}) + if len(r.listener) != 3 { + t.Errorf("len(listener) == %v, want 3", len(r.listener)) + } +} + +func TestListenStdoutAddsListener(t *testing.T) { + r := Runner("p").ListenStdout(func(string) {}) + if len(r.listener) != 1 { + t.Errorf("len(listener) == %v, want 1", len(r.listener)) + } +} + +func TestListenStdoutForwardsCalls(t *testing.T) { + got := "" + r := Runner("p").ListenStdout(func(s string) { got = s }) + if len(r.listener) != 1 { + t.Fatalf("len(listener) == %v, want 1", len(r.listener)) + } + r.listener[0].ReadStdoutLine("hello") + if got != "hello" { + t.Errorf("listener got %q, want hello", got) + } + // non-stdout methods should not panic and should not affect state + r.listener[0].ReadStderrLine("nope") + r.listener[0].ReadRawStdout([]byte("raw")) + r.listener[0].ReadRawStderr([]byte("raw")) + r.listener[0].Finished(0) + r.listener[0].Timeout() + if got != "hello" { + t.Errorf("listener got mutated to %q, want hello", got) + } +} + +func TestListenStderrAddsListener(t *testing.T) { + r := Runner("p").ListenStderr(func(string) {}) + if len(r.listener) != 1 { + t.Errorf("len(listener) == %v, want 1", len(r.listener)) + } +} + +func TestListenStderrForwardsCalls(t *testing.T) { + got := "" + r := Runner("p").ListenStderr(func(s string) { got = s }) + if len(r.listener) != 1 { + t.Fatalf("len(listener) == %v, want 1", len(r.listener)) + } + r.listener[0].ReadStderrLine("oops") + if got != "oops" { + t.Errorf("listener got %q, want oops", got) + } + r.listener[0].ReadStdoutLine("nope") + if got != "oops" { + t.Errorf("listener got mutated to %q, want oops", got) + } +} + +func TestChainReturnsSameInstance(t *testing.T) { + r := Runner("p") + if r.Arg("a") != r { + t.Errorf("Arg returned different instance") + } + if r.Args([]string{"b"}) != r { + t.Errorf("Args returned different instance") + } + if r.Timeout(time.Second) != r { + t.Errorf("Timeout returned different instance") + } + if r.Env("K", "V") != r { + t.Errorf("Env returned different instance") + } + if r.RawEnv("K=V") != r { + t.Errorf("RawEnv returned different instance") + } + if r.Envs([]string{"K=V"}) != r { + t.Errorf("Envs returned different instance") + } + if r.EnsureExitcode(0) != r { + t.Errorf("EnsureExitcode returned different instance") + } + if r.FailOnExitCode() != r { + t.Errorf("FailOnExitCode returned different instance") + } + if r.FailOnTimeout() != r { + t.Errorf("FailOnTimeout returned different instance") + } + if r.FailOnStderr() != r { + t.Errorf("FailOnStderr returned different instance") + } + if r.Listen(genericCommandListener{}) != r { + t.Errorf("Listen returned different instance") + } + if r.ListenStdout(func(string) {}) != r { + t.Errorf("ListenStdout returned different instance") + } + if r.ListenStderr(func(string) {}) != r { + t.Errorf("ListenStderr returned different instance") + } +} + +func TestSeparateInstancesIndependent(t *testing.T) { + r1 := Runner("p1").Arg("a") + r2 := Runner("p2").Arg("b") + + if r1.program != "p1" { + t.Errorf("r1.program == %v, want p1", r1.program) + } + if r2.program != "p2" { + t.Errorf("r2.program == %v, want p2", r2.program) + } + if !reflect.DeepEqual(r1.args, []string{"a"}) { + t.Errorf("r1.args == %v, want [a]", r1.args) + } + if !reflect.DeepEqual(r2.args, []string{"b"}) { + t.Errorf("r2.args == %v, want [b]", r2.args) + } +} diff --git a/cmdext/listener_test.go b/cmdext/listener_test.go new file mode 100644 index 0000000..c63d8b1 --- /dev/null +++ b/cmdext/listener_test.go @@ -0,0 +1,143 @@ +package cmdext + +import ( + "reflect" + "testing" +) + +func TestGenericListenerEmptyDoesNotPanic(t *testing.T) { + l := genericCommandListener{} + l.ReadRawStdout([]byte("x")) + l.ReadRawStderr([]byte("x")) + l.ReadStdoutLine("x") + l.ReadStderrLine("x") + l.Finished(0) + l.Timeout() +} + +func TestGenericListenerReadRawStdout(t *testing.T) { + var got []byte + fn := func(b []byte) { got = append(got, b...) } + l := genericCommandListener{_readRawStdout: &fn} + + l.ReadRawStdout([]byte("hello")) + l.ReadRawStdout([]byte(" world")) + + if string(got) != "hello world" { + t.Errorf("got %q, want %q", string(got), "hello world") + } +} + +func TestGenericListenerReadRawStderr(t *testing.T) { + var got []byte + fn := func(b []byte) { got = append(got, b...) } + l := genericCommandListener{_readRawStderr: &fn} + + l.ReadRawStderr([]byte("err")) + + if string(got) != "err" { + t.Errorf("got %q, want %q", string(got), "err") + } +} + +func TestGenericListenerReadStdoutLine(t *testing.T) { + var got []string + fn := func(s string) { got = append(got, s) } + l := genericCommandListener{_readStdoutLine: &fn} + + l.ReadStdoutLine("line1") + l.ReadStdoutLine("line2") + + if !reflect.DeepEqual(got, []string{"line1", "line2"}) { + t.Errorf("got %v, want [line1 line2]", got) + } +} + +func TestGenericListenerReadStderrLine(t *testing.T) { + var got []string + fn := func(s string) { got = append(got, s) } + l := genericCommandListener{_readStderrLine: &fn} + + l.ReadStderrLine("line1") + l.ReadStderrLine("line2") + + if !reflect.DeepEqual(got, []string{"line1", "line2"}) { + t.Errorf("got %v, want [line1 line2]", got) + } +} + +func TestGenericListenerFinished(t *testing.T) { + var got int + called := false + fn := func(v int) { got = v; called = true } + l := genericCommandListener{_finished: &fn} + + l.Finished(42) + + if !called { + t.Errorf("Finished callback was not called") + } + if got != 42 { + t.Errorf("got %v, want 42", got) + } +} + +func TestGenericListenerTimeout(t *testing.T) { + called := false + fn := func() { called = true } + l := genericCommandListener{_timeout: &fn} + + l.Timeout() + + if !called { + t.Errorf("Timeout callback was not called") + } +} + +func TestGenericListenerOnlySpecifiedCalled(t *testing.T) { + stdoutCalled := false + stderrCalled := false + stdoutFn := func(string) { stdoutCalled = true } + stderrFn := func(string) { stderrCalled = true } + l := genericCommandListener{_readStdoutLine: &stdoutFn, _readStderrLine: &stderrFn} + + l.ReadStdoutLine("x") + if !stdoutCalled { + t.Errorf("stdout callback not called") + } + if stderrCalled { + t.Errorf("stderr callback called when it shouldn't be") + } + + stdoutCalled = false + l.ReadStderrLine("x") + if stdoutCalled { + t.Errorf("stdout callback called when it shouldn't be") + } + if !stderrCalled { + t.Errorf("stderr callback not called") + } + + // these have no callbacks set; should be no-ops + l.ReadRawStdout([]byte("x")) + l.ReadRawStderr([]byte("x")) + l.Finished(0) + l.Timeout() +} + +func TestGenericListenerImplementsCommandListener(t *testing.T) { + var _ CommandListener = genericCommandListener{} +} + +func TestGenericListenerEmptyByteSlice(t *testing.T) { + calls := 0 + fn := func(b []byte) { calls++ } + l := genericCommandListener{_readRawStdout: &fn} + + l.ReadRawStdout([]byte{}) + l.ReadRawStdout(nil) + + if calls != 2 { + t.Errorf("calls == %v, want 2", calls) + } +} diff --git a/confext/confParser_extra_test.go b/confext/confParser_extra_test.go new file mode 100644 index 0000000..8e0f059 --- /dev/null +++ b/confext/confParser_extra_test.go @@ -0,0 +1,390 @@ +package confext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" + "time" +) + +func TestApplyEnvOverridesPrefix(t *testing.T) { + type testdata struct { + V1 int `env:"V1"` + V2 string `env:"V2"` + } + + data := testdata{V1: 1, V2: "x"} + + t.Setenv("MYAPP_V1", "42") + t.Setenv("MYAPP_V2", "hello") + t.Setenv("V1", "111") + t.Setenv("V2", "noprefix") + + err := ApplyEnvOverrides("MYAPP_", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.V1, 42) + tst.AssertEqual(t, data.V2, "hello") +} + +func TestApplyEnvOverridesUnexportedFieldsIgnored(t *testing.T) { + type testdata struct { + V1 int `env:"TEST_V1"` + v2 int `env:"TEST_V2"` + } + + data := testdata{V1: 1, v2: 2} + + t.Setenv("TEST_V1", "11") + t.Setenv("TEST_V2", "22") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.V1, 11) + tst.AssertEqual(t, data.v2, 2) +} + +func TestApplyEnvOverridesNoEnvTagIgnored(t *testing.T) { + type testdata struct { + V1 int `env:"TEST_V1"` + V2 int `` + } + + data := testdata{V1: 1, V2: 2} + + t.Setenv("TEST_V1", "11") + t.Setenv("V2", "22") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.V1, 11) + tst.AssertEqual(t, data.V2, 2) +} + +func TestApplyEnvOverridesDashTagIgnored(t *testing.T) { + type testdata struct { + V1 int `env:"TEST_V1"` + V2 string `env:"-"` + } + + data := testdata{V1: 1, V2: "no"} + + t.Setenv("TEST_V1", "11") + t.Setenv("-", "yes") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.V1, 11) + tst.AssertEqual(t, data.V2, "no") +} + +func TestApplyEnvOverridesEnvNotSetKeepsValue(t *testing.T) { + type testdata struct { + V1 int `env:"NOT_SET_INT_KEY_XYZ"` + V2 string `env:"NOT_SET_STR_KEY_XYZ"` + } + + data := testdata{V1: 7, V2: "keep"} + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.V1, 7) + tst.AssertEqual(t, data.V2, "keep") +} + +func TestApplyEnvOverridesBoolVariants(t *testing.T) { + type testdata struct { + B1 bool `env:"B1"` + B2 bool `env:"B2"` + B3 bool `env:"B3"` + B4 bool `env:"B4"` + B5 bool `env:"B5"` + B6 bool `env:"B6"` + } + + data := testdata{} + + t.Setenv("B1", "true") + t.Setenv("B2", "false") + t.Setenv("B3", "1") + t.Setenv("B4", "0") + t.Setenv("B5", " TRUE ") + t.Setenv("B6", "FaLsE") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.B1, true) + tst.AssertEqual(t, data.B2, false) + tst.AssertEqual(t, data.B3, true) + tst.AssertEqual(t, data.B4, false) + tst.AssertEqual(t, data.B5, true) + tst.AssertEqual(t, data.B6, false) +} + +func TestApplyEnvOverridesInvalidIntReturnsError(t *testing.T) { + type testdata struct { + V1 int `env:"BAD_INT"` + } + + data := testdata{} + + t.Setenv("BAD_INT", "not_a_number") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid int, got nil") + } +} + +func TestApplyEnvOverridesInvalidInt8ReturnsError(t *testing.T) { + type testdata struct { + V1 int8 `env:"BAD_INT8"` + } + + data := testdata{} + + t.Setenv("BAD_INT8", "9999") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid int8, got nil") + } +} + +func TestApplyEnvOverridesInvalidInt32ReturnsError(t *testing.T) { + type testdata struct { + V1 int32 `env:"BAD_INT32"` + } + + data := testdata{} + + t.Setenv("BAD_INT32", "not_an_int32") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid int32, got nil") + } +} + +func TestApplyEnvOverridesInvalidInt64ReturnsError(t *testing.T) { + type testdata struct { + V1 int64 `env:"BAD_INT64"` + } + + data := testdata{} + + t.Setenv("BAD_INT64", "not_an_int64") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid int64, got nil") + } +} + +func TestApplyEnvOverridesInvalidDurationReturnsError(t *testing.T) { + type testdata struct { + V1 time.Duration `env:"BAD_DUR"` + } + + data := testdata{} + + t.Setenv("BAD_DUR", "not_a_duration") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid duration, got nil") + } +} + +func TestApplyEnvOverridesInvalidTimeReturnsError(t *testing.T) { + type testdata struct { + V1 time.Time `env:"BAD_TIME"` + } + + data := testdata{} + + t.Setenv("BAD_TIME", "not_a_time") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid time, got nil") + } +} + +func TestApplyEnvOverridesInvalidBoolReturnsError(t *testing.T) { + type testdata struct { + V1 bool `env:"BAD_BOOL"` + } + + data := testdata{} + + t.Setenv("BAD_BOOL", "yesno") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid bool, got nil") + } +} + +func TestApplyEnvOverridesUnsupportedTypeReturnsError(t *testing.T) { + type testdata struct { + V1 []int `env:"UNSUPPORTED"` + } + + data := testdata{} + + t.Setenv("UNSUPPORTED", "1,2,3") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for unsupported type, got nil") + } +} + +func TestApplyEnvOverridesFloatUnsupportedReturnsError(t *testing.T) { + type testdata struct { + V1 float64 `env:"UNSUPPORTED_FLOAT"` + } + + data := testdata{} + + t.Setenv("UNSUPPORTED_FLOAT", "1.5") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for float64, got nil") + } +} + +func TestApplyEnvOverridesPointerInvalidReturnsError(t *testing.T) { + type testdata struct { + V1 *int `env:"PTR_BAD_INT"` + } + + data := testdata{} + + t.Setenv("PTR_BAD_INT", "not_a_number") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error for invalid pointer int, got nil") + } +} + +func TestApplyEnvOverridesPointerNotSetStaysNil(t *testing.T) { + type testdata struct { + V1 *int `env:"PTR_NOT_SET_KEY_ABC"` + V2 *string `env:"PTR_NOT_SET_KEY_DEF"` + } + + data := testdata{} + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + if data.V1 != nil { + t.Errorf("expected V1 to remain nil, got %v", *data.V1) + } + if data.V2 != nil { + t.Errorf("expected V2 to remain nil, got %v", *data.V2) + } +} + +func TestApplyEnvOverridesAliasBool(t *testing.T) { + type aliasbool bool + + type testdata struct { + V1 aliasbool `env:"ALIAS_BOOL"` + } + + data := testdata{} + + t.Setenv("ALIAS_BOOL", "true") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.V1, aliasbool(true)) +} + +func TestApplyEnvOverridesNestedRecursiveError(t *testing.T) { + type subdata struct { + V1 int `env:"V1"` + } + type testdata struct { + Sub subdata `env:"SUB"` + } + + data := testdata{} + + t.Setenv("SUB.V1", "not_a_number") + + err := ApplyEnvOverrides("", &data, ".") + if err == nil { + t.Errorf("expected error from nested struct invalid value, got nil") + } +} + +func TestApplyEnvOverridesTimeFieldInsideStructIsParsed(t *testing.T) { + type testdata struct { + T time.Time `env:"MYTIME"` + } + + data := testdata{} + + t.Setenv("MYTIME", "2023-01-02T03:04:05Z") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.T.Equal(time.Date(2023, 1, 2, 3, 4, 5, 0, time.UTC)), true) +} + +func TestApplyEnvOverridesPointerStringAlias(t *testing.T) { + type aliasstr string + + type testdata struct { + V1 *aliasstr `env:"PTR_ALIAS_STR"` + } + + data := testdata{} + + t.Setenv("PTR_ALIAS_STR", "hello") + + err := ApplyEnvOverrides("", &data, ".") + tst.AssertNoErr(t, err) + + if data.V1 == nil { + t.Fatalf("expected V1 to be set") + } + tst.AssertEqual(t, *data.V1, aliasstr("hello")) +} + +func TestApplyEnvOverridesEmptyEnvTagOnSubstruct(t *testing.T) { + type subdata struct { + V1 int `env:"INNER"` + } + type testdata struct { + Sub subdata `env:""` + } + + data := testdata{} + + t.Setenv("INNER", "55") + + err := ApplyEnvOverrides("PRE_", &data, "_") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.Sub.V1, 0) + + t.Setenv("PRE_INNER", "77") + + err = ApplyEnvOverrides("PRE_", &data, "_") + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, data.Sub.V1, 77) +} diff --git a/cryptext/aes_additional_test.go b/cryptext/aes_additional_test.go new file mode 100644 index 0000000..ab3ae02 --- /dev/null +++ b/cryptext/aes_additional_test.go @@ -0,0 +1,139 @@ +package cryptext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestAESSimpleEmptyData(t *testing.T) { + pw := []byte("password") + enc, err := EncryptAESSimple(pw, []byte{}, 256) + tst.AssertNoErr(t, err) + tst.AssertNotEqual(t, enc, "") + + dec, err := DecryptAESSimple(pw, enc) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(dec), 0) +} + +func TestAESSimpleEmptyPassword(t *testing.T) { + pw := []byte{} + plain := []byte("some content") + + enc, err := EncryptAESSimple(pw, plain, 256) + tst.AssertNoErr(t, err) + + dec, err := DecryptAESSimple(pw, enc) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(dec), string(plain)) +} + +func TestAESSimpleWrongPassword(t *testing.T) { + plain := []byte("Hello World") + enc, err := EncryptAESSimple([]byte("right"), plain, 256) + tst.AssertNoErr(t, err) + + _, err = DecryptAESSimple([]byte("wrong"), enc) + if err == nil { + t.Errorf("expected error when decrypting with wrong password") + } +} + +func TestAESSimpleInvalidBase32(t *testing.T) { + _, err := DecryptAESSimple([]byte("pw"), "!!!not-base32!!!") + if err == nil { + t.Errorf("expected error on invalid base32 input") + } +} + +func TestAESSimpleInvalidJSON(t *testing.T) { + // "AAAAAAAA" decodes to valid base32 but not valid JSON + _, err := DecryptAESSimple([]byte("pw"), "AAAAAAAA") + if err == nil { + t.Errorf("expected error on invalid JSON payload") + } +} + +func TestAESSimpleEmptyEncText(t *testing.T) { + _, err := DecryptAESSimple([]byte("pw"), "") + if err == nil { + t.Errorf("expected error on empty text") + } +} + +func TestAESSimpleLargeData(t *testing.T) { + pw := []byte("hunter12") + plain := []byte(strings.Repeat("ABCDEFGHIJ", 1024)) + + enc, err := EncryptAESSimple(pw, plain, 256) + tst.AssertNoErr(t, err) + + dec, err := DecryptAESSimple(pw, enc) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(dec), string(plain)) +} + +func TestAESSimpleBinaryData(t *testing.T) { + pw := []byte("hunter12") + plain := []byte{0x00, 0x01, 0x02, 0x7F, 0x80, 0xFE, 0xFF, 0x00, 0xAA, 0x55} + + enc, err := EncryptAESSimple(pw, plain, 256) + tst.AssertNoErr(t, err) + + dec, err := DecryptAESSimple(pw, enc) + tst.AssertNoErr(t, err) + tst.AssertArrayEqual(t, dec, plain) +} + +func TestAESSimpleDifferentRoundsForEachCall(t *testing.T) { + pw := []byte("hunter12") + plain := []byte("Hello") + + enc1, err := EncryptAESSimple(pw, plain, 256) + tst.AssertNoErr(t, err) + + enc2, err := EncryptAESSimple(pw, plain, 256) + tst.AssertNoErr(t, err) + + // Two separate encrypt calls on same plaintext should differ (random salt + IV) + tst.AssertNotEqual(t, enc1, enc2) + + // Both should decrypt back to the same plaintext + d1, err := DecryptAESSimple(pw, enc1) + tst.AssertNoErr(t, err) + d2, err := DecryptAESSimple(pw, enc2) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(d1), string(plain)) + tst.AssertEqual(t, string(d2), string(plain)) +} + +func TestAESSimpleVariableRounds(t *testing.T) { + pw := []byte("hunter12") + plain := []byte("rounds-test") + + for _, r := range []int{16, 32, 64, 128, 256, 512, 1024} { + enc, err := EncryptAESSimple(pw, plain, r) + tst.AssertNoErr(t, err) + dec, err := DecryptAESSimple(pw, enc) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(dec), string(plain)) + } +} + +func TestAESSimpleResultIsBase32(t *testing.T) { + pw := []byte("hunter12") + plain := []byte("Hello World") + + enc, err := EncryptAESSimple(pw, plain, 64) + tst.AssertNoErr(t, err) + + for _, c := range enc { + isUpper := c >= 'A' && c <= 'Z' + isDigit := c >= '2' && c <= '7' + if !(isUpper || isDigit) { + t.Errorf("non-base32 character %q in output", c) + break + } + } +} diff --git a/cryptext/hash_additional_test.go b/cryptext/hash_additional_test.go new file mode 100644 index 0000000..4ad5713 --- /dev/null +++ b/cryptext/hash_additional_test.go @@ -0,0 +1,54 @@ +package cryptext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestStrSha256SameAsBytesSha256(t *testing.T) { + inputs := []string{"", "a", "Hello World", "lorem ipsum dolor sit amet", "🎉 unicode"} + for _, in := range inputs { + tst.AssertEqual(t, StrSha256(in), BytesSha256([]byte(in))) + } +} + +func TestStrSha256Length(t *testing.T) { + // SHA-256 hex output must always be 64 characters + tst.AssertEqual(t, len(StrSha256("")), 64) + tst.AssertEqual(t, len(StrSha256("x")), 64) + tst.AssertEqual(t, len(StrSha256(strings.Repeat("x", 10000))), 64) +} + +func TestStrSha256Deterministic(t *testing.T) { + v := "deterministic input" + a := StrSha256(v) + b := StrSha256(v) + tst.AssertEqual(t, a, b) +} + +func TestStrSha256DifferentInputs(t *testing.T) { + tst.AssertNotEqual(t, StrSha256("a"), StrSha256("b")) + tst.AssertNotEqual(t, StrSha256("Hello"), StrSha256("hello")) + tst.AssertNotEqual(t, StrSha256("Hello World"), StrSha256("Hello World ")) +} + +func TestStrSha256IsHex(t *testing.T) { + out := StrSha256("anything") + for _, c := range out { + isLowerHex := (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') + if !isLowerHex { + t.Errorf("non-hex char %q in StrSha256 output", c) + return + } + } +} + +func TestBytesSha256NilSameAsEmpty(t *testing.T) { + tst.AssertEqual(t, BytesSha256(nil), BytesSha256([]byte{})) +} + +func TestBytesSha256KnownVectors(t *testing.T) { + // "abc" => sha-256 standard vector + tst.AssertEqual(t, StrSha256("abc"), "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") +} diff --git a/cryptext/passHash_additional_test.go b/cryptext/passHash_additional_test.go new file mode 100644 index 0000000..0d21a84 --- /dev/null +++ b/cryptext/passHash_additional_test.go @@ -0,0 +1,379 @@ +package cryptext + +import ( + "encoding/json" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestPassHashInvalidEmpty(t *testing.T) { + ph := PassHash("") + tst.AssertFalse(t, ph.Valid()) + tst.AssertFalse(t, ph.HasTOTP()) + tst.AssertFalse(t, ph.NeedsPasswordUpgrade()) +} + +func TestPassHashInvalidGarbage(t *testing.T) { + for _, raw := range []string{ + "garbage", + "99|nope", + "abc|payload", + "3|onlytwo", + "4|onlytwo", + "5|onlytwo", + "2|notbase64!|notbase64!", + "1|!!!notbase64!!!", + "3|!!notb64|!!notb64|0", + "3|abc|!!notb64|0", + } { + ph := PassHash(raw) + if ph.Valid() { + t.Errorf("expected %q to be invalid", raw) + } + } +} + +func TestPassHashVerifyInvalid(t *testing.T) { + ph := PassHash("garbage-value") + tst.AssertFalse(t, ph.Verify("anything", nil)) +} + +func TestPassHashUpgradeInvalid(t *testing.T) { + ph := PassHash("garbage-value") + _, err := ph.Upgrade("anything") + if err == nil { + t.Errorf("expected error for invalid PassHash upgrade") + } +} + +func TestPassHashStringRoundtrip(t *testing.T) { + ph, err := HashPassword("hunter2", nil) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, ph.String(), string(ph)) +} + +func TestPassHashMarshalJSONEmpty(t *testing.T) { + ph := PassHash("") + data, err := json.Marshal(ph) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(data), `""`) +} + +func TestPassHashMarshalJSONMasked(t *testing.T) { + ph, err := HashPassword("hunter2", nil) + tst.AssertNoErr(t, err) + data, err := json.Marshal(ph) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(data), `"*****"`) +} + +func TestPassHashDataV0(t *testing.T) { + ph, err := HashPasswordV0("test123") + tst.AssertNoErr(t, err) + + v, seed, payload, hastotp, totpsecret, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, 0) + tst.AssertEqual(t, len(seed), 0) + tst.AssertEqual(t, string(payload), "test123") + tst.AssertFalse(t, hastotp) + tst.AssertEqual(t, len(totpsecret), 0) +} + +func TestPassHashDataV1(t *testing.T) { + ph, err := HashPasswordV1("test123") + tst.AssertNoErr(t, err) + + v, seed, payload, hastotp, _, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, 1) + tst.AssertEqual(t, len(seed), 0) + tst.AssertEqual(t, len(payload), 32) // sha-256 is 32 bytes + tst.AssertFalse(t, hastotp) +} + +func TestPassHashDataV2(t *testing.T) { + ph, err := HashPasswordV2("test123") + tst.AssertNoErr(t, err) + + v, seed, payload, hastotp, _, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, 2) + tst.AssertEqual(t, len(seed), 32) + tst.AssertEqual(t, len(payload), 32) + tst.AssertFalse(t, hastotp) +} + +func TestPassHashDataV3(t *testing.T) { + ph, err := HashPasswordV3("test123", nil) + tst.AssertNoErr(t, err) + + v, seed, payload, hastotp, _, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, 3) + tst.AssertEqual(t, len(seed), 32) + tst.AssertEqual(t, len(payload), 32) + tst.AssertFalse(t, hastotp) +} + +func TestPassHashDataV4(t *testing.T) { + ph, err := HashPasswordV4("test123", nil) + tst.AssertNoErr(t, err) + + v, _, _, hastotp, _, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, 4) + tst.AssertFalse(t, hastotp) +} + +func TestPassHashDataV5(t *testing.T) { + ph, err := HashPasswordV5("test123", nil) + tst.AssertNoErr(t, err) + + v, _, _, hastotp, _, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, 5) + tst.AssertFalse(t, hastotp) +} + +func TestPassHashLatestIsV5(t *testing.T) { + ph, err := HashPassword("test", nil) + tst.AssertNoErr(t, err) + v, _, _, _, _, valid := ph.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, LatestPassHashVersion) + tst.AssertEqual(t, v, 5) +} + +func TestPassHashUpgradeLatestIsNoop(t *testing.T) { + ph, err := HashPassword("test", nil) + tst.AssertNoErr(t, err) + + tst.AssertFalse(t, ph.NeedsPasswordUpgrade()) + + ph2, err := ph.Upgrade("test") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(ph), string(ph2)) +} + +func TestPassHashClearTOTPInvalid(t *testing.T) { + _, err := PassHash("garbage").ClearTOTP() + if err == nil { + t.Errorf("expected error from ClearTOTP on invalid") + } +} + +func TestPassHashClearTOTPV0V1V2Noop(t *testing.T) { + ph0, _ := HashPasswordV0("x") + r0, err := ph0.ClearTOTP() + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(r0), string(ph0)) + + ph1, _ := HashPasswordV1("x") + r1, err := ph1.ClearTOTP() + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(r1), string(ph1)) + + ph2, _ := HashPasswordV2("x") + r2, err := ph2.ClearTOTP() + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(r2), string(ph2)) +} + +func TestPassHashClearTOTPV3(t *testing.T) { + secret := []byte{0x01, 0x02, 0x03} + ph, err := HashPasswordV3("test123", secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, ph.HasTOTP()) + + cleared, err := ph.ClearTOTP() + tst.AssertNoErr(t, err) + tst.AssertFalse(t, cleared.HasTOTP()) + tst.AssertTrue(t, cleared.Valid()) + tst.AssertTrue(t, cleared.Verify("test123", nil)) +} + +func TestPassHashClearTOTPV4(t *testing.T) { + secret := []byte{0x01, 0x02, 0x03} + ph, err := HashPasswordV4("test123", secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, ph.HasTOTP()) + + cleared, err := ph.ClearTOTP() + tst.AssertNoErr(t, err) + tst.AssertFalse(t, cleared.HasTOTP()) + tst.AssertTrue(t, cleared.Verify("test123", nil)) +} + +func TestPassHashClearTOTPV5(t *testing.T) { + secret := []byte{0x01, 0x02, 0x03} + ph, err := HashPasswordV5("test123", secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, ph.HasTOTP()) + + cleared, err := ph.ClearTOTP() + tst.AssertNoErr(t, err) + tst.AssertFalse(t, cleared.HasTOTP()) + tst.AssertTrue(t, cleared.Verify("test123", nil)) +} + +func TestPassHashWithTOTPInvalid(t *testing.T) { + _, err := PassHash("garbage").WithTOTP([]byte{0x01}) + if err == nil { + t.Errorf("expected error for WithTOTP on invalid") + } +} + +func TestPassHashWithTOTPV0V1V2Errors(t *testing.T) { + ph0, _ := HashPasswordV0("x") + if _, err := ph0.WithTOTP([]byte{0x01}); err == nil { + t.Errorf("expected v0 not to support TOTP") + } + ph1, _ := HashPasswordV1("x") + if _, err := ph1.WithTOTP([]byte{0x01}); err == nil { + t.Errorf("expected v1 not to support TOTP") + } + ph2, _ := HashPasswordV2("x") + if _, err := ph2.WithTOTP([]byte{0x01}); err == nil { + t.Errorf("expected v2 not to support TOTP") + } +} + +func TestPassHashWithTOTPV3V4V5(t *testing.T) { + secret := []byte{0xDE, 0xAD, 0xBE, 0xEF} + + ph3, _ := HashPasswordV3("pw", nil) + tst.AssertFalse(t, ph3.HasTOTP()) + r3, err := ph3.WithTOTP(secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, r3.HasTOTP()) + + ph4, _ := HashPasswordV4("pw", nil) + tst.AssertFalse(t, ph4.HasTOTP()) + r4, err := ph4.WithTOTP(secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, r4.HasTOTP()) + + ph5, _ := HashPasswordV5("pw", nil) + tst.AssertFalse(t, ph5.HasTOTP()) + r5, err := ph5.WithTOTP(secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, r5.HasTOTP()) +} + +func TestPassHashChangeInvalid(t *testing.T) { + _, err := PassHash("garbage").Change("new-pw") + if err == nil { + t.Errorf("expected error from Change on invalid") + } +} + +func TestPassHashChangeKeepsVersion(t *testing.T) { + cases := []struct { + name string + hashed func() (PassHash, error) + version int + }{ + {"V0", func() (PassHash, error) { return HashPasswordV0("old") }, 0}, + {"V1", func() (PassHash, error) { return HashPasswordV1("old") }, 1}, + {"V2", func() (PassHash, error) { return HashPasswordV2("old") }, 2}, + {"V3", func() (PassHash, error) { return HashPasswordV3("old", nil) }, 3}, + {"V4", func() (PassHash, error) { return HashPasswordV4("old", nil) }, 4}, + {"V5", func() (PassHash, error) { return HashPasswordV5("old", nil) }, 5}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + ph, err := c.hashed() + tst.AssertNoErr(t, err) + + changed, err := ph.Change("new-pw") + tst.AssertNoErr(t, err) + + v, _, _, _, _, valid := changed.Data() + tst.AssertTrue(t, valid) + tst.AssertEqual(t, v, c.version) + + tst.AssertTrue(t, changed.Verify("new-pw", nil)) + tst.AssertFalse(t, changed.Verify("old", nil)) + }) + } +} + +func TestPassHashChangeKeepsTOTPV3(t *testing.T) { + secret := []byte{0xAB, 0xCD} + ph, err := HashPasswordV3("old", secret) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, ph.HasTOTP()) + + changed, err := ph.Change("new") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, changed.HasTOTP()) +} + +func TestPassHashV0Format(t *testing.T) { + ph, err := HashPasswordV0("plaintext-pw") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.HasPrefix(string(ph), "0|")) + tst.AssertEqual(t, string(ph), "0|plaintext-pw") +} + +func TestPassHashV1Format(t *testing.T) { + ph, err := HashPasswordV1("test") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.HasPrefix(string(ph), "1|")) +} + +func TestPassHashV2Format(t *testing.T) { + ph, err := HashPasswordV2("test") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.HasPrefix(string(ph), "2|")) + tst.AssertEqual(t, strings.Count(string(ph), "|"), 2) +} + +func TestPassHashV3Format(t *testing.T) { + ph, err := HashPasswordV3("test", nil) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.HasPrefix(string(ph), "3|")) + tst.AssertEqual(t, strings.Count(string(ph), "|"), 3) + tst.AssertTrue(t, strings.HasSuffix(string(ph), "|0")) +} + +func TestPassHashV4Format(t *testing.T) { + ph, err := HashPasswordV4("test", nil) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.HasPrefix(string(ph), "4|")) + tst.AssertTrue(t, strings.HasSuffix(string(ph), "|0")) +} + +func TestPassHashV5Format(t *testing.T) { + ph, err := HashPasswordV5("test", nil) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.HasPrefix(string(ph), "5|")) + tst.AssertTrue(t, strings.HasSuffix(string(ph), "|0")) +} + +func TestPassHashV5VerifyLongPassword(t *testing.T) { + // V5 hashes via sha512 first → bcrypt's 72-byte limit shouldn't apply + longPw := strings.Repeat("a", 200) + ph, err := HashPasswordV5(longPw, nil) + tst.AssertNoErr(t, err) + + tst.AssertTrue(t, ph.Verify(longPw, nil)) + tst.AssertFalse(t, ph.Verify(longPw+"x", nil)) +} + +func TestPassHashV5DifferentEachCall(t *testing.T) { + ph1, err := HashPasswordV5("samepw", nil) + tst.AssertNoErr(t, err) + ph2, err := HashPasswordV5("samepw", nil) + tst.AssertNoErr(t, err) + + // Bcrypt salts internally — same password should produce different hashes + tst.AssertNotEqual(t, string(ph1), string(ph2)) + + // Both must verify + tst.AssertTrue(t, ph1.Verify("samepw", nil)) + tst.AssertTrue(t, ph2.Verify("samepw", nil)) +} diff --git a/cryptext/pronouncablePassword_additional_test.go b/cryptext/pronouncablePassword_additional_test.go new file mode 100644 index 0000000..0cc13af --- /dev/null +++ b/cryptext/pronouncablePassword_additional_test.go @@ -0,0 +1,180 @@ +package cryptext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + mathrand "math/rand" + "strings" + "testing" + "unicode" +) + +func TestPronouncablePasswordLength(t *testing.T) { + for _, n := range []int{1, 2, 3, 5, 8, 13, 21, 50, 128} { + pw := PronouncablePassword(n) + tst.AssertEqual(t, len(pw), n) + } +} + +func TestPronouncablePasswordZeroOrNegative(t *testing.T) { + tst.AssertEqual(t, PronouncablePassword(0), "") + tst.AssertEqual(t, PronouncablePassword(-1), "") + tst.AssertEqual(t, PronouncablePassword(-1000), "") +} + +func TestPronouncablePasswordSeededDeterministic(t *testing.T) { + pw1 := PronouncablePasswordSeeded(42, 16) + pw2 := PronouncablePasswordSeeded(42, 16) + tst.AssertEqual(t, pw1, pw2) + tst.AssertEqual(t, len(pw1), 16) +} + +func TestPronouncablePasswordSeededDifferentSeeds(t *testing.T) { + pw1 := PronouncablePasswordSeeded(1, 16) + pw2 := PronouncablePasswordSeeded(2, 16) + tst.AssertNotEqual(t, pw1, pw2) +} + +func TestPronouncablePasswordExtEntropy(t *testing.T) { + rng := mathrand.New(mathrand.NewSource(1)) + pw, entropy := PronouncablePasswordExt(rng, 32) + tst.AssertEqual(t, len(pw), 32) + if entropy <= 0 { + t.Errorf("expected positive entropy, got %f", entropy) + } +} + +func TestPronouncablePasswordExtZeroLen(t *testing.T) { + rng := mathrand.New(mathrand.NewSource(1)) + pw, entropy := PronouncablePasswordExt(rng, 0) + tst.AssertEqual(t, pw, "") + tst.AssertEqual(t, entropy, float64(0)) +} + +func TestPronouncablePasswordCharacters(t *testing.T) { + // Output should be only ASCII letters + for i := range 50 { + pw := PronouncablePasswordSeeded(int64(i), 32) + for _, c := range pw { + if !unicode.IsLetter(c) || c > unicode.MaxASCII { + t.Errorf("non-letter or non-ASCII rune %q in password %q", c, pw) + break + } + } + } +} + +func TestPronouncablePasswordStartsUpper(t *testing.T) { + for i := range 50 { + pw := PronouncablePasswordSeeded(int64(i), 16) + if pw == "" { + continue + } + first := rune(pw[0]) + if !unicode.IsUpper(first) { + t.Errorf("expected first letter uppercase in %q (seed %d)", pw, i) + } + if !strings.ContainsRune(ppStartChar, first) { + t.Errorf("expected first letter from start-set in %q (seed %d)", pw, i) + } + } +} + +func TestPpMakeSet(t *testing.T) { + set := ppMakeSet("ABC") + tst.AssertTrue(t, set['A']) + tst.AssertTrue(t, set['B']) + tst.AssertTrue(t, set['C']) + tst.AssertFalse(t, set['D']) + tst.AssertEqual(t, len(set), 3) +} + +func TestPpMakeSetEmpty(t *testing.T) { + set := ppMakeSet("") + tst.AssertEqual(t, len(set), 0) +} + +func TestPpCharType(t *testing.T) { + v, c := ppCharType('A') + tst.AssertTrue(t, v) + tst.AssertFalse(t, c) + + v, c = ppCharType('B') + tst.AssertFalse(t, v) + tst.AssertTrue(t, c) + + v, c = ppCharType('Y') + tst.AssertTrue(t, v) + tst.AssertFalse(t, c) + + v, c = ppCharType('1') + tst.AssertFalse(t, v) + tst.AssertFalse(t, c) +} + +func TestPpCharsetRemove(t *testing.T) { + set := ppMakeSet("AEIOU") + out := ppCharsetRemove("ABCDEFG", set, false) + tst.AssertEqual(t, out, "BCDFG") +} + +func TestPpCharsetRemoveEmptyDisallowed(t *testing.T) { + set := ppMakeSet("AB") + out := ppCharsetRemove("AB", set, false) + // when result would be empty and allowEmpty=false, it returns the original + tst.AssertEqual(t, out, "AB") +} + +func TestPpCharsetRemoveEmptyAllowed(t *testing.T) { + set := ppMakeSet("AB") + out := ppCharsetRemove("AB", set, true) + tst.AssertEqual(t, out, "") +} + +func TestPpCharsetFilter(t *testing.T) { + set := ppMakeSet("AEIOU") + out := ppCharsetFilter("ABCDEFG", set, false) + tst.AssertEqual(t, out, "AE") +} + +func TestPpCharsetFilterEmptyDisallowed(t *testing.T) { + set := ppMakeSet("XYZ") + out := ppCharsetFilter("ABC", set, false) + tst.AssertEqual(t, out, "ABC") // returns original when result empty & not allowed +} + +func TestPpCharsetFilterEmptyAllowed(t *testing.T) { + set := ppMakeSet("XYZ") + out := ppCharsetFilter("ABC", set, true) + tst.AssertEqual(t, out, "") +} + +func TestPronouncablePasswordContinuationFollowsRules(t *testing.T) { + // Make sure each continuation pair (lowercased) appears in ppContinuation + // Note: when a new segment starts (uppercase letter mid-string), the continuation + // check does not apply across the segment boundary. + for s := range 30 { + seed := int64(s) + pw := PronouncablePasswordSeeded(seed, 32) + if len(pw) < 2 { + continue + } + runes := []byte(strings.ToUpper(pw)) + for i := 1; i < len(runes); i++ { + // Detect new segment (original char was uppercase and it's not the first char) + origUpper := pw[i] >= 'A' && pw[i] <= 'Z' + if origUpper && i > 0 { + continue + } + prev := runes[i-1] + cur := runes[i] + cont, ok := ppContinuation[prev] + if !ok { + t.Errorf("no continuation map for %q (pw=%q)", prev, pw) + continue + } + if !strings.ContainsRune(cont, rune(cur)) { + t.Errorf("invalid continuation %q -> %q in %q (seed %d)", prev, cur, pw, seed) + } + } + } +} diff --git a/ctxext/getter_test.go b/ctxext/getter_test.go new file mode 100644 index 0000000..9bb16a7 --- /dev/null +++ b/ctxext/getter_test.go @@ -0,0 +1,237 @@ +package ctxext + +import ( + "context" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +type ctxKey string + +const ( + keyString ctxKey = "string-key" + keyInt ctxKey = "int-key" + keyStruct ctxKey = "struct-key" + keyPtr ctxKey = "ptr-key" + keyMissing ctxKey = "missing-key" +) + +type sampleStruct struct { + Name string + N int +} + +func TestValueStringPresent(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "hello") + v, ok := Value[string](ctx, keyString) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v, "hello") +} + +func TestValueIntPresent(t *testing.T) { + ctx := context.WithValue(context.Background(), keyInt, 42) + v, ok := Value[int](ctx, keyInt) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v, 42) +} + +func TestValueStructPresent(t *testing.T) { + want := sampleStruct{Name: "abc", N: 7} + ctx := context.WithValue(context.Background(), keyStruct, want) + v, ok := Value[sampleStruct](ctx, keyStruct) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v.Name, "abc") + tst.AssertEqual(t, v.N, 7) +} + +func TestValuePointerPresent(t *testing.T) { + want := &sampleStruct{Name: "ptr", N: 99} + ctx := context.WithValue(context.Background(), keyPtr, want) + v, ok := Value[*sampleStruct](ctx, keyPtr) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v == want, true) + tst.AssertEqual(t, v.Name, "ptr") +} + +func TestValueMissing(t *testing.T) { + ctx := context.Background() + v, ok := Value[string](ctx, keyMissing) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, v, "") +} + +func TestValueMissingInt(t *testing.T) { + ctx := context.Background() + v, ok := Value[int](ctx, keyMissing) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, v, 0) +} + +func TestValueMissingStruct(t *testing.T) { + ctx := context.Background() + v, ok := Value[sampleStruct](ctx, keyMissing) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, v.Name, "") + tst.AssertEqual(t, v.N, 0) +} + +func TestValueMissingPointer(t *testing.T) { + ctx := context.Background() + v, ok := Value[*sampleStruct](ctx, keyMissing) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, v == nil, true) +} + +func TestValueWrongType(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "hello") + v, ok := Value[int](ctx, keyString) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, v, 0) +} + +func TestValueWrongTypeStructToString(t *testing.T) { + ctx := context.WithValue(context.Background(), keyStruct, sampleStruct{Name: "x"}) + v, ok := Value[string](ctx, keyStruct) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, v, "") +} + +func TestValueNilStoredAsInterface(t *testing.T) { + var stored *sampleStruct = nil + ctx := context.WithValue(context.Background(), keyPtr, stored) + v, ok := Value[*sampleStruct](ctx, keyPtr) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v == nil, true) +} + +func TestValueEmptyString(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "") + v, ok := Value[string](ctx, keyString) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v, "") +} + +func TestValueZeroInt(t *testing.T) { + ctx := context.WithValue(context.Background(), keyInt, 0) + v, ok := Value[int](ctx, keyInt) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v, 0) +} + +func TestValueWithStringKey(t *testing.T) { + type stringKey string + k := stringKey("my-key") + ctx := context.WithValue(context.Background(), k, "value") + v, ok := Value[string](ctx, k) + tst.AssertEqual(t, ok, true) + tst.AssertEqual(t, v, "value") +} + +func TestValueOrDefaultPresent(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "hello") + v := ValueOrDefault(ctx, keyString, "default") + tst.AssertEqual(t, v, "hello") +} + +func TestValueOrDefaultIntPresent(t *testing.T) { + ctx := context.WithValue(context.Background(), keyInt, 42) + v := ValueOrDefault(ctx, keyInt, -1) + tst.AssertEqual(t, v, 42) +} + +func TestValueOrDefaultMissing(t *testing.T) { + ctx := context.Background() + v := ValueOrDefault(ctx, keyMissing, "default") + tst.AssertEqual(t, v, "default") +} + +func TestValueOrDefaultMissingInt(t *testing.T) { + ctx := context.Background() + v := ValueOrDefault(ctx, keyMissing, 99) + tst.AssertEqual(t, v, 99) +} + +func TestValueOrDefaultMissingStruct(t *testing.T) { + ctx := context.Background() + def := sampleStruct{Name: "default", N: 1} + v := ValueOrDefault(ctx, keyMissing, def) + tst.AssertEqual(t, v.Name, "default") + tst.AssertEqual(t, v.N, 1) +} + +func TestValueOrDefaultWrongType(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "hello") + v := ValueOrDefault(ctx, keyString, 7) + tst.AssertEqual(t, v, 7) +} + +func TestValueOrDefaultWrongTypeStruct(t *testing.T) { + ctx := context.WithValue(context.Background(), keyStruct, sampleStruct{Name: "x"}) + def := "fallback" + v := ValueOrDefault(ctx, keyStruct, def) + tst.AssertEqual(t, v, "fallback") +} + +func TestValueOrDefaultEmptyStringStored(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "") + v := ValueOrDefault(ctx, keyString, "default") + tst.AssertEqual(t, v, "") +} + +func TestValueOrDefaultZeroIntStored(t *testing.T) { + ctx := context.WithValue(context.Background(), keyInt, 0) + v := ValueOrDefault(ctx, keyInt, 99) + tst.AssertEqual(t, v, 0) +} + +func TestValueOrDefaultPointerPresent(t *testing.T) { + want := &sampleStruct{Name: "p", N: 5} + ctx := context.WithValue(context.Background(), keyPtr, want) + def := &sampleStruct{Name: "def", N: 0} + v := ValueOrDefault(ctx, keyPtr, def) + tst.AssertEqual(t, v == want, true) +} + +func TestValueOrDefaultPointerMissing(t *testing.T) { + ctx := context.Background() + def := &sampleStruct{Name: "def", N: 0} + v := ValueOrDefault(ctx, keyMissing, def) + tst.AssertEqual(t, v == def, true) +} + +func TestValueOrDefaultNilPointerStored(t *testing.T) { + var stored *sampleStruct = nil + ctx := context.WithValue(context.Background(), keyPtr, stored) + def := &sampleStruct{Name: "def"} + v := ValueOrDefault(ctx, keyPtr, def) + tst.AssertEqual(t, v == nil, true) +} + +func TestValueNestedContext(t *testing.T) { + ctx := context.WithValue(context.Background(), keyString, "outer") + ctx = context.WithValue(ctx, keyInt, 123) + ctx = context.WithValue(ctx, keyString, "inner") + + vs, oks := Value[string](ctx, keyString) + tst.AssertEqual(t, oks, true) + tst.AssertEqual(t, vs, "inner") + + vi, oki := Value[int](ctx, keyInt) + tst.AssertEqual(t, oki, true) + tst.AssertEqual(t, vi, 123) +} + +func TestValueDifferentKeyTypesDoNotCollide(t *testing.T) { + type keyA string + type keyB string + ctx := context.WithValue(context.Background(), keyA("k"), "a-val") + ctx = context.WithValue(ctx, keyB("k"), "b-val") + + va, oka := Value[string](ctx, keyA("k")) + tst.AssertEqual(t, oka, true) + tst.AssertEqual(t, va, "a-val") + + vb, okb := Value[string](ctx, keyB("k")) + tst.AssertEqual(t, okb, true) + tst.AssertEqual(t, vb, "b-val") +} diff --git a/cursortoken/direction_test.go b/cursortoken/direction_test.go new file mode 100644 index 0000000..f95633c --- /dev/null +++ b/cursortoken/direction_test.go @@ -0,0 +1,29 @@ +package cursortoken + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestSortDirectionToMongoASC(t *testing.T) { + tst.AssertEqual(t, SortASC.ToMongo(), 1) +} + +func TestSortDirectionToMongoDESC(t *testing.T) { + tst.AssertEqual(t, SortDESC.ToMongo(), -1) +} + +func TestSortDirectionToMongoEmpty(t *testing.T) { + var sd SortDirection + tst.AssertEqual(t, sd.ToMongo(), 0) +} + +func TestSortDirectionToMongoUnknown(t *testing.T) { + sd := SortDirection("xyz") + tst.AssertEqual(t, sd.ToMongo(), 0) +} + +func TestSortDirectionConstants(t *testing.T) { + tst.AssertEqual(t, string(SortASC), "ASC") + tst.AssertEqual(t, string(SortDESC), "DESC") +} diff --git a/cursortoken/tokenKeySort_test.go b/cursortoken/tokenKeySort_test.go new file mode 100644 index 0000000..3190a5f --- /dev/null +++ b/cursortoken/tokenKeySort_test.go @@ -0,0 +1,136 @@ +package cursortoken + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" + "time" +) + +func TestStartToken(t *testing.T) { + tok := Start() + tst.AssertEqual(t, tok.Token(), "@start") + tst.AssertTrue(t, tok.IsStart()) + tst.AssertFalse(t, tok.IsEnd()) +} + +func TestEndToken(t *testing.T) { + tok := End() + tst.AssertEqual(t, tok.Token(), "@end") + tst.AssertTrue(t, tok.IsEnd()) + tst.AssertFalse(t, tok.IsStart()) +} + +func TestNewKeySortTokenBasic(t *testing.T) { + tok := NewKeySortToken("alpha", "beta", SortASC, SortDESC, 50, Extra{}) + tst.AssertFalse(t, tok.IsEnd()) + tst.AssertFalse(t, tok.IsStart()) + str := tok.Token() + tst.AssertTrue(t, strings.HasPrefix(str, "tok_")) +} + +func TestNewKeySortTokenRoundTrip(t *testing.T) { + original := NewKeySortToken("primary-val", "secondary-val", SortASC, SortDESC, 25, Extra{}) + encoded := original.Token() + + decoded, err := Decode(encoded) + tst.AssertNoErr(t, err) + + ks, ok := decoded.(CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, ks.ValuePrimary, "primary-val") + tst.AssertEqual(t, ks.ValueSecondary, "secondary-val") + tst.AssertEqual(t, ks.Direction, SortASC) + tst.AssertEqual(t, ks.DirectionSecondary, SortDESC) + tst.AssertEqual(t, ks.PageSize, 25) + tst.AssertEqual(t, ks.Mode, CTMNormal) +} + +func TestKeySortTokenWithExtra(t *testing.T) { + ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC) + id := "object-id-123" + page := 7 + pageSize := 42 + + original := NewKeySortToken("p", "s", SortDESC, SortASC, 10, Extra{ + Timestamp: &ts, + Id: &id, + Page: &page, + PageSize: &pageSize, + }) + encoded := original.Token() + + decoded, err := Decode(encoded) + tst.AssertNoErr(t, err) + + ks, ok := decoded.(CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertTrue(t, ks.Extra.Timestamp != nil) + tst.AssertTrue(t, ks.Extra.Timestamp.Equal(ts)) + tst.AssertDeRefEqual(t, ks.Extra.Id, "object-id-123") + tst.AssertDeRefEqual(t, ks.Extra.Page, 7) + tst.AssertDeRefEqual(t, ks.Extra.PageSize, 42) +} + +func TestKeySortTokenStartRoundTrip(t *testing.T) { + original := Start() + decoded, err := Decode(original.Token()) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, decoded.IsStart()) + tst.AssertFalse(t, decoded.IsEnd()) +} + +func TestKeySortTokenEndRoundTrip(t *testing.T) { + original := End() + decoded, err := Decode(original.Token()) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, decoded.IsEnd()) + tst.AssertFalse(t, decoded.IsStart()) +} + +func TestKeySortTokenEmptyValues(t *testing.T) { + tok := CTKeySort{Mode: CTMNormal} + encoded := tok.Token() + tst.AssertTrue(t, strings.HasPrefix(encoded, "tok_")) + + decoded, err := Decode(encoded) + tst.AssertNoErr(t, err) + + ks, ok := decoded.(CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, ks.ValuePrimary, "") + tst.AssertEqual(t, ks.ValueSecondary, "") + tst.AssertEqual(t, ks.Direction, SortDirection("")) + tst.AssertEqual(t, ks.DirectionSecondary, SortDirection("")) + tst.AssertEqual(t, ks.PageSize, 0) +} + +func TestKeySortTokenOnlyTimestamp(t *testing.T) { + ts := time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC) + tok := CTKeySort{ + Mode: CTMNormal, + Extra: Extra{Timestamp: &ts}, + } + + decoded, err := Decode(tok.Token()) + tst.AssertNoErr(t, err) + + ks, ok := decoded.(CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertTrue(t, ks.Extra.Timestamp != nil) + tst.AssertTrue(t, ks.Extra.Timestamp.Equal(ts)) + tst.AssertTrue(t, ks.Extra.Id == nil) + tst.AssertTrue(t, ks.Extra.Page == nil) + tst.AssertTrue(t, ks.Extra.PageSize == nil) +} + +func TestKeySortTokenSpecialChars(t *testing.T) { + original := NewKeySortToken("hello world / @!#$%", "äöü€", SortASC, SortASC, 1, Extra{}) + decoded, err := Decode(original.Token()) + tst.AssertNoErr(t, err) + + ks, ok := decoded.(CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, ks.ValuePrimary, "hello world / @!#$%") + tst.AssertEqual(t, ks.ValueSecondary, "äöü€") +} diff --git a/cursortoken/tokenPaginate_test.go b/cursortoken/tokenPaginate_test.go new file mode 100644 index 0000000..c5f08b4 --- /dev/null +++ b/cursortoken/tokenPaginate_test.go @@ -0,0 +1,61 @@ +package cursortoken + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestPageToken(t *testing.T) { + tok := Page(5) + tst.AssertEqual(t, tok.Token(), "$5") + tst.AssertFalse(t, tok.IsEnd()) + tst.AssertFalse(t, tok.IsStart()) +} + +func TestPageTokenOne(t *testing.T) { + tok := Page(1) + tst.AssertEqual(t, tok.Token(), "$1") + tst.AssertFalse(t, tok.IsEnd()) + tst.AssertTrue(t, tok.IsStart()) +} + +func TestPageTokenLarge(t *testing.T) { + tok := Page(123456) + tst.AssertEqual(t, tok.Token(), "$123456") +} + +func TestPageTokenZero(t *testing.T) { + tok := Page(0) + tst.AssertEqual(t, tok.Token(), "$0") + tst.AssertFalse(t, tok.IsEnd()) + tst.AssertFalse(t, tok.IsStart()) +} + +func TestPageEndToken(t *testing.T) { + tok := PageEnd() + tst.AssertEqual(t, tok.Token(), "$end") + tst.AssertTrue(t, tok.IsEnd()) + tst.AssertFalse(t, tok.IsStart()) +} + +func TestPaginatedStartMode(t *testing.T) { + tok := CTPaginated{Mode: CTMStart, Page: 0} + tst.AssertEqual(t, tok.Token(), "$1") + tst.AssertTrue(t, tok.IsStart()) + tst.AssertFalse(t, tok.IsEnd()) +} + +func TestPaginatedEndMode(t *testing.T) { + tok := CTPaginated{Mode: CTMEnd, Page: 99} + tst.AssertEqual(t, tok.Token(), "$end") + tst.AssertTrue(t, tok.IsEnd()) +} + +func TestPaginatedRoundTrip(t *testing.T) { + for _, page := range []int{2, 3, 7, 100, 9999} { + tok := Page(page) + decoded, err := Decode(tok.Token()) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, decoded.Token(), tok.Token()) + } +} diff --git a/cursortoken/token_test.go b/cursortoken/token_test.go new file mode 100644 index 0000000..926624d --- /dev/null +++ b/cursortoken/token_test.go @@ -0,0 +1,125 @@ +package cursortoken + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestDecodeEmpty(t *testing.T) { + tok, err := Decode("") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsStart()) + tst.AssertFalse(t, tok.IsEnd()) + tst.AssertEqual(t, tok.Token(), "@start") +} + +func TestDecodeAtStart(t *testing.T) { + tok, err := Decode("@start") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsStart()) + tst.AssertFalse(t, tok.IsEnd()) +} + +func TestDecodeAtStartUppercase(t *testing.T) { + tok, err := Decode("@START") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsStart()) +} + +func TestDecodeAtStartMixedCase(t *testing.T) { + tok, err := Decode("@StArT") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsStart()) +} + +func TestDecodeAtEnd(t *testing.T) { + tok, err := Decode("@end") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsEnd()) + tst.AssertFalse(t, tok.IsStart()) +} + +func TestDecodeAtEndUppercase(t *testing.T) { + tok, err := Decode("@END") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsEnd()) +} + +func TestDecodeDollarEnd(t *testing.T) { + tok, err := Decode("$end") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsEnd()) + _, ok := tok.(CTPaginated) + tst.AssertTrue(t, ok) +} + +func TestDecodeDollarEndUppercase(t *testing.T) { + tok, err := Decode("$END") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsEnd()) +} + +func TestDecodeDollarPage(t *testing.T) { + tok, err := Decode("$5") + tst.AssertNoErr(t, err) + pg, ok := tok.(CTPaginated) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, pg.Page, 5) + tst.AssertEqual(t, pg.Mode, CTMNormal) +} + +func TestDecodeDollarPageOne(t *testing.T) { + tok, err := Decode("$1") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, tok.IsStart()) + pg, ok := tok.(CTPaginated) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, pg.Page, 1) +} + +func TestDecodeDollarPageInvalid(t *testing.T) { + _, err := Decode("$abc") + if err == nil { + t.Fatalf("expected error for invalid page") + } +} + +func TestDecodeUnknownPrefix(t *testing.T) { + _, err := Decode("foobar") + if err == nil { + t.Fatalf("expected error for unknown prefix") + } +} + +func TestDecodeInvalidBase32(t *testing.T) { + _, err := Decode("tok_!!!") + if err == nil { + t.Fatalf("expected error for invalid base32 body") + } +} + +func TestDecodeInvalidJSON(t *testing.T) { + // "tok_" prefix with valid base32 but invalid JSON content + _, err := Decode("tok_NBSWY3DP") + if err == nil { + t.Fatalf("expected error for invalid json body") + } +} + +func TestDecodeJustDollar(t *testing.T) { + // "$" alone (length == 1) should fall through to the unknown-prefix branch + _, err := Decode("$") + if err == nil { + t.Fatalf("expected error for bare $") + } +} + +func TestDecodeKnownTokenContent(t *testing.T) { + tok := NewKeySortToken("k1", "k2", SortASC, SortDESC, 33, Extra{}) + encoded := tok.Token() + + decoded, err := Decode(encoded) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, decoded.Token(), encoded) +} diff --git a/dataext/bufferedReadCloser_test.go b/dataext/bufferedReadCloser_test.go new file mode 100644 index 0000000..1890298 --- /dev/null +++ b/dataext/bufferedReadCloser_test.go @@ -0,0 +1,129 @@ +package dataext + +import ( + "bytes" + "io" + "testing" +) + +type fakeReadCloser struct { + r *bytes.Reader + closed bool +} + +func newFakeReadCloser(data []byte) *fakeReadCloser { + return &fakeReadCloser{r: bytes.NewReader(data)} +} + +func (f *fakeReadCloser) Read(p []byte) (int, error) { + return f.r.Read(p) +} + +func (f *fakeReadCloser) Close() error { + f.closed = true + return nil +} + +func TestBufferedReadCloser_ReadAll(t *testing.T) { + data := []byte("hello world") + brc := NewBufferedReadCloser(newFakeReadCloser(data)) + + buf := make([]byte, 64) + total := 0 + for { + n, err := brc.Read(buf[total:]) + total += n + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + } + + if !bytes.Equal(buf[:total], data) { + t.Fatalf("got %q want %q", buf[:total], data) + } +} + +func TestBufferedReadCloser_BufferedAllThenRead(t *testing.T) { + data := []byte("foobar baz") + brc := NewBufferedReadCloser(newFakeReadCloser(data)) + + all, err := brc.BufferedAll() + if err != nil { + t.Fatalf("BufferedAll err: %v", err) + } + if !bytes.Equal(all, data) { + t.Fatalf("BufferedAll got %q want %q", all, data) + } + + // after BufferedAll, Reset put us in BufferReading mode - we can read again + out, err := io.ReadAll(brc) + if err != nil { + t.Fatalf("ReadAll err: %v", err) + } + if !bytes.Equal(out, data) { + t.Fatalf("ReadAll got %q want %q", out, data) + } +} + +func TestBufferedReadCloser_FullyReadResetReread(t *testing.T) { + data := []byte("abcdefghij") + brc := NewBufferedReadCloser(newFakeReadCloser(data)) + + out, err := io.ReadAll(brc) + if err != nil { + t.Fatalf("first ReadAll err: %v", err) + } + if !bytes.Equal(out, data) { + t.Fatalf("first read got %q want %q", out, data) + } + + if err := brc.Reset(); err != nil { + t.Fatalf("reset err: %v", err) + } + + out2, err := io.ReadAll(brc) + if err != nil { + t.Fatalf("second ReadAll err: %v", err) + } + if !bytes.Equal(out2, data) { + t.Fatalf("after reset got %q want %q", out2, data) + } +} + +func TestBufferedReadCloser_Close(t *testing.T) { + data := []byte("xyz") + inner := newFakeReadCloser(data) + brc := NewBufferedReadCloser(inner) + + if err := brc.Close(); err != nil { + t.Fatalf("close err: %v", err) + } + if !inner.closed { + t.Fatal("inner not closed") + } + + // double close should be no-op + if err := brc.Close(); err != nil { + t.Fatalf("second close err: %v", err) + } +} + +func TestBufferedReadCloser_ResetWithoutRead(t *testing.T) { + data := []byte("abc") + brc := NewBufferedReadCloser(newFakeReadCloser(data)) + + if err := brc.Reset(); err != nil { + t.Fatalf("reset err: %v", err) + } + + out, err := io.ReadAll(brc) + if err != nil { + t.Fatalf("ReadAll err: %v", err) + } + if !bytes.Equal(out, data) { + t.Fatalf("got %q want %q", out, data) + } +} diff --git a/dataext/casMutex_test.go b/dataext/casMutex_test.go new file mode 100644 index 0000000..98b0de9 --- /dev/null +++ b/dataext/casMutex_test.go @@ -0,0 +1,122 @@ +package dataext + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestCASMutex_LockUnlock(t *testing.T) { + m := NewCASMutex() + m.Lock() + m.Unlock() +} + +func TestCASMutex_TryLock(t *testing.T) { + m := NewCASMutex() + if !m.TryLock() { + t.Fatal("TryLock should succeed on fresh mutex") + } + if m.TryLock() { + t.Fatal("TryLock should fail when already locked") + } + m.Unlock() + if !m.TryLock() { + t.Fatal("TryLock should succeed after Unlock") + } + m.Unlock() +} + +func TestCASMutex_TryLockWithTimeout(t *testing.T) { + m := NewCASMutex() + m.Lock() + start := time.Now() + if m.TryLockWithTimeout(20 * time.Millisecond) { + t.Fatal("TryLockWithTimeout should fail when locked") + } + if time.Since(start) < 15*time.Millisecond { + t.Fatal("TryLockWithTimeout returned too quickly") + } + m.Unlock() + + if !m.TryLockWithTimeout(50 * time.Millisecond) { + t.Fatal("TryLockWithTimeout should succeed when unlocked") + } + m.Unlock() +} + +func TestCASMutex_TryLockWithContext_Cancel(t *testing.T) { + m := NewCASMutex() + m.Lock() + defer m.Unlock() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + if m.TryLockWithContext(ctx) { + t.Fatal("expected lock to fail after cancel") + } +} + +func TestCASMutex_RLockMultiple(t *testing.T) { + m := NewCASMutex() + if !m.RTryLock() { + t.Fatal("RTryLock should succeed") + } + if !m.RTryLock() { + t.Fatal("Second RTryLock should succeed") + } + if m.TryLock() { + t.Fatal("Write TryLock should fail with read locks held") + } + m.RUnlock() + m.RUnlock() + if !m.TryLock() { + t.Fatal("Write TryLock should succeed after read unlocks") + } + m.Unlock() +} + +func TestCASMutex_RLocker(t *testing.T) { + m := NewCASMutex() + rl := m.RLocker() + rl.Lock() + rl.Unlock() +} + +func TestCASMutex_Concurrent(t *testing.T) { + m := NewCASMutex() + var counter int64 + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + m.Lock() + atomic.AddInt64(&counter, 1) + m.Unlock() + }() + } + wg.Wait() + if atomic.LoadInt64(&counter) != n { + t.Fatalf("counter=%d want %d", counter, n) + } +} + +func TestCASMutex_RTryLockWithTimeout(t *testing.T) { + m := NewCASMutex() + m.Lock() + if m.RTryLockWithTimeout(20 * time.Millisecond) { + t.Fatal("RTryLockWithTimeout should fail when write-locked") + } + m.Unlock() + if !m.RTryLockWithTimeout(20 * time.Millisecond) { + t.Fatal("RTryLockWithTimeout should succeed when free") + } + m.RUnlock() +} diff --git a/dataext/delayedCombiningInvoker_test.go b/dataext/delayedCombiningInvoker_test.go new file mode 100644 index 0000000..0783f33 --- /dev/null +++ b/dataext/delayedCombiningInvoker_test.go @@ -0,0 +1,182 @@ +package dataext + +import ( + "sync/atomic" + "testing" + "time" +) + +func waitForCalls(t *testing.T, calls *int64, want int64, max time.Duration) { + t.Helper() + deadline := time.Now().Add(max) + for time.Now().Before(deadline) { + if atomic.LoadInt64(calls) >= want { + return + } + time.Sleep(5 * time.Millisecond) + } +} + +func TestDelayedCombiningInvoker_SingleRequest(t *testing.T) { + var calls int64 + d := NewDelayedCombiningInvoker(func() { + atomic.AddInt64(&calls, 1) + }, 20*time.Millisecond, 200*time.Millisecond) + + d.Request() + + waitForCalls(t, &calls, 1, 2*time.Second) + if c := atomic.LoadInt64(&calls); c != 1 { + t.Fatalf("calls=%d want 1", c) + } +} + +func TestDelayedCombiningInvoker_TwoRequestsCombine(t *testing.T) { + var calls int64 + d := NewDelayedCombiningInvoker(func() { + atomic.AddInt64(&calls, 1) + }, 50*time.Millisecond, 1*time.Second) + + d.Request() + time.Sleep(10 * time.Millisecond) + d.Request() + + waitForCalls(t, &calls, 1, 2*time.Second) + if c := atomic.LoadInt64(&calls); c != 1 { + t.Fatalf("calls=%d want 1 (should be combined)", c) + } +} + +func TestDelayedCombiningInvoker_SequentialRuns(t *testing.T) { + var calls int64 + d := NewDelayedCombiningInvoker(func() { + atomic.AddInt64(&calls, 1) + }, 20*time.Millisecond, 200*time.Millisecond) + + d.Request() + waitForCalls(t, &calls, 1, 2*time.Second) + if c := atomic.LoadInt64(&calls); c != 1 { + t.Fatalf("after first wait calls=%d want 1", c) + } + + // allow executorRunning to clear + time.Sleep(50 * time.Millisecond) + + d.Request() + waitForCalls(t, &calls, 2, 2*time.Second) + if c := atomic.LoadInt64(&calls); c != 2 { + t.Fatalf("calls=%d want 2", c) + } +} + +func TestDelayedCombiningInvoker_ExecuteNow(t *testing.T) { + var calls int64 + d := NewDelayedCombiningInvoker(func() { + atomic.AddInt64(&calls, 1) + }, 5*time.Second, 30*time.Second) + + d.Request() + if !d.HasPendingRequests() { + t.Fatal("should have pending requests") + } + + if !d.ExecuteNow() { + t.Fatal("ExecuteNow should return true when running") + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if atomic.LoadInt64(&calls) >= 1 { + break + } + time.Sleep(10 * time.Millisecond) + } + if c := atomic.LoadInt64(&calls); c != 1 { + t.Fatalf("calls=%d want 1 (ExecuteNow should fire well before delay)", c) + } + + // allow internal state cleanup + for i := 0; i < 100; i++ { + if !d.HasPendingRequests() { + break + } + time.Sleep(10 * time.Millisecond) + } + if d.ExecuteNow() { + t.Fatal("ExecuteNow should return false when no pending") + } +} + +func TestDelayedCombiningInvoker_Cancel(t *testing.T) { + var calls int64 + d := NewDelayedCombiningInvoker(func() { + atomic.AddInt64(&calls, 1) + }, 500*time.Millisecond, 5*time.Second) + + d.Request() + d.CancelPendingRequests() + + time.Sleep(200 * time.Millisecond) + + if c := atomic.LoadInt64(&calls); c != 0 { + t.Fatalf("calls=%d want 0 after cancel", c) + } +} + +func TestDelayedCombiningInvoker_HasAndCountPending(t *testing.T) { + d := NewDelayedCombiningInvoker(func() { + // no-op + }, 500*time.Millisecond, 5*time.Second) + + if d.HasPendingRequests() { + t.Fatal("should not have pending before any Request") + } + if d.CountPendingRequests() != 0 { + t.Fatalf("count=%d want 0", d.CountPendingRequests()) + } + + d.Request() + if !d.HasPendingRequests() { + t.Fatal("should have pending") + } + if d.CountPendingRequests() < 1 { + t.Fatalf("count=%d want >=1", d.CountPendingRequests()) + } + d.CancelPendingRequests() +} + +func TestDelayedCombiningInvoker_Listeners(t *testing.T) { + var ( + startCount int64 + doneCount int64 + requestCount int64 + ) + + d := NewDelayedCombiningInvoker(func() { + // no-op + }, 20*time.Millisecond, 200*time.Millisecond) + + d.RegisterOnExecutionStart(func(immediately bool) { + atomic.AddInt64(&startCount, 1) + }) + d.RegisterOnExecutionDone(func() { + atomic.AddInt64(&doneCount, 1) + }) + d.RegisterOnRequest(func(pending int, initial bool) { + atomic.AddInt64(&requestCount, 1) + }) + + d.Request() + + waitForCalls(t, &doneCount, 1, 2*time.Second) + + if atomic.LoadInt64(&startCount) != 1 { + t.Fatalf("startCount=%d want 1", startCount) + } + if atomic.LoadInt64(&doneCount) != 1 { + t.Fatalf("doneCount=%d want 1", doneCount) + } + if atomic.LoadInt64(&requestCount) != 1 { + t.Fatalf("requestCount=%d want 1", requestCount) + } +} diff --git a/dataext/multiMutex_test.go b/dataext/multiMutex_test.go new file mode 100644 index 0000000..0f5168f --- /dev/null +++ b/dataext/multiMutex_test.go @@ -0,0 +1,89 @@ +package dataext + +import ( + "context" + "testing" + "time" +) + +func TestMultiMutex_LockDifferentKeys(t *testing.T) { + mm := NewMultiMutex[string]() + mm.Lock("a") + mm.Lock("b") + mm.Unlock("a") + mm.Unlock("b") +} + +func TestMultiMutex_TryLockSameKey(t *testing.T) { + mm := NewMultiMutex[string]() + if !mm.TryLock("k") { + t.Fatal("TryLock should succeed first time") + } + if mm.TryLock("k") { + t.Fatal("TryLock should fail second time") + } + mm.Unlock("k") + if !mm.TryLock("k") { + t.Fatal("TryLock should succeed after unlock") + } + mm.Unlock("k") +} + +func TestMultiMutex_TryLockDifferentKeys(t *testing.T) { + mm := NewMultiMutex[int]() + if !mm.TryLock(1) { + t.Fatal("TryLock(1) failed") + } + if !mm.TryLock(2) { + t.Fatal("TryLock(2) failed - different keys should be independent") + } + mm.Unlock(1) + mm.Unlock(2) +} + +func TestMultiMutex_RLockMultiple(t *testing.T) { + mm := NewMultiMutex[string]() + if !mm.RTryLock("k") { + t.Fatal("first RTryLock failed") + } + if !mm.RTryLock("k") { + t.Fatal("second RTryLock failed") + } + mm.RUnlock("k") + mm.RUnlock("k") +} + +func TestMultiMutex_TryLockWithTimeout(t *testing.T) { + mm := NewMultiMutex[string]() + mm.Lock("k") + if mm.TryLockWithTimeout("k", 10*time.Millisecond) { + t.Fatal("expected timeout failure") + } + mm.Unlock("k") +} + +func TestMultiMutex_TryLockWithContext(t *testing.T) { + mm := NewMultiMutex[string]() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + if !mm.TryLockWithContext(ctx, "k") { + t.Fatal("TryLockWithContext should succeed on free key") + } + mm.Unlock("k") +} + +func TestMultiMutex_GetAndGetCAS(t *testing.T) { + mm := NewMultiMutex[string]() + l := mm.Get("a") + if l == nil { + t.Fatal("Get returned nil") + } + cas := mm.GetCAS("a") + if cas == nil { + t.Fatal("GetCAS returned nil") + } + rl := mm.RLocker("a") + if rl == nil { + t.Fatal("RLocker returned nil") + } +} diff --git a/dataext/mutexSet_test.go b/dataext/mutexSet_test.go new file mode 100644 index 0000000..475a434 --- /dev/null +++ b/dataext/mutexSet_test.go @@ -0,0 +1,51 @@ +package dataext + +import ( + "sync" + "sync/atomic" + "testing" +) + +func TestMutexSet_BasicLockUnlock(t *testing.T) { + ms := NewMutexSet[string]() + ms.Lock("a") + ms.Unlock("a") + ms.RLock("b") + ms.RUnlock("b") +} + +func TestMutexSet_DifferentKeysIndependent(t *testing.T) { + ms := NewMutexSet[int]() + ms.Lock(1) + ms.Lock(2) + ms.Unlock(1) + ms.Unlock(2) +} + +func TestMutexSet_SameKeyMutuallyExclusive(t *testing.T) { + ms := NewMutexSet[string]() + var counter int64 + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + ms.Lock("shared") + atomic.AddInt64(&counter, 1) + ms.Unlock("shared") + }() + } + wg.Wait() + if atomic.LoadInt64(&counter) != n { + t.Fatalf("got %d want %d", counter, n) + } +} + +func TestMutexSet_RLockMultiple(t *testing.T) { + ms := NewMutexSet[string]() + ms.RLock("k") + ms.RLock("k") + ms.RUnlock("k") + ms.RUnlock("k") +} diff --git a/dataext/optional_test.go b/dataext/optional_test.go new file mode 100644 index 0000000..69df944 --- /dev/null +++ b/dataext/optional_test.go @@ -0,0 +1,142 @@ +package dataext + +import ( + "encoding/json" + "testing" +) + +func TestJsonOpt_NewAndEmpty(t *testing.T) { + o := NewJsonOpt[int](42) + if !o.IsSet() { + t.Fatal("expected IsSet=true") + } + if o.IsUnset() { + t.Fatal("expected IsUnset=false") + } + + e := EmptyJsonOpt[int]() + if e.IsSet() { + t.Fatal("expected IsSet=false") + } + if !e.IsUnset() { + t.Fatal("expected IsUnset=true") + } +} + +func TestJsonOpt_Value(t *testing.T) { + o := NewJsonOpt[string]("hello") + v, ok := o.Value() + if !ok || v != "hello" { + t.Fatalf("got (%q,%v)", v, ok) + } + + e := EmptyJsonOpt[string]() + v, ok = e.Value() + if ok || v != "" { + t.Fatalf("empty got (%q,%v)", v, ok) + } +} + +func TestJsonOpt_ValueOrNil(t *testing.T) { + o := NewJsonOpt[int](7) + p := o.ValueOrNil() + if p == nil || *p != 7 { + t.Fatalf("expected ptr to 7") + } + e := EmptyJsonOpt[int]() + if e.ValueOrNil() != nil { + t.Fatal("expected nil") + } +} + +func TestJsonOpt_ValueDblPtrOrNil(t *testing.T) { + o := NewJsonOpt[int](7) + p := o.ValueDblPtrOrNil() + if p == nil || *p == nil || **p != 7 { + t.Fatalf("expected double ptr to 7") + } + e := EmptyJsonOpt[int]() + if e.ValueDblPtrOrNil() != nil { + t.Fatal("expected nil") + } +} + +func TestJsonOpt_MustValue(t *testing.T) { + o := NewJsonOpt[int](9) + if o.MustValue() != 9 { + t.Fatal("MustValue wrong") + } + defer func() { + if recover() == nil { + t.Fatal("expected panic") + } + }() + EmptyJsonOpt[int]().MustValue() +} + +func TestJsonOpt_IfSet(t *testing.T) { + called := false + NewJsonOpt[int](1).IfSet(func(v int) { + called = true + if v != 1 { + t.Fatalf("v=%d", v) + } + }) + if !called { + t.Fatal("IfSet did not invoke fn") + } + + called = false + EmptyJsonOpt[int]().IfSet(func(v int) { called = true }) + if called { + t.Fatal("IfSet invoked fn on empty") + } +} + +func TestJsonOpt_MarshalJSON(t *testing.T) { + o := NewJsonOpt[int](5) + b, err := json.Marshal(o) + if err != nil { + t.Fatal(err) + } + if string(b) != "5" { + t.Fatalf("got %s", b) + } + + e := EmptyJsonOpt[int]() + b, err = json.Marshal(e) + if err != nil { + t.Fatal(err) + } + if string(b) != "null" { + t.Fatalf("got %s", b) + } +} + +func TestJsonOpt_UnmarshalJSON(t *testing.T) { + var o JsonOpt[int] + if err := json.Unmarshal([]byte("42"), &o); err != nil { + t.Fatal(err) + } + if !o.IsSet() { + t.Fatal("should be set") + } + if v, _ := o.Value(); v != 42 { + t.Fatalf("got %d", v) + } +} + +func TestJsonOpt_StructWithJsonOpt(t *testing.T) { + type S struct { + A JsonOpt[int] `json:"a"` + B JsonOpt[string] `json:"b"` + } + s := S{A: NewJsonOpt[int](1), B: EmptyJsonOpt[string]()} + b, err := json.Marshal(s) + if err != nil { + t.Fatal(err) + } + if string(b) != `{"a":1,"b":null}` { + t.Fatalf("got %s", b) + } +} diff --git a/dataext/stack_test.go b/dataext/stack_test.go new file mode 100644 index 0000000..599af70 --- /dev/null +++ b/dataext/stack_test.go @@ -0,0 +1,98 @@ +package dataext + +import ( + "errors" + "sync" + "testing" +) + +func TestStack_PushPop(t *testing.T) { + s := NewStack[int](false, 4) + s.Push(1) + s.Push(2) + s.Push(3) + + if s.Length() != 3 { + t.Fatalf("Length=%d", s.Length()) + } + if s.Empty() { + t.Fatal("should not be empty") + } + + v, err := s.Pop() + if err != nil || v != 3 { + t.Fatalf("Pop got (%d,%v)", v, err) + } + v, err = s.Pop() + if err != nil || v != 2 { + t.Fatalf("Pop got (%d,%v)", v, err) + } + v, err = s.Pop() + if err != nil || v != 1 { + t.Fatalf("Pop got (%d,%v)", v, err) + } +} + +func TestStack_PopEmpty(t *testing.T) { + s := NewStack[int](false, 0) + _, err := s.Pop() + if !errors.Is(err, ErrEmptyStack) { + t.Fatalf("expected ErrEmptyStack, got %v", err) + } + if !s.Empty() { + t.Fatal("should be empty") + } +} + +func TestStack_Peek(t *testing.T) { + s := NewStack[string](false, 0) + if _, err := s.Peek(); !errors.Is(err, ErrEmptyStack) { + t.Fatalf("expected ErrEmptyStack got %v", err) + } + s.Push("a") + s.Push("b") + v, err := s.Peek() + if err != nil || v != "b" { + t.Fatalf("Peek got (%q,%v)", v, err) + } + if s.Length() != 2 { + t.Fatal("Peek must not pop") + } +} + +func TestStack_OptPopOptPeek(t *testing.T) { + s := NewStack[int](false, 0) + if s.OptPop() != nil { + t.Fatal("OptPop on empty should return nil") + } + if s.OptPeek() != nil { + t.Fatal("OptPeek on empty should return nil") + } + s.Push(7) + if p := s.OptPeek(); p == nil || *p != 7 { + t.Fatalf("OptPeek bad") + } + if p := s.OptPop(); p == nil || *p != 7 { + t.Fatalf("OptPop bad") + } + if !s.Empty() { + t.Fatal("should be empty after OptPop") + } +} + +func TestStack_ThreadSafe(t *testing.T) { + s := NewStack[int](true, 0) + var wg sync.WaitGroup + const n = 200 + wg.Add(n) + for i := 0; i < n; i++ { + go func(v int) { + defer wg.Done() + s.Push(v) + }(i) + } + wg.Wait() + if s.Length() != n { + t.Fatalf("Length=%d want %d", s.Length(), n) + } +} diff --git a/dataext/syncMap_test.go b/dataext/syncMap_test.go new file mode 100644 index 0000000..7a6c2c0 --- /dev/null +++ b/dataext/syncMap_test.go @@ -0,0 +1,176 @@ +package dataext + +import ( + "sort" + "sync" + "testing" +) + +func TestSyncMap_SetGet(t *testing.T) { + m := NewSyncMap[string, int]() + m.Set("a", 1) + v, ok := m.Get("a") + if !ok || v != 1 { + t.Fatalf("got (%d,%v)", v, ok) + } + if _, ok := m.Get("missing"); ok { + t.Fatal("expected missing") + } +} + +func TestSyncMap_SetIfNotContains(t *testing.T) { + m := NewSyncMap[string, int]() + if !m.SetIfNotContains("a", 1) { + t.Fatal("first set should succeed") + } + if m.SetIfNotContains("a", 2) { + t.Fatal("second set should fail") + } + v, _ := m.Get("a") + if v != 1 { + t.Fatalf("expected unchanged got %d", v) + } +} + +func TestSyncMap_SetIfNotContainsFunc(t *testing.T) { + m := NewSyncMap[string, int]() + calls := 0 + if !m.SetIfNotContainsFunc("a", func() int { calls++; return 5 }) { + t.Fatal("first should succeed") + } + if m.SetIfNotContainsFunc("a", func() int { calls++; return 6 }) { + t.Fatal("second should fail") + } + if calls != 1 { + t.Fatalf("calls=%d want 1", calls) + } +} + +func TestSyncMap_GetAndSetIfNotContains(t *testing.T) { + m := NewSyncMap[string, int]() + if v := m.GetAndSetIfNotContains("a", 10); v != 10 { + t.Fatalf("got %d", v) + } + if v := m.GetAndSetIfNotContains("a", 99); v != 10 { + t.Fatalf("got %d", v) + } +} + +func TestSyncMap_GetAndSetIfNotContainsFunc(t *testing.T) { + m := NewSyncMap[string, int]() + calls := 0 + if v := m.GetAndSetIfNotContainsFunc("a", func() int { calls++; return 1 }); v != 1 { + t.Fatalf("got %d", v) + } + if v := m.GetAndSetIfNotContainsFunc("a", func() int { calls++; return 2 }); v != 1 { + t.Fatalf("got %d", v) + } + if calls != 1 { + t.Fatalf("calls=%d", calls) + } +} + +func TestSyncMap_Delete(t *testing.T) { + m := NewSyncMap[string, int]() + m.Set("a", 1) + if !m.Delete("a") { + t.Fatal("delete existing returned false") + } + if m.Delete("a") { + t.Fatal("delete missing returned true") + } +} + +func TestSyncMap_DeleteIf(t *testing.T) { + m := NewSyncMap[string, int]() + m.Set("a", 1) + m.Set("b", 2) + m.Set("c", 3) + rm := m.DeleteIf(func(k string, v int) bool { return v%2 == 1 }) + if rm != 2 { + t.Fatalf("removed=%d", rm) + } + if m.Count() != 1 { + t.Fatalf("count=%d", m.Count()) + } +} + +func TestSyncMap_UpdateIfExists(t *testing.T) { + m := NewSyncMap[string, int]() + if m.UpdateIfExists("a", func(v int) int { return v + 1 }) { + t.Fatal("should be false on missing key") + } + m.Set("a", 5) + if !m.UpdateIfExists("a", func(v int) int { return v + 1 }) { + t.Fatal("should be true on existing") + } + v, _ := m.Get("a") + if v != 6 { + t.Fatalf("v=%d", v) + } +} + +func TestSyncMap_UpdateOrInsert(t *testing.T) { + m := NewSyncMap[string, int]() + if m.UpdateOrInsert("a", func(v int) int { return v + 1 }, 100) { + t.Fatal("should return false on insert") + } + if v, _ := m.Get("a"); v != 100 { + t.Fatalf("v=%d", v) + } + if !m.UpdateOrInsert("a", func(v int) int { return v + 1 }, 100) { + t.Fatal("should return true on update") + } + if v, _ := m.Get("a"); v != 101 { + t.Fatalf("v=%d", v) + } +} + +func TestSyncMap_ClearContains(t *testing.T) { + m := NewSyncMap[string, int]() + m.Set("a", 1) + if !m.Contains("a") { + t.Fatal("Contains should be true") + } + m.Clear() + if m.Contains("a") { + t.Fatal("after Clear should be false") + } + if m.Count() != 0 { + t.Fatalf("count=%d", m.Count()) + } +} + +func TestSyncMap_GetAllKeysValues(t *testing.T) { + m := NewSyncMap[string, int]() + m.Set("a", 1) + m.Set("b", 2) + m.Set("c", 3) + keys := m.GetAllKeys() + sort.Strings(keys) + if len(keys) != 3 || keys[0] != "a" || keys[2] != "c" { + t.Fatalf("keys=%v", keys) + } + vals := m.GetAllValues() + sort.Ints(vals) + if len(vals) != 3 || vals[0] != 1 || vals[2] != 3 { + t.Fatalf("vals=%v", vals) + } +} + +func TestSyncMap_Concurrent(t *testing.T) { + m := NewSyncMap[int, int]() + var wg sync.WaitGroup + const n = 200 + wg.Add(n) + for i := 0; i < n; i++ { + go func(k int) { + defer wg.Done() + m.Set(k, k*2) + }(i) + } + wg.Wait() + if m.Count() != n { + t.Fatalf("count=%d want %d", m.Count(), n) + } +} diff --git a/dataext/syncRingSet_test.go b/dataext/syncRingSet_test.go new file mode 100644 index 0000000..fd48837 --- /dev/null +++ b/dataext/syncRingSet_test.go @@ -0,0 +1,84 @@ +package dataext + +import ( + "sort" + "testing" +) + +func TestSyncRingSet_AddAndContains(t *testing.T) { + s := NewSyncRingSet[int](3) + if !s.Add(1) { + t.Fatal("first Add(1) should be true") + } + if s.Add(1) { + t.Fatal("duplicate Add(1) should be false") + } + if !s.Contains(1) { + t.Fatal("expected Contains(1)") + } +} + +func TestSyncRingSet_CapacityEvicts(t *testing.T) { + s := NewSyncRingSet[int](3) + s.Add(1) + s.Add(2) + s.Add(3) + s.Add(4) // should evict the oldest (1) + if s.Contains(1) { + t.Fatal("1 should have been evicted") + } + for _, v := range []int{2, 3, 4} { + if !s.Contains(v) { + t.Fatalf("expected %d", v) + } + } +} + +func TestSyncRingSet_Remove(t *testing.T) { + s := NewSyncRingSet[string](3) + s.Add("a") + s.Add("b") + if !s.Remove("a") { + t.Fatal("remove existing failed") + } + if s.Remove("a") { + t.Fatal("remove missing returned true") + } + if s.Contains("a") { + t.Fatal("a should be gone") + } +} + +func TestSyncRingSet_AddAllRemoveAll(t *testing.T) { + s := NewSyncRingSet[int](10) + s.AddAll([]int{1, 2, 3, 2}) + out := s.Get() + sort.Ints(out) + if len(out) != 3 { + t.Fatalf("got %v", out) + } + + s.RemoveAll([]int{1, 99}) + if s.Contains(1) { + t.Fatal("1 should be removed") + } + if !s.Contains(2) || !s.Contains(3) { + t.Fatal("2/3 should remain") + } +} + +func TestSyncRingSet_AddIfNotContainsRemoveIfContains(t *testing.T) { + s := NewSyncRingSet[string](5) + if !s.AddIfNotContains("x") { + t.Fatal("first should succeed") + } + if s.AddIfNotContains("x") { + t.Fatal("second should fail") + } + if !s.RemoveIfContains("x") { + t.Fatal("remove existing failed") + } + if s.RemoveIfContains("x") { + t.Fatal("remove missing returned true") + } +} diff --git a/dataext/syncSet_test.go b/dataext/syncSet_test.go new file mode 100644 index 0000000..bf752cf --- /dev/null +++ b/dataext/syncSet_test.go @@ -0,0 +1,82 @@ +package dataext + +import ( + "sort" + "testing" +) + +func TestSyncSet_Add(t *testing.T) { + s := NewSyncSet[string]() + if !s.Add("a") { + t.Fatal("first add should be true") + } + if s.Add("a") { + t.Fatal("duplicate add should be false") + } + if !s.Contains("a") { + t.Fatal("Contains a should be true") + } +} + +func TestSyncSet_AddAll(t *testing.T) { + s := NewSyncSet[int]() + s.AddAll([]int{1, 2, 3, 2}) + if !s.Contains(1) || !s.Contains(2) || !s.Contains(3) { + t.Fatal("missing items") + } + if len(s.Get()) != 3 { + t.Fatalf("got len %d", len(s.Get())) + } +} + +func TestSyncSet_Remove(t *testing.T) { + s := NewSyncSet[string]() + s.Add("a") + if !s.Remove("a") { + t.Fatal("remove existing failed") + } + if s.Remove("a") { + t.Fatal("remove missing returned true") + } + if s.Contains("a") { + t.Fatal("still contains after remove") + } +} + +func TestSyncSet_RemoveAll(t *testing.T) { + s := NewSyncSet[int]() + s.AddAll([]int{1, 2, 3}) + s.RemoveAll([]int{1, 2, 99}) + if s.Contains(1) || s.Contains(2) { + t.Fatal("should be removed") + } + if !s.Contains(3) { + t.Fatal("3 should remain") + } +} + +func TestSyncSet_Get(t *testing.T) { + s := NewSyncSet[int]() + s.AddAll([]int{3, 1, 2}) + out := s.Get() + sort.Ints(out) + if len(out) != 3 || out[0] != 1 || out[2] != 3 { + t.Fatalf("out=%v", out) + } +} + +func TestSyncSet_AddIfNotContainsRemoveIfContains(t *testing.T) { + s := NewSyncSet[string]() + if !s.AddIfNotContains("x") { + t.Fatal("first AddIfNotContains failed") + } + if s.AddIfNotContains("x") { + t.Fatal("second AddIfNotContains succeeded") + } + if !s.RemoveIfContains("x") { + t.Fatal("RemoveIfContains failed") + } + if s.RemoveIfContains("x") { + t.Fatal("RemoveIfContains on missing succeeded") + } +} diff --git a/dataext/tuple_test.go b/dataext/tuple_test.go new file mode 100644 index 0000000..05b3f9f --- /dev/null +++ b/dataext/tuple_test.go @@ -0,0 +1,136 @@ +package dataext + +import ( + "reflect" + "testing" +) + +func TestSingle(t *testing.T) { + s := NewSingle[int](7) + if s.V1 != 7 { + t.Fatalf("V1=%d", s.V1) + } + if s.TupleLength() != 1 { + t.Fatalf("len=%d", s.TupleLength()) + } + if !reflect.DeepEqual(s.TupleValues(), []any{7}) { + t.Fatalf("values=%v", s.TupleValues()) + } + if NewTuple1[int](7).V1 != 7 { + t.Fatal("NewTuple1 mismatch") + } +} + +func TestTuple(t *testing.T) { + tp := NewTuple[int, string](1, "two") + if tp.V1 != 1 || tp.V2 != "two" { + t.Fatal("values wrong") + } + if tp.TupleLength() != 2 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(tp.TupleValues(), []any{1, "two"}) { + t.Fatalf("values=%v", tp.TupleValues()) + } + if NewTuple2[int, string](1, "two") != tp { + t.Fatal("NewTuple2 mismatch") + } +} + +func TestTriple(t *testing.T) { + tr := NewTriple[int, string, bool](1, "x", true) + if tr.TupleLength() != 3 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(tr.TupleValues(), []any{1, "x", true}) { + t.Fatalf("values=%v", tr.TupleValues()) + } + if NewTuple3[int, string, bool](1, "x", true) != tr { + t.Fatal("NewTuple3 mismatch") + } +} + +func TestQuadruple(t *testing.T) { + q := NewQuadruple[int, int, int, int](1, 2, 3, 4) + if q.TupleLength() != 4 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(q.TupleValues(), []any{1, 2, 3, 4}) { + t.Fatalf("values=%v", q.TupleValues()) + } + if NewTuple4[int, int, int, int](1, 2, 3, 4) != q { + t.Fatal("NewTuple4 mismatch") + } +} + +func TestQuintuple(t *testing.T) { + q := NewQuintuple[int, int, int, int, int](1, 2, 3, 4, 5) + if q.TupleLength() != 5 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(q.TupleValues(), []any{1, 2, 3, 4, 5}) { + t.Fatalf("values=%v", q.TupleValues()) + } + if NewTuple5[int, int, int, int, int](1, 2, 3, 4, 5) != q { + t.Fatal("NewTuple5 mismatch") + } +} + +func TestSextuple(t *testing.T) { + s := NewSextuple[int, int, int, int, int, int](1, 2, 3, 4, 5, 6) + if s.TupleLength() != 6 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(s.TupleValues(), []any{1, 2, 3, 4, 5, 6}) { + t.Fatalf("values=%v", s.TupleValues()) + } + if NewTuple6[int, int, int, int, int, int](1, 2, 3, 4, 5, 6) != s { + t.Fatal("NewTuple6 mismatch") + } +} + +func TestSeptuple(t *testing.T) { + s := NewSeptuple[int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7) + if s.TupleLength() != 7 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(s.TupleValues(), []any{1, 2, 3, 4, 5, 6, 7}) { + t.Fatalf("values=%v", s.TupleValues()) + } + if NewTuple7[int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7) != s { + t.Fatal("NewTuple7 mismatch") + } +} + +func TestOctuple(t *testing.T) { + o := NewOctuple[int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8) + if o.TupleLength() != 8 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(o.TupleValues(), []any{1, 2, 3, 4, 5, 6, 7, 8}) { + t.Fatalf("values=%v", o.TupleValues()) + } + if NewTuple8[int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8) != o { + t.Fatal("NewTuple8 mismatch") + } +} + +func TestNonuple(t *testing.T) { + n := NewNonuple[int, int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8, 9) + if n.TupleLength() != 9 { + t.Fatal("len wrong") + } + if !reflect.DeepEqual(n.TupleValues(), []any{1, 2, 3, 4, 5, 6, 7, 8, 9}) { + t.Fatalf("values=%v", n.TupleValues()) + } + if NewTuple9[int, int, int, int, int, int, int, int, int](1, 2, 3, 4, 5, 6, 7, 8, 9) != n { + t.Fatal("NewTuple9 mismatch") + } +} + +func TestValueGroupInterface(t *testing.T) { + var vg ValueGroup = NewTuple[int, string](1, "a") + if vg.TupleLength() != 2 { + t.Fatal("interface length wrong") + } +} diff --git a/enums/enum_test.go b/enums/enum_test.go new file mode 100644 index 0000000..01681b3 --- /dev/null +++ b/enums/enum_test.go @@ -0,0 +1,258 @@ +package enums + +import ( + "encoding/json" + "reflect" + "testing" +) + +type mockEnum struct { + name string +} + +func (m mockEnum) Valid() bool { return m.name != "" } +func (m mockEnum) ValuesAny() []any { return []any{mockEnum{name: "a"}, mockEnum{name: "b"}} } +func (m mockEnum) ValuesMeta() []EnumMetaValue { return nil } +func (m mockEnum) VarName() string { return m.name } +func (m mockEnum) TypeName() string { return "mockEnum" } +func (m mockEnum) PackageName() string { return "enums_test" } +func (m mockEnum) String() string { return "str:" + m.name } +func (m mockEnum) Description() string { return "desc:" + m.name } +func (m mockEnum) DescriptionMeta() EnumDescriptionMetaValue { + return EnumDescriptionMetaValue{VarName: m.name, Value: m, Description: "desc:" + m.name} +} + +func (m mockEnum) MarshalJSON() ([]byte, error) { + return json.Marshal(m.name) +} + +func TestMockEnumImplementsInterfaces(t *testing.T) { + var _ Enum = mockEnum{} + var _ StringEnum = mockEnum{} + var _ DescriptionEnum = mockEnum{} +} + +func TestEnumValid(t *testing.T) { + if !(mockEnum{name: "x"}).Valid() { + t.Errorf("expected Valid() == true") + } + if (mockEnum{}).Valid() { + t.Errorf("expected Valid() == false for zero value") + } +} + +func TestEnumMetaValueJSON(t *testing.T) { + desc := "the-description" + mv := EnumMetaValue{ + VarName: "Foo", + Value: mockEnum{name: "foo"}, + Description: &desc, + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if got["varName"] != "Foo" { + t.Errorf("varName == %v, want Foo", got["varName"]) + } + if got["value"] != "foo" { + t.Errorf("value == %v, want foo", got["value"]) + } + if got["description"] != "the-description" { + t.Errorf("description == %v, want the-description", got["description"]) + } +} + +func TestEnumMetaValueJSONNilDescription(t *testing.T) { + mv := EnumMetaValue{ + VarName: "Foo", + Value: mockEnum{name: "foo"}, + Description: nil, + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if got["description"] != nil { + t.Errorf("description == %v, want nil", got["description"]) + } +} + +func TestEnumDescriptionMetaValueJSON(t *testing.T) { + mv := EnumDescriptionMetaValue{ + VarName: "Bar", + Value: mockEnum{name: "bar"}, + Description: "bar-desc", + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + expected := map[string]any{ + "varName": "Bar", + "value": "bar", + "description": "bar-desc", + } + if !reflect.DeepEqual(got, expected) { + t.Errorf("json output == %v, want %v", got, expected) + } +} + +func TestEnumDataMetaValueMarshalJSON(t *testing.T) { + desc := "data-desc" + mv := EnumDataMetaValue{ + VarName: "Baz", + Value: mockEnum{name: "baz"}, + Description: &desc, + Data: map[string]any{ + "extra1": "hello", + "extra2": float64(42), + }, + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if got["varName"] != "Baz" { + t.Errorf("varName == %v, want Baz", got["varName"]) + } + if got["value"] != "baz" { + t.Errorf("value == %v, want baz", got["value"]) + } + if got["description"] != "data-desc" { + t.Errorf("description == %v, want data-desc", got["description"]) + } + if got["extra1"] != "hello" { + t.Errorf("extra1 == %v, want hello", got["extra1"]) + } + if got["extra2"] != float64(42) { + t.Errorf("extra2 == %v, want 42", got["extra2"]) + } +} + +func TestEnumDataMetaValueMarshalJSONNilData(t *testing.T) { + mv := EnumDataMetaValue{ + VarName: "Baz", + Value: mockEnum{name: "baz"}, + Description: nil, + Data: nil, + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if got["varName"] != "Baz" { + t.Errorf("varName == %v, want Baz", got["varName"]) + } + if got["value"] != "baz" { + t.Errorf("value == %v, want baz", got["value"]) + } + if _, ok := got["description"]; !ok { + t.Errorf("description key missing in JSON output") + } + if got["description"] != nil { + t.Errorf("description == %v, want nil", got["description"]) + } + if len(got) != 3 { + t.Errorf("expected 3 keys with nil Data, got %d: %v", len(got), got) + } +} + +func TestEnumDataMetaValueMarshalJSONDataDoesNotOverrideStandardFields(t *testing.T) { + desc := "real-desc" + mv := EnumDataMetaValue{ + VarName: "Real", + Value: mockEnum{name: "real"}, + Description: &desc, + Data: map[string]any{ + "varName": "ShouldBeOverwritten", + "value": "ShouldBeOverwritten", + "description": "ShouldBeOverwritten", + "keep": "kept", + }, + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if got["varName"] != "Real" { + t.Errorf("varName == %v, want Real (standard field must override Data)", got["varName"]) + } + if got["value"] != "real" { + t.Errorf("value == %v, want real (standard field must override Data)", got["value"]) + } + if got["description"] != "real-desc" { + t.Errorf("description == %v, want real-desc (standard field must override Data)", got["description"]) + } + if got["keep"] != "kept" { + t.Errorf("keep == %v, want kept", got["keep"]) + } +} + +func TestEnumDataMetaValueMarshalJSONEmptyData(t *testing.T) { + mv := EnumDataMetaValue{ + VarName: "E", + Value: mockEnum{name: "e"}, + Description: nil, + Data: map[string]any{}, + } + + data, err := json.Marshal(mv) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var got map[string]any + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if got["varName"] != "E" { + t.Errorf("varName == %v, want E", got["varName"]) + } + if got["value"] != "e" { + t.Errorf("value == %v, want e", got["value"]) + } +} diff --git a/excelext/mapper_test.go b/excelext/mapper_test.go new file mode 100644 index 0000000..bed4ae7 --- /dev/null +++ b/excelext/mapper_test.go @@ -0,0 +1,304 @@ +package excelext + +import ( + "bytes" + "errors" + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "github.com/xuri/excelize/v2" +) + +type testRow struct { + Name string + Age int + Score float64 +} + +func openBytes(t *testing.T, data []byte) *excelize.File { + t.Helper() + f, err := excelize.OpenReader(bytes.NewReader(data)) + if err != nil { + t.Fatalf("failed to open xlsx bytes: %v", err) + } + return f +} + +func cellValue(t *testing.T, f *excelize.File, sheet, axis string) string { + t.Helper() + v, err := f.GetCellValue(sheet, axis) + if err != nil { + t.Fatalf("GetCellValue(%s, %s) failed: %v", sheet, axis, err) + } + return v +} + +func TestNewExcelMapper(t *testing.T) { + em, err := NewExcelMapper[testRow]() + tst.AssertNoErr(t, err) + if em == nil { + t.Fatal("expected non-nil mapper") + } + tst.AssertEqual(t, em.SkipColumnHeader, false) + tst.AssertEqual(t, len(em.colDefinitions), 0) + tst.AssertEqual(t, len(em.wsHeader), 0) + tst.AssertEqual(t, len(em.colFilter), 0) + if em.StyleDate != nil || em.StyleHeader != nil { + t.Errorf("expected styles to be nil before init") + } +} + +func TestInitNewFileAndStyles(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + f, err := em.InitNewFile("Sheet-Foo") + tst.AssertNoErr(t, err) + if f == nil { + t.Fatal("expected non-nil file") + } + + sheets := f.GetSheetList() + tst.AssertEqual(t, len(sheets), 1) + tst.AssertEqual(t, sheets[0], "Sheet-Foo") + + if em.StyleDate == nil || em.StyleDatetime == nil || em.StyleEUR == nil || + em.StylePercentage == nil || em.StyleHeader == nil || em.StyleWSHeader == nil { + t.Errorf("expected all styles to be initialized") + } +} + +func TestAddColumn(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + em.AddColumn("Age", nil, langext.Ptr(12.0), func(r testRow) any { return r.Age }) + + tst.AssertEqual(t, len(em.colDefinitions), 2) + tst.AssertEqual(t, em.colDefinitions[0].header, "Name") + tst.AssertEqual(t, em.colDefinitions[1].header, "Age") + if em.colDefinitions[1].width == nil || *em.colDefinitions[1].width != 12.0 { + t.Errorf("expected width 12.0") + } + + val, err := em.colDefinitions[0].fn(testRow{Name: "Alice"}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, val.(string), "Alice") +} + +func TestAddColumnErr(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + sentinel := errors.New("boom") + em.AddColumnErr("X", nil, nil, func(r testRow) (any, error) { + if r.Age < 0 { + return nil, sentinel + } + return r.Age, nil + }) + + tst.AssertEqual(t, len(em.colDefinitions), 1) + + v, err := em.colDefinitions[0].fn(testRow{Age: 5}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, v.(int), 5) + + _, err = em.colDefinitions[0].fn(testRow{Age: -1}) + if !errors.Is(err, sentinel) { + t.Errorf("expected sentinel error, got %v", err) + } +} + +func TestAddWorksheetHeader(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddWorksheetHeader("Title 1", nil) + em.AddWorksheetHeader("Title 2", langext.Ptr(7)) + + tst.AssertEqual(t, len(em.wsHeader), 2) + tst.AssertEqual(t, em.wsHeader[0].V1, "Title 1") + tst.AssertEqual(t, em.wsHeader[1].V1, "Title 2") + if em.wsHeader[1].V2 == nil || *em.wsHeader[1].V2 != 7 { + t.Errorf("expected style ptr 7") + } + if em.wsHeader[0].V2 != nil { + t.Errorf("expected nil style for first header") + } +} + +func TestAddFilter(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddFilter(func(v testRow) bool { return v.Age >= 18 }) + em.AddFilter(func(v testRow) bool { return v.Score > 0 }) + tst.AssertEqual(t, len(em.colFilter), 2) +} + +func TestBuildBasic(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + em.AddColumn("Age", nil, nil, func(r testRow) any { return r.Age }) + + rows := []testRow{ + {Name: "Alice", Age: 30}, + {Name: "Bob", Age: 25}, + } + + data, err := em.Build("Sheet1", rows) + tst.AssertNoErr(t, err) + if len(data) == 0 { + t.Fatal("expected non-empty xlsx output") + } + + f := openBytes(t, data) + defer f.Close() + + tst.AssertEqual(t, cellValue(t, f, "Sheet1", "A1"), "Name") + tst.AssertEqual(t, cellValue(t, f, "Sheet1", "B1"), "Age") + tst.AssertEqual(t, cellValue(t, f, "Sheet1", "A2"), "Alice") + tst.AssertEqual(t, cellValue(t, f, "Sheet1", "B2"), "30") + tst.AssertEqual(t, cellValue(t, f, "Sheet1", "A3"), "Bob") + tst.AssertEqual(t, cellValue(t, f, "Sheet1", "B3"), "25") +} + +func TestBuildSkipColumnHeader(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.SkipColumnHeader = true + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + + rows := []testRow{{Name: "Alice"}, {Name: "Bob"}} + + data, err := em.Build("Data", rows) + tst.AssertNoErr(t, err) + + f := openBytes(t, data) + defer f.Close() + + tst.AssertEqual(t, cellValue(t, f, "Data", "A1"), "Alice") + tst.AssertEqual(t, cellValue(t, f, "Data", "A2"), "Bob") +} + +func TestBuildWithFilter(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + em.AddFilter(func(v testRow) bool { return v.Age >= 18 }) + + rows := []testRow{ + {Name: "Alice", Age: 30}, + {Name: "Charlie", Age: 12}, + {Name: "Bob", Age: 25}, + } + + data, err := em.Build("S", rows) + tst.AssertNoErr(t, err) + + f := openBytes(t, data) + defer f.Close() + + tst.AssertEqual(t, cellValue(t, f, "S", "A1"), "Name") + tst.AssertEqual(t, cellValue(t, f, "S", "A2"), "Alice") + tst.AssertEqual(t, cellValue(t, f, "S", "A3"), "Bob") + tst.AssertEqual(t, cellValue(t, f, "S", "A4"), "") +} + +func TestBuildWithWorksheetHeader(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddWorksheetHeader("My Big Title", nil) + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + em.AddColumn("Age", nil, nil, func(r testRow) any { return r.Age }) + + rows := []testRow{{Name: "Alice", Age: 30}} + + data, err := em.Build("S", rows) + tst.AssertNoErr(t, err) + + f := openBytes(t, data) + defer f.Close() + + tst.AssertEqual(t, cellValue(t, f, "S", "A1"), "My Big Title") + tst.AssertEqual(t, cellValue(t, f, "S", "A3"), "Name") + tst.AssertEqual(t, cellValue(t, f, "S", "B3"), "Age") + tst.AssertEqual(t, cellValue(t, f, "S", "A4"), "Alice") + tst.AssertEqual(t, cellValue(t, f, "S", "B4"), "30") +} + +func TestBuildHandlesNilPointer(t *testing.T) { + type ptrRow struct { + Name *string + } + + em, _ := NewExcelMapper[ptrRow]() + em.AddColumn("Name", nil, nil, func(r ptrRow) any { return r.Name }) + + name := "Alice" + rows := []ptrRow{ + {Name: &name}, + {Name: nil}, + } + + data, err := em.Build("S", rows) + tst.AssertNoErr(t, err) + + f := openBytes(t, data) + defer f.Close() + + tst.AssertEqual(t, cellValue(t, f, "S", "A2"), "Alice") + tst.AssertEqual(t, cellValue(t, f, "S", "A3"), "") +} + +func TestBuildPropagatesColumnError(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + sentinel := errors.New("col fail") + em.AddColumnErr("Bad", nil, nil, func(r testRow) (any, error) { + return nil, sentinel + }) + + _, err := em.Build("S", []testRow{{Name: "X"}}) + if err == nil { + t.Fatal("expected error from column fn to propagate") + } +} + +func TestBuildEmptyData(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + + data, err := em.Build("S", []testRow{}) + tst.AssertNoErr(t, err) + + f := openBytes(t, data) + defer f.Close() + + tst.AssertEqual(t, cellValue(t, f, "S", "A1"), "Name") + tst.AssertEqual(t, cellValue(t, f, "S", "A2"), "") +} + +func TestBuildSingleSheetWithExistingFile(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + em.AddColumn("Name", nil, nil, func(r testRow) any { return r.Name }) + + f, err := em.InitNewFile("S1") + tst.AssertNoErr(t, err) + + _, err = f.NewSheet("S2") + tst.AssertNoErr(t, err) + + err = em.BuildSingleSheet(f, "S2", []testRow{{Name: "Bob"}}) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, cellValue(t, f, "S2", "A1"), "Name") + tst.AssertEqual(t, cellValue(t, f, "S2", "A2"), "Bob") +} + +func TestBuildWithColumnWidthAndStyle(t *testing.T) { + em, _ := NewExcelMapper[testRow]() + + f, err := em.InitNewFile("S") + tst.AssertNoErr(t, err) + + em.AddColumn("Name", em.StyleHeader, langext.Ptr(20.5), func(r testRow) any { return r.Name }) + + err = em.BuildSingleSheet(f, "S", []testRow{{Name: "Alice"}}) + tst.AssertNoErr(t, err) + + w, err := f.GetColWidth("S", "A") + tst.AssertNoErr(t, err) + if w < 20.0 || w > 21.0 { + t.Errorf("expected column width near 20.5, got %v", w) + } +} diff --git a/excelext/utils_test.go b/excelext/utils_test.go new file mode 100644 index 0000000..bc2e9f6 --- /dev/null +++ b/excelext/utils_test.go @@ -0,0 +1,61 @@ +package excelext + +import ( + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/rfctime" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestCellAddress(t *testing.T) { + tst.AssertEqual(t, c(1, 0), "A1") + tst.AssertEqual(t, c(1, 1), "B1") + tst.AssertEqual(t, c(2, 0), "A2") + tst.AssertEqual(t, c(10, 25), "Z10") + tst.AssertEqual(t, c(1, 26), "AA1") + tst.AssertEqual(t, c(99, 27), "AB99") + tst.AssertEqual(t, c(100, 51), "AZ100") + tst.AssertEqual(t, c(1, 52), "BA1") +} + +func TestExcelizeOptTimeNil(t *testing.T) { + got := excelizeOptTime(nil) + if got != "" { + t.Errorf("expected empty string for nil time, got %v", got) + } +} + +func TestExcelizeOptTimeValue(t *testing.T) { + now := time.Date(2024, 5, 17, 13, 45, 30, 0, time.UTC) + rt := rfctime.RFC3339NanoTime(now) + got := excelizeOptTime(&rt) + + gt, ok := got.(time.Time) + if !ok { + t.Fatalf("expected time.Time, got %T", got) + } + if !gt.Equal(now) { + t.Errorf("expected %v, got %v", now, gt) + } +} + +func TestExcelizeOptDateNil(t *testing.T) { + got := excelizeOptDate(nil) + if got != "" { + t.Errorf("expected empty string for nil date, got %v", got) + } +} + +func TestExcelizeOptDateValue(t *testing.T) { + d := rfctime.NewDate(time.Date(2025, 11, 3, 0, 0, 0, 0, time.UTC)) + got := excelizeOptDate(&d) + + gt, ok := got.(time.Time) + if !ok { + t.Fatalf("expected time.Time, got %T", got) + } + if gt.Year() != 2025 || gt.Month() != 11 || gt.Day() != 3 { + t.Errorf("unexpected date returned: %v", gt) + } +} diff --git a/exerr/unit_test.go b/exerr/unit_test.go new file mode 100644 index 0000000..95e9314 --- /dev/null +++ b/exerr/unit_test.go @@ -0,0 +1,836 @@ +package exerr + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "go.mongodb.org/mongo-driver/v2/bson" +) + +// ============================================================================ +// Builder / Constructor tests +// ============================================================================ + +func TestNewBuildsExErr(t *testing.T) { + err := New(TypeInternal, "boom").Build() + tst.AssertTrue(t, err != nil) + + ee, ok := err.(*ExErr) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, ee.Message, "boom") + tst.AssertEqual(t, ee.Type, TypeInternal) + tst.AssertEqual(t, ee.Category, CatSystem) + tst.AssertEqual(t, ee.Severity, SevErr) + tst.AssertTrue(t, ee.UniqueID != "") +} + +func TestWrapNilProducesInternalError(t *testing.T) { + err := Wrap(nil, "msg").Build() + tst.AssertTrue(t, err != nil) + + ee, _ := err.(*ExErr) + tst.AssertEqual(t, ee.Message, "msg") + tst.AssertEqual(t, ee.Type, TypeInternal) +} + +func TestWrapForeignError(t *testing.T) { + plain := errors.New("plain go error") + err := Wrap(plain, "wrapped").Build() + ee, _ := err.(*ExErr) + + // outer is a wrap (TypeWrap), inner is the foreign error + tst.AssertEqual(t, ee.Type, TypeWrap) + tst.AssertEqual(t, ee.Category, CatWrap) + tst.AssertEqual(t, ee.Message, "wrapped") + tst.AssertTrue(t, ee.OriginalError != nil) + tst.AssertEqual(t, ee.OriginalError.Category, CatForeign) + tst.AssertEqual(t, ee.OriginalError.Message, "plain go error") +} + +func TestWrapExErrChainsDepth(t *testing.T) { + e1 := New(TypeInternal, "level-1").Build() + e2 := Wrap(e1, "level-2").Build() + e3 := Wrap(e2, "level-3").Build() + ee3 := e3.(*ExErr) + + tst.AssertEqual(t, ee3.Depth(), 3) +} + +func TestGetReturnsExErr(t *testing.T) { + plain := errors.New("foreign") + b := Get(plain) + tst.AssertTrue(t, b != nil) + tst.AssertEqual(t, b.errorData.Category, CatForeign) + tst.AssertEqual(t, b.errorData.Message, "foreign") +} + +func TestBuilderWithModifiers(t *testing.T) { + err := New(TypeInternal, "msg"). + WithType(TypeAssert). + WithStatuscode(418). + WithMessage("teapot"). + WithSeverity(SevWarn). + WithCategory(CatUser). + Build() + ee := err.(*ExErr) + + tst.AssertEqual(t, ee.Type, TypeAssert) + tst.AssertDeRefEqual(t, ee.StatusCode, 418) + tst.AssertEqual(t, ee.Message, "teapot") + tst.AssertEqual(t, ee.Severity, SevWarn) + tst.AssertEqual(t, ee.Category, CatUser) +} + +func TestBuilderSeverityShortcuts(t *testing.T) { + tst.AssertEqual(t, New(TypeInternal, "x").Err().Build().(*ExErr).Severity, SevErr) + tst.AssertEqual(t, New(TypeInternal, "x").Warn().Build().(*ExErr).Severity, SevWarn) + tst.AssertEqual(t, New(TypeInternal, "x").Info().Build().(*ExErr).Severity, SevInfo) +} + +func TestBuilderCategoryShortcuts(t *testing.T) { + tst.AssertEqual(t, New(TypeInternal, "x").User().Build().(*ExErr).Category, CatUser) + tst.AssertEqual(t, New(TypeInternal, "x").System().Build().(*ExErr).Category, CatSystem) +} + +func TestBuilderNoLog(t *testing.T) { + b := New(TypeInternal, "x").NoLog() + tst.AssertTrue(t, b.noLog) +} + +func TestBuilderExtra(t *testing.T) { + err := New(TypeInternal, "x").Extra("k", 42).Build() + ee := err.(*ExErr) + v, ok := ee.GetExtra("k") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v.(int), 42) +} + +func TestBuilderMetaTypes(t *testing.T) { + now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + err := New(TypeInternal, "msg"). + Str("s", "value"). + Int("i", 7). + Int8("i8", 8). + Int16("i16", 16). + Int32("i32", 32). + Int64("i64", 64). + Float32("f32", 1.5). + Float64("f64", 2.5). + Bool("b", true). + Bytes("by", []byte{0xAA, 0xBB}). + Time("t", now). + Dur("d", 5*time.Second). + Strs("strs", []string{"a", "b"}). + Ints("ints", []int{1, 2, 3}). + Ints32("ints32", []int32{4, 5}). + Type("typ", "hello"). + Build() + + ee := err.(*ExErr) + + gotS, _ := ee.GetMetaString("s") + tst.AssertEqual(t, gotS, "value") + + gotI, _ := ee.GetMetaInt("i") + tst.AssertEqual(t, gotI, 7) + + gotB, _ := ee.GetMetaBool("b") + tst.AssertEqual(t, gotB, true) + + gotF32, _ := ee.GetMetaFloat32("f32") + tst.AssertEqual(t, gotF32, float32(1.5)) + + gotF64, _ := ee.GetMetaFloat64("f64") + tst.AssertEqual(t, gotF64, 2.5) + + gotT, _ := ee.GetMetaTime("t") + tst.AssertEqual(t, gotT.Equal(now), true) +} + +func TestBuilderInterfaceAndAny(t *testing.T) { + type payload struct { + A int `json:"a"` + B string `json:"b"` + } + err := New(TypeInternal, "msg").Interface("p", payload{A: 1, B: "x"}).Any("p2", payload{A: 2, B: "y"}).Build() + ee := err.(*ExErr) + v1, ok := ee.GetMeta("p") + tst.AssertTrue(t, ok) + mv := v1.(AnyWrap) + tst.AssertTrue(t, strings.Contains(mv.Json, "\"a\":1")) + _, ok = ee.GetMeta("p2") + tst.AssertTrue(t, ok) +} + +func TestBuilderStack(t *testing.T) { + err := New(TypeInternal, "msg").Stack().Build() + ee := err.(*ExErr) + v, ok := ee.GetMetaString("@Stack") + tst.AssertTrue(t, ok) + tst.AssertTrue(t, len(v) > 0) +} + +func TestBuilderErrs(t *testing.T) { + in := []error{errors.New("first"), errors.New("second")} + err := New(TypeInternal, "msg").Errs("errs", in).Build() + ee := err.(*ExErr) + + v0, ok := ee.GetMetaString("errs[0]") + tst.AssertTrue(t, ok) + tst.AssertTrue(t, strings.Contains(v0, "first")) + v1, ok := ee.GetMetaString("errs[1]") + tst.AssertTrue(t, ok) + tst.AssertTrue(t, strings.Contains(v1, "second")) +} + +type stringerImpl struct{ s string } + +func (s stringerImpl) String() string { return s.s } + +func TestBuilderStringerAndId(t *testing.T) { + err := New(TypeInternal, "msg"). + Stringer("s", stringerImpl{s: "hello"}). + Id("id", stringerImpl{s: "abc-123"}). + Build() + ee := err.(*ExErr) + + v, _ := ee.GetMetaString("s") + tst.AssertEqual(t, v, "hello") + + idv, ok := ee.GetMeta("id") + tst.AssertTrue(t, ok) + w := idv.(IDWrap) + tst.AssertEqual(t, w.Value, "abc-123") +} + +func TestBuilderObjectID(t *testing.T) { + oid := bson.NewObjectID() + err := New(TypeInternal, "msg").ObjectID("oid", oid).Build() + ee := err.(*ExErr) + mv, ok := ee.Meta["oid"] + tst.AssertTrue(t, ok) + tst.AssertEqual(t, mv.DataType, MDTObjectID) + tst.AssertEqual(t, mv.Value.(bson.ObjectID), oid) +} + +func TestBuilderStrPtr(t *testing.T) { + s := "hello" + err := New(TypeInternal, "msg").StrPtr("p", &s).StrPtr("n", nil).Build() + ee := err.(*ExErr) + _, ok := ee.Meta["p"] + tst.AssertTrue(t, ok) + _, ok = ee.Meta["n"] + tst.AssertTrue(t, ok) +} + +func TestBuilderMetaCollision(t *testing.T) { + err := New(TypeInternal, "msg").Str("k", "v1").Str("k", "v2").Str("k", "v3").Build() + ee := err.(*ExErr) + + v1, _ := ee.GetMetaString("k") + tst.AssertEqual(t, v1, "v1") + + v2, ok := ee.GetMetaString("k-2") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v2, "v2") + + v3, ok := ee.GetMetaString("k-3") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v3, "v3") +} + +// ============================================================================ +// FromError tests +// ============================================================================ + +func TestFromErrorNil(t *testing.T) { + ee := FromError(nil) + tst.AssertTrue(t, ee != nil) + tst.AssertEqual(t, ee.Category, CatForeign) + tst.AssertEqual(t, ee.WrappedErrType, "nil") +} + +func TestFromErrorPassThrough(t *testing.T) { + orig := New(TypeInternal, "msg").Build().(*ExErr) + got := FromError(orig) + tst.AssertTrue(t, orig == got) +} + +func TestFromErrorForeign(t *testing.T) { + in := errors.New("standard") + ee := FromError(in) + tst.AssertEqual(t, ee.Category, CatForeign) + tst.AssertEqual(t, ee.Message, "standard") +} + +// ============================================================================ +// ExErr method tests +// ============================================================================ + +func TestErrorReturnsRecursiveMessage(t *testing.T) { + in := errors.New("orig") + err := Wrap(in, "outer").Build() + tst.AssertTrue(t, strings.Contains(err.Error(), "outer") || strings.Contains(err.Error(), "orig")) +} + +func TestUnwrapReturnsOriginalError(t *testing.T) { + e1 := New(TypeInternal, "inner").Build().(*ExErr) + e2 := Wrap(e1, "outer").Build().(*ExErr) + tst.AssertEqual(t, e2.Unwrap().(*ExErr) == e1, true) +} + +func TestUnwrapForeign(t *testing.T) { + in := errors.New("std") + ee := FromError(in) + u := ee.Unwrap() + tst.AssertEqual(t, u.Error(), "std") +} + +func TestUnwrapNilWhenNoneAvailable(t *testing.T) { + ee := New(TypeInternal, "msg").Build().(*ExErr) + tst.AssertTrue(t, ee.Unwrap() == nil) +} + +func TestRecursiveMessageSkipsForeign(t *testing.T) { + in := errors.New("foreign-msg") + err := Wrap(in, "wrapped-msg").Build().(*ExErr) + tst.AssertEqual(t, err.RecursiveMessage(), "wrapped-msg") +} + +func TestRecursiveMessageFallback(t *testing.T) { + ee := &ExErr{Message: "self"} + tst.AssertEqual(t, ee.RecursiveMessage(), "self") +} + +func TestRecursiveType(t *testing.T) { + e1 := New(TypeAssert, "inner").Build().(*ExErr) + e2 := Wrap(e1, "outer").Build().(*ExErr) + tst.AssertEqual(t, e2.RecursiveType(), TypeAssert) +} + +func TestRecursiveStatuscode(t *testing.T) { + e1 := New(TypeInternal, "inner").WithStatuscode(404).Build().(*ExErr) + e2 := Wrap(e1, "outer").Build().(*ExErr) + + got := e2.RecursiveStatuscode() + tst.AssertDeRefEqual(t, got, 404) +} + +func TestRecursiveStatuscodeNil(t *testing.T) { + e1 := New(TypeWrap, "x").Build().(*ExErr) + tst.AssertTrue(t, e1.RecursiveStatuscode() == nil) +} + +func TestRecursiveCategory(t *testing.T) { + e1 := New(TypeInternal, "inner").User().Build().(*ExErr) + e2 := Wrap(e1, "outer").Build().(*ExErr) + tst.AssertEqual(t, e2.RecursiveCategory(), CatUser) +} + +func TestRecursiveMeta(t *testing.T) { + e1 := New(TypeInternal, "inner").Str("xkey", "xval").Build().(*ExErr) + e2 := Wrap(e1, "outer").Build().(*ExErr) + + mv := e2.RecursiveMeta("xkey") + tst.AssertTrue(t, mv != nil) + tst.AssertEqual(t, mv.Value.(string), "xval") + tst.AssertTrue(t, e2.RecursiveMeta("nope") == nil) +} + +func TestDepth(t *testing.T) { + e1 := New(TypeInternal, "1").Build().(*ExErr) + tst.AssertEqual(t, e1.Depth(), 1) + e2 := Wrap(e1, "2").Build().(*ExErr) + tst.AssertEqual(t, e2.Depth(), 2) + e3 := Wrap(e2, "3").Build().(*ExErr) + tst.AssertEqual(t, e3.Depth(), 3) +} + +func TestGetMetaStringTypeMismatch(t *testing.T) { + err := New(TypeInternal, "msg").Int("k", 1).Build().(*ExErr) + _, ok := err.GetMetaString("k") + tst.AssertFalse(t, ok) +} + +func TestGetMetaMissing(t *testing.T) { + err := New(TypeInternal, "msg").Build().(*ExErr) + _, ok := err.GetMeta("missing") + tst.AssertFalse(t, ok) +} + +func TestGetExtraMissing(t *testing.T) { + err := New(TypeInternal, "msg").Build().(*ExErr) + _, ok := err.GetExtra("missing") + tst.AssertFalse(t, ok) +} + +func TestUniqueIDsCollects(t *testing.T) { + e1 := New(TypeInternal, "1").Build().(*ExErr) + e2 := Wrap(e1, "2").Build().(*ExErr) + e3 := Wrap(e2, "3").Build().(*ExErr) + ids := e3.UniqueIDs() + tst.AssertEqual(t, len(ids), 3) +} + +// ============================================================================ +// Format / Log +// ============================================================================ + +func TestFormatLogShort(t *testing.T) { + err := New(TypeAssert, "boom").Build().(*ExErr) + out := err.FormatLog(LogPrintShort) + tst.AssertTrue(t, strings.Contains(out, "boom")) + tst.AssertTrue(t, strings.Contains(out, TypeAssert.Key)) +} + +func TestFormatLogOverview(t *testing.T) { + e1 := New(TypeAssert, "inner").Build().(*ExErr) + e2 := Wrap(e1, "outer").Build().(*ExErr) + out := e2.FormatLog(LogPrintOverview) + tst.AssertTrue(t, strings.Contains(out, "outer")) + tst.AssertTrue(t, strings.Contains(out, "inner")) +} + +func TestFormatLogFull(t *testing.T) { + err := New(TypeInternal, "boom").Str("k", "v").Build().(*ExErr) + out := err.FormatLog(LogPrintFull) + tst.AssertTrue(t, strings.Contains(out, "boom")) + tst.AssertTrue(t, strings.Contains(out, "k")) +} + +func TestFormatLogUnknownLevel(t *testing.T) { + err := New(TypeInternal, "boom").Build().(*ExErr) + out := err.FormatLog(LogPrintLevel("__nope__")) + tst.AssertTrue(t, strings.HasPrefix(out, "[?[")) +} + +func TestBuilderFormat(t *testing.T) { + b := New(TypeInternal, "boom") + out := b.Format(LogPrintShort) + tst.AssertTrue(t, strings.Contains(out, "boom")) +} + +// ============================================================================ +// helper.go +// ============================================================================ + +func TestIsTypeMatching(t *testing.T) { + err := New(TypeAssert, "x").Build() + tst.AssertTrue(t, IsType(err, TypeAssert)) + tst.AssertFalse(t, IsType(err, TypeNotImplemented)) + tst.AssertFalse(t, IsType(nil, TypeAssert)) +} + +func TestIsTypeRecursive(t *testing.T) { + e1 := New(TypeAssert, "inner").Build() + e2 := Wrap(e1, "outer").Build() + tst.AssertTrue(t, IsType(e2, TypeAssert)) +} + +func TestIsFromIdentity(t *testing.T) { + e := errors.New("x") + tst.AssertTrue(t, IsFrom(e, e)) + tst.AssertFalse(t, IsFrom(nil, e)) +} + +func TestIsFromForeign(t *testing.T) { + src := errors.New("origmsg") + wrap := Wrap(src, "outer").Build() + tst.AssertTrue(t, IsFrom(wrap, src)) + + other := errors.New("other") + tst.AssertFalse(t, IsFrom(wrap, other)) +} + +func TestHasSourceMessage(t *testing.T) { + src := errors.New("origmsg") + wrap := Wrap(src, "outer").Build() + tst.AssertTrue(t, HasSourceMessage(wrap, "origmsg")) + tst.AssertFalse(t, HasSourceMessage(wrap, "other")) + tst.AssertFalse(t, HasSourceMessage(nil, "x")) +} + +func TestMessageMatch(t *testing.T) { + e := New(TypeInternal, "alpha-beta").Build() + tst.AssertTrue(t, MessageMatch(e, func(s string) bool { return strings.Contains(s, "alpha") })) + tst.AssertFalse(t, MessageMatch(e, func(s string) bool { return strings.Contains(s, "missing") })) + tst.AssertFalse(t, MessageMatch(nil, func(s string) bool { return true })) +} + +func TestOriginalError(t *testing.T) { + src := errors.New("the-source") + wrap := Wrap(src, "outer").Build() + got := OriginalError(wrap) + tst.AssertEqual(t, got.Error(), "the-source") + tst.AssertTrue(t, OriginalError(nil) == nil) +} + +func TestUniqueIDHelper(t *testing.T) { + err := New(TypeInternal, "x").Build() + id := UniqueID(err) + tst.AssertTrue(t, id != nil) + tst.AssertTrue(t, *id != "") + + tst.AssertTrue(t, UniqueID(nil) == nil) + + plain := errors.New("plain") + tst.AssertTrue(t, UniqueID(plain) == nil) +} + +// ============================================================================ +// errors.Is / errors.As +// ============================================================================ + +type customErr struct{ msg string } + +func (c customErr) Error() string { return c.msg } + +func TestErrorsAsForeign(t *testing.T) { + src := customErr{msg: "x"} + wrap := Wrap(src, "outer").Build() + + var got customErr + ok := errors.As(wrap, &got) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, got.msg, "x") +} + +func TestErrorsIsForeign(t *testing.T) { + src := customErr{msg: "x"} + wrap := Wrap(src, "outer").Build() + + tst.AssertTrue(t, errors.Is(wrap, customErr{msg: "x"})) + tst.AssertFalse(t, errors.Is(wrap, customErr{msg: "y"})) +} + +// ============================================================================ +// MetaValue serialize / deserialize roundtrip +// ============================================================================ + +func TestMetaValueRoundtripPrimitives(t *testing.T) { + cases := []MetaValue{ + {DataType: MDTString, Value: "hello"}, + {DataType: MDTInt, Value: 42}, + {DataType: MDTInt8, Value: int8(8)}, + {DataType: MDTInt16, Value: int16(16)}, + {DataType: MDTInt32, Value: int32(32)}, + {DataType: MDTInt64, Value: int64(64)}, + {DataType: MDTFloat32, Value: float32(1.5)}, + {DataType: MDTFloat64, Value: float64(2.5)}, + {DataType: MDTBool, Value: true}, + {DataType: MDTBool, Value: false}, + {DataType: MDTBytes, Value: []byte{0x01, 0x02, 0xAB}}, + {DataType: MDTStringArray, Value: []string{"a", "b"}}, + {DataType: MDTIntArray, Value: []int{1, 2, 3}}, + {DataType: MDTInt32Array, Value: []int32{4, 5}}, + {DataType: MDTNil, Value: nil}, + } + + for _, c := range cases { + s, err := c.SerializeValue() + tst.AssertNoErr(t, err) + + var dec MetaValue + tst.AssertNoErr(t, dec.Deserialize(s, c.DataType)) + tst.AssertStrRepEqual(t, dec.Value, c.Value) + } +} + +func TestMetaValueRoundtripStringPtr(t *testing.T) { + v := "hello" + mv := MetaValue{DataType: MDTStringPtr, Value: &v} + s, err := mv.SerializeValue() + tst.AssertNoErr(t, err) + var dec MetaValue + tst.AssertNoErr(t, dec.Deserialize(s, MDTStringPtr)) + tst.AssertEqual(t, *(dec.Value.(*string)), v) + + mv2 := MetaValue{DataType: MDTStringPtr, Value: (*string)(nil)} + s, err = mv2.SerializeValue() + tst.AssertNoErr(t, err) + tst.AssertEqual(t, s, "#") +} + +func TestMetaValueRoundtripTime(t *testing.T) { + tm := time.Date(2025, 4, 27, 12, 34, 56, 12345, time.UTC) + mv := MetaValue{DataType: MDTTime, Value: tm} + s, err := mv.SerializeValue() + tst.AssertNoErr(t, err) + var dec MetaValue + tst.AssertNoErr(t, dec.Deserialize(s, MDTTime)) + got := dec.Value.(time.Time) + tst.AssertEqual(t, got.Unix(), tm.Unix()) + tst.AssertEqual(t, got.Nanosecond(), tm.Nanosecond()) +} + +func TestMetaValueRoundtripDuration(t *testing.T) { + d := 3*time.Second + 250*time.Millisecond + mv := MetaValue{DataType: MDTDuration, Value: d} + s, err := mv.SerializeValue() + tst.AssertNoErr(t, err) + var dec MetaValue + tst.AssertNoErr(t, dec.Deserialize(s, MDTDuration)) + tst.AssertEqual(t, dec.Value.(time.Duration), d) +} + +func TestMetaValueRoundtripObjectID(t *testing.T) { + oid := bson.NewObjectID() + mv := MetaValue{DataType: MDTObjectID, Value: oid} + s, err := mv.SerializeValue() + tst.AssertNoErr(t, err) + var dec MetaValue + tst.AssertNoErr(t, dec.Deserialize(s, MDTObjectID)) + tst.AssertEqual(t, dec.Value.(bson.ObjectID), oid) +} + +func TestMetaValueDeserializeUnknownType(t *testing.T) { + var mv MetaValue + err := mv.Deserialize("x", metaDataType("unknown")) + tst.AssertTrue(t, err != nil) +} + +func TestMetaValueDeserializeBadBool(t *testing.T) { + var mv MetaValue + err := mv.Deserialize("nope", MDTBool) + tst.AssertTrue(t, err != nil) +} + +func TestMetaValueShortString(t *testing.T) { + cases := []struct { + mv MetaValue + out string + }{ + {MetaValue{DataType: MDTString, Value: "hello"}, "hello"}, + {MetaValue{DataType: MDTInt, Value: 42}, "42"}, + {MetaValue{DataType: MDTBool, Value: true}, "true"}, + {MetaValue{DataType: MDTNil, Value: nil}, "<>"}, + } + for _, c := range cases { + tst.AssertEqual(t, c.mv.ShortString(100), c.out) + } +} + +func TestMetaValueValueString(t *testing.T) { + mv := MetaValue{DataType: MDTInt, Value: 42} + tst.AssertEqual(t, mv.ValueString(), "42") + + mv = MetaValue{DataType: MDTString, Value: "ok"} + tst.AssertEqual(t, mv.ValueString(), "ok") +} + +func TestMetaValueJSONMarshal(t *testing.T) { + mv := MetaValue{DataType: MDTString, Value: "abc"} + bin, err := json.Marshal(mv) + tst.AssertNoErr(t, err) + + var dec MetaValue + tst.AssertNoErr(t, json.Unmarshal(bin, &dec)) + tst.AssertEqual(t, dec.DataType, MDTString) + tst.AssertEqual(t, dec.Value.(string), "abc") +} + +func TestMetaValueJSONMarshalInvalidString(t *testing.T) { + var mv MetaValue + err := json.Unmarshal([]byte("\"badformat\""), &mv) + tst.AssertTrue(t, err != nil) +} + +// ============================================================================ +// MetaMap +// ============================================================================ + +func TestMetaMapAny(t *testing.T) { + mm := MetaMap{} + tst.AssertFalse(t, mm.Any()) + mm.add("k", MDTString, "v") + tst.AssertTrue(t, mm.Any()) +} + +func TestMetaMapFormatOneLine(t *testing.T) { + mm := MetaMap{} + mm.add("k", MDTString, "v") + out := mm.FormatOneLine(100) + tst.AssertTrue(t, strings.Contains(out, "k")) + tst.AssertTrue(t, strings.Contains(out, "v")) +} + +func TestMetaMapFormatMultiLine(t *testing.T) { + mm := MetaMap{} + mm.add("k1", MDTString, "v1") + mm.add("gin_body", MDTString, "should-be-skipped") + out := mm.FormatMultiLine("", " ", 100) + tst.AssertTrue(t, strings.Contains(out, "k1")) + tst.AssertFalse(t, strings.Contains(out, "should-be-skipped")) +} + +// ============================================================================ +// typeWrapper.go +// ============================================================================ + +func TestIDWrap(t *testing.T) { + w := newIDWrap(stringerImpl{s: "id-1"}) + tst.AssertEqual(t, w.Value, "id-1") + tst.AssertFalse(t, w.IsNil) + + s := w.Serialize() + got := deserializeIDWrap(s) + tst.AssertEqual(t, got.Value, "id-1") + tst.AssertEqual(t, got.Type, w.Type) +} + +func TestIDWrapNil(t *testing.T) { + var nilStringer fmt.Stringer = (*stringerImpl)(nil) + w := newIDWrap(nilStringer) + tst.AssertTrue(t, w.IsNil) + + s := w.Serialize() + got := deserializeIDWrap(s) + tst.AssertTrue(t, got.IsNil) +} + +func TestAnyWrap(t *testing.T) { + type p struct { + X int `json:"x"` + } + w := newAnyWrap(p{X: 7}) + tst.AssertFalse(t, w.IsError) + tst.AssertFalse(t, w.IsNil) + tst.AssertTrue(t, strings.Contains(w.Json, "\"x\":7")) + + s := w.Serialize() + got := deserializeAnyWrap(s) + tst.AssertEqual(t, got.IsError, false) + tst.AssertTrue(t, strings.Contains(got.Json, "\"x\":7")) +} + +func TestAnyWrapNil(t *testing.T) { + w := newAnyWrap(nil) + tst.AssertTrue(t, w.IsNil) + + s := w.Serialize() + got := deserializeAnyWrap(s) + tst.AssertTrue(t, got.IsNil) +} + +func TestAnyWrapDeserializeBad(t *testing.T) { + got := deserializeAnyWrap("xx") + tst.AssertTrue(t, got.IsError) +} + +// ============================================================================ +// dataType.go +// ============================================================================ + +func TestNewTypeRegisters(t *testing.T) { + custom := NewType("UNIT_TEST_CUSTOM_TYPE", new(503)) + tst.AssertEqual(t, custom.Key, "UNIT_TEST_CUSTOM_TYPE") + tst.AssertDeRefEqual(t, custom.DefaultStatusCode, 503) + + all := ListRegisteredTypes() + found := false + for _, et := range all { + if et.Key == "UNIT_TEST_CUSTOM_TYPE" { + found = true + break + } + } + tst.AssertTrue(t, found) +} + +func TestErrorTypeJSONUnmarshalKnown(t *testing.T) { + var et ErrorType + tst.AssertNoErr(t, json.Unmarshal([]byte("\"NOT_IMPLEMENTED\""), &et)) + tst.AssertEqual(t, et.Key, TypeNotImplemented.Key) +} + +func TestErrorTypeJSONUnmarshalUnknown(t *testing.T) { + var et ErrorType + tst.AssertNoErr(t, json.Unmarshal([]byte("\"COMPLETELY_UNKNOWN_TYPE_QQQ\""), &et)) + tst.AssertEqual(t, et.Key, "COMPLETELY_UNKNOWN_TYPE_QQQ") + tst.AssertTrue(t, et.DefaultStatusCode == nil) +} + +func TestCategoryAndSeverityJSONMarshal(t *testing.T) { + bin, err := json.Marshal(CatUser) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(bin), "\"User\"") + + var c ErrorCategory + tst.AssertNoErr(t, json.Unmarshal(bin, &c)) + tst.AssertEqual(t, c, CatUser) + + bin, err = json.Marshal(SevWarn) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, string(bin), "\"Warn\"") +} + +// ============================================================================ +// proxy.go +// ============================================================================ + +func TestProxy(t *testing.T) { + ee := New(TypeInternal, "x").Build().(*ExErr) + p := Proxy{v: *ee} + tst.AssertEqual(t, p.UniqueID(), ee.UniqueID) + tst.AssertEqual(t, p.Get().Message, ee.Message) +} + +// ============================================================================ +// listener.go +// ============================================================================ + +func TestRegisterListenerInvokedOnBuild(t *testing.T) { + var gotMethod Method + var gotErr *ExErr + RegisterListener(func(method Method, v *ExErr, opt ListenerOpt) { + if v != nil && strings.Contains(v.Message, "listener-marker-1") { + gotMethod = method + gotErr = v + } + }) + + _ = New(TypeInternal, "listener-marker-1").Build() + + tst.AssertEqual(t, gotMethod, MethodBuild) + tst.AssertTrue(t, gotErr != nil) +} + +// ============================================================================ +// Initialized +// ============================================================================ + +func TestInitialized(t *testing.T) { + tst.AssertTrue(t, Initialized()) +} + +// ============================================================================ +// JSON output (toJson / ToAPIJson) +// ============================================================================ + +func TestToAPIJsonContainsCoreFields(t *testing.T) { + err := New(TypeInternal, "boom").Str("k", "v").Extra("ex", 1).Build().(*ExErr) + out := err.ToAPIJson(false, true, true) + + tst.AssertEqual(t, out["errorid"].(string), err.UniqueID) + tst.AssertEqual(t, out["errorcode"].(string), TypeInternal.Key) + tst.AssertEqual(t, out["category"].(string), CatSystem.Category) + tst.AssertEqual(t, out["message"].(string), "boom") + + _, hasData := out["__data"] + tst.AssertTrue(t, hasData) + + tst.AssertEqual(t, out["ex"].(int), 1) +} + +func TestToDefaultAPIJson(t *testing.T) { + err := New(TypeInternal, "boom").Build().(*ExErr) + out, jerr := err.ToDefaultAPIJson() + tst.AssertNoErr(t, jerr) + tst.AssertTrue(t, strings.Contains(out, "boom")) + tst.AssertTrue(t, strings.Contains(out, err.UniqueID)) +} diff --git a/fsext/exists_test.go b/fsext/exists_test.go new file mode 100644 index 0000000..3c2522f --- /dev/null +++ b/fsext/exists_test.go @@ -0,0 +1,264 @@ +package fsext + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +func TestPathExistsFile(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "file.txt") + if err := os.WriteFile(fp, []byte("hello"), 0644); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := PathExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("PathExists(%q) = false, want true", fp) + } +} + +func TestPathExistsDirectory(t *testing.T) { + dir := t.TempDir() + + ok, err := PathExists(dir) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("PathExists(%q) = false, want true", dir) + } +} + +func TestPathExistsMissing(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "does_not_exist") + + ok, err := PathExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("PathExists(%q) = true, want false", fp) + } +} + +func TestPathExistsMissingNested(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "nope", "still_nope", "file.txt") + + ok, err := PathExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("PathExists(%q) = true, want false", fp) + } +} + +func TestFileExistsFile(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "file.txt") + if err := os.WriteFile(fp, []byte("data"), 0644); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := FileExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("FileExists(%q) = false, want true", fp) + } +} + +func TestFileExistsEmptyFile(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "empty.txt") + if err := os.WriteFile(fp, []byte{}, 0644); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := FileExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("FileExists(%q) = false, want true", fp) + } +} + +func TestFileExistsDirectoryReturnsFalse(t *testing.T) { + dir := t.TempDir() + + ok, err := FileExists(dir) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("FileExists(%q) = true, want false (it's a directory)", dir) + } +} + +func TestFileExistsMissing(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "missing.txt") + + ok, err := FileExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("FileExists(%q) = true, want false", fp) + } +} + +func TestDirectoryExistsDirectory(t *testing.T) { + dir := t.TempDir() + + ok, err := DirectoryExists(dir) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("DirectoryExists(%q) = false, want true", dir) + } +} + +func TestDirectoryExistsNestedDirectory(t *testing.T) { + dir := t.TempDir() + nested := filepath.Join(dir, "a", "b", "c") + if err := os.MkdirAll(nested, 0755); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := DirectoryExists(nested) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("DirectoryExists(%q) = false, want true", nested) + } +} + +func TestDirectoryExistsFileReturnsFalse(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "file.txt") + if err := os.WriteFile(fp, []byte("data"), 0644); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := DirectoryExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("DirectoryExists(%q) = true, want false (it's a file)", fp) + } +} + +func TestDirectoryExistsMissing(t *testing.T) { + dir := t.TempDir() + fp := filepath.Join(dir, "missing") + + ok, err := DirectoryExists(fp) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("DirectoryExists(%q) = true, want false", fp) + } +} + +func TestPathExistsEmptyString(t *testing.T) { + ok, err := PathExists("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("PathExists(\"\") = true, want false") + } +} + +func TestFileExistsEmptyString(t *testing.T) { + ok, err := FileExists("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("FileExists(\"\") = true, want false") + } +} + +func TestDirectoryExistsEmptyString(t *testing.T) { + ok, err := DirectoryExists("") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("DirectoryExists(\"\") = true, want false") + } +} + +func TestPathExistsSymlinkToFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlinks require admin privileges on windows") + } + dir := t.TempDir() + target := filepath.Join(dir, "target.txt") + if err := os.WriteFile(target, []byte("data"), 0644); err != nil { + t.Fatalf("setup: %v", err) + } + link := filepath.Join(dir, "link.txt") + if err := os.Symlink(target, link); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := PathExists(link) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("PathExists(symlink) = false, want true") + } + + ok, err = FileExists(link) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Errorf("FileExists(symlink-to-file) = false, want true") + } + + ok, err = DirectoryExists(link) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("DirectoryExists(symlink-to-file) = true, want false") + } +} + +func TestPathExistsBrokenSymlink(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlinks require admin privileges on windows") + } + dir := t.TempDir() + link := filepath.Join(dir, "broken") + if err := os.Symlink(filepath.Join(dir, "nonexistent_target"), link); err != nil { + t.Fatalf("setup: %v", err) + } + + ok, err := PathExists(link) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ok { + t.Errorf("PathExists(broken-symlink) = true, want false") + } +} diff --git a/ginext/appContext_test.go b/ginext/appContext_test.go new file mode 100644 index 0000000..dc7db4d --- /dev/null +++ b/ginext/appContext_test.go @@ -0,0 +1,121 @@ +package ginext + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" +) + +func TestCreateBackgroundAppContext(t *testing.T) { + ac := CreateBackgroundAppContext() + if ac == nil { + t.Fatalf("expected non-nil context") + } + if ac.GinContext != nil { + t.Fatalf("background context should have no gin context") + } + if ac.Err() != nil { + t.Fatalf("expected no error") + } + if _, ok := ac.Deadline(); ok { + t.Fatalf("background context should have no deadline") + } +} + +func TestCreateAppContext_CopiesGinKeys(t *testing.T) { + rec := httptest.NewRecorder() + g, _ := gin.CreateTestContext(rec) + g.Request = httptest.NewRequest(http.MethodGet, "/", nil) + g.Set("foo", "bar") + g.Set("num", 42) + + inner, cancel := context.WithCancel(context.Background()) + defer cancel() + ac := CreateAppContext(g, inner, cancel) + + if ac.Value("foo") != "bar" { + t.Fatalf("expected key foo to be copied") + } + if ac.Value("num") != 42 { + t.Fatalf("expected key num to be copied") + } + if ac.GinContext != g { + t.Fatalf("expected GinContext to be set") + } +} + +func TestAppContext_Set(t *testing.T) { + ac := CreateBackgroundAppContext() + ac.Set("k", "v") + if ac.Value("k") != "v" { + t.Fatalf("expected Set to store value") + } +} + +func TestAppContext_Cancel(t *testing.T) { + called := false + cancel := func() { called = true } + ac := &AppContext{ + inner: context.Background(), + cancelFunc: cancel, + } + ac.Cancel() + if !called { + t.Fatalf("expected cancel function to be invoked") + } + if !ac.cancelled { + t.Fatalf("expected cancelled flag set") + } +} + +func TestAppContext_DeadlineDoneErr(t *testing.T) { + deadline := time.Now().Add(1 * time.Hour) + inner, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + ac := &AppContext{inner: inner, cancelFunc: cancel} + d, ok := ac.Deadline() + if !ok { + t.Fatalf("expected deadline ok") + } + if !d.Equal(deadline) { + t.Fatalf("deadline mismatch") + } + if ac.Done() == nil { + t.Fatalf("expected non-nil Done channel") + } + if ac.Err() != nil { + t.Fatalf("expected no err yet") + } + + cancel() + // After cancel, Err should return Canceled + if !errors.Is(ac.Err(), context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", ac.Err()) + } +} + +func TestAppContext_RequestURI(t *testing.T) { + bg := CreateBackgroundAppContext() + if bg.RequestURI() != "" { + t.Fatalf("expected empty for background context") + } + + rec := httptest.NewRecorder() + g, _ := gin.CreateTestContext(rec) + g.Request = httptest.NewRequest(http.MethodPost, "/foo/bar", nil) + + inner, cancel := context.WithCancel(context.Background()) + defer cancel() + ac := CreateAppContext(g, inner, cancel) + + uri := ac.RequestURI() + if uri != "POST :: /foo/bar" { + t.Fatalf("expected POST :: /foo/bar, got %q", uri) + } +} diff --git a/ginext/commonHandler_test.go b/ginext/commonHandler_test.go new file mode 100644 index 0000000..e32add2 --- /dev/null +++ b/ginext/commonHandler_test.go @@ -0,0 +1,33 @@ +package ginext + +import ( + "net/http" + "testing" +) + +func TestRedirectFound(t *testing.T) { + hf := RedirectFound("/x") + resp := hf(PreContext{}) + if resp == nil { + t.Fatalf("expected response") + } + if resp.(InspectableHTTPResponse).Statuscode() != http.StatusFound { + t.Fatalf("expected 302") + } +} + +func TestRedirectTemporary(t *testing.T) { + hf := RedirectTemporary("/x") + resp := hf(PreContext{}) + if resp.(InspectableHTTPResponse).Statuscode() != http.StatusTemporaryRedirect { + t.Fatalf("expected 307") + } +} + +func TestRedirectPermanent(t *testing.T) { + hf := RedirectPermanent("/x") + resp := hf(PreContext{}) + if resp.(InspectableHTTPResponse).Statuscode() != http.StatusPermanentRedirect { + t.Fatalf("expected 308") + } +} diff --git a/ginext/commonMiddlewares_test.go b/ginext/commonMiddlewares_test.go new file mode 100644 index 0000000..07b3dc6 --- /dev/null +++ b/ginext/commonMiddlewares_test.go @@ -0,0 +1,45 @@ +package ginext + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestBodyBuffer_WrapsBody(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader("payload")) + + original := c.Request.Body + BodyBuffer(c) + if c.Request.Body == original { + t.Fatalf("expected body to be replaced with buffered reader") + } + + data, err := io.ReadAll(c.Request.Body) + if err != nil { + t.Fatalf("read err: %v", err) + } + if !bytes.Equal(data, []byte("payload")) { + t.Fatalf("body mismatch: %q", data) + } +} + +func TestBodyBuffer_NilBody(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Body = nil + + // Should not panic + BodyBuffer(c) + if c.Request.Body != nil { + t.Fatalf("expected nil body to remain nil") + } +} diff --git a/ginext/cors_test.go b/ginext/cors_test.go new file mode 100644 index 0000000..5b3cff9 --- /dev/null +++ b/ginext/cors_test.go @@ -0,0 +1,81 @@ +package ginext + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestCorsMiddleware_SetsHeaders(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + mw := CorsMiddleware([]string{"X-Foo", "X-Bar"}, []string{"X-Exposed"}) + mw(c) + + h := rec.Header() + if h.Get("Access-Control-Allow-Origin") != "*" { + t.Fatalf("expected Allow-Origin *") + } + if h.Get("Access-Control-Allow-Credentials") != "true" { + t.Fatalf("expected Allow-Credentials true") + } + if h.Get("Access-Control-Allow-Headers") != "X-Foo, X-Bar" { + t.Fatalf("expected Allow-Headers X-Foo, X-Bar got %q", h.Get("Access-Control-Allow-Headers")) + } + if h.Get("Access-Control-Expose-Headers") != "X-Exposed" { + t.Fatalf("expected Expose-Headers X-Exposed got %q", h.Get("Access-Control-Expose-Headers")) + } + allowMethods := h.Get("Access-Control-Allow-Methods") + for _, want := range []string{"OPTIONS", "GET", "POST", "PUT", "PATCH", "DELETE", "COUNT"} { + if !strings.Contains(allowMethods, want) { + t.Errorf("expected Allow-Methods to contain %q, got %q", want, allowMethods) + } + } +} + +func TestCorsMiddleware_NoExposeHeader(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + mw := CorsMiddleware([]string{"X-Foo"}, []string{}) + mw(c) + + if _, ok := rec.Header()["Access-Control-Expose-Headers"]; ok { + t.Fatalf("expected Expose-Headers to be unset when empty") + } +} + +func TestCorsMiddleware_OptionsAborts(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + + mw := CorsMiddleware([]string{"X-Foo"}, nil) + mw(c) + + if !c.IsAborted() { + t.Fatalf("expected context aborted on OPTIONS") + } + if rec.Code != http.StatusOK { + t.Fatalf("expected 200 on OPTIONS, got %d", rec.Code) + } +} + +func TestCorsMiddleware_NonOptionsContinues(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + mw := CorsMiddleware([]string{"X-Foo"}, nil) + mw(c) + + if c.IsAborted() { + t.Fatalf("non-OPTIONS request should not be aborted") + } +} diff --git a/ginext/engine_test.go b/ginext/engine_test.go new file mode 100644 index 0000000..d404bfd --- /dev/null +++ b/ginext/engine_test.go @@ -0,0 +1,174 @@ +package ginext + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "github.com/gin-gonic/gin" +) + +func TestNewEngine_DefaultsApplied(t *testing.T) { + w := NewEngine(Options{}) + if w == nil { + t.Fatalf("expected non-nil wrapper") + } + if w.engine == nil { + t.Fatalf("expected gin engine") + } + if w.allowCors { + t.Fatalf("expected allowCors default false") + } + if w.bufferBody { + t.Fatalf("expected bufferBody default false") + } + if !w.ginDebug { + t.Fatalf("expected ginDebug default true") + } + if w.requestTimeout != 24*time.Hour { + t.Fatalf("expected default 24h timeout, got %s", w.requestTimeout) + } +} + +func TestNewEngine_OptionsHonored(t *testing.T) { + allowCors := true + bufferBody := true + suppress := true + debug := false + timeout := 5 * time.Second + + w := NewEngine(Options{ + AllowCors: &allowCors, + BufferBody: &bufferBody, + SuppressGinLogs: &suppress, + GinDebug: &debug, + Timeout: &timeout, + CorsAllowHeader: &[]string{"X-Custom"}, + }) + + if !w.allowCors { + t.Fatalf("allowCors") + } + if !w.bufferBody { + t.Fatalf("bufferBody") + } + if !w.suppressGinLogs { + t.Fatalf("suppressGinLogs") + } + if w.ginDebug { + t.Fatalf("ginDebug should be false") + } + if w.requestTimeout != timeout { + t.Fatalf("timeout mismatch") + } + if !langext.ArrEqualsExact(w.corsAllowHeader, []string{"X-Custom"}) { + t.Fatalf("expected custom allow header") + } +} + +func TestNewEngine_BuildRequestBindError_DefaultIsErrorWrapper(t *testing.T) { + w := NewEngine(Options{}) + if w.buildRequestBindError == nil { + t.Fatalf("expected default builder") + } + resp := w.buildRequestBindError(nil, "URI", http.ErrAbortHandler) + if resp == nil { + t.Fatalf("expected response") + } + if resp.IsSuccess() { + t.Fatalf("expected error response, not success") + } +} + +func TestNewEngine_BuildRequestBindError_Custom(t *testing.T) { + called := false + custom := func(c *gin.Context, fieldtype string, err error) HTTPResponse { + called = true + return Status(http.StatusTeapot) + } + _ = custom // referenced below to avoid unused warning if signature mismatch + w := NewEngine(Options{BuildRequestBindError: custom}) + resp := w.buildRequestBindError(nil, "URI", http.ErrAbortHandler) + if !called { + t.Fatalf("expected custom builder to be invoked") + } + if resp.(InspectableHTTPResponse).Statuscode() != http.StatusTeapot { + t.Fatalf("expected 418 from custom builder") + } +} + +func TestServeHTTP_RoundTrip(t *testing.T) { + w := NewEngine(Options{}) + w.Routes().GET("/hello").Handle(func(p PreContext) HTTPResponse { + return Text(http.StatusOK, "world") + }) + + req := httptest.NewRequest(http.MethodGet, "/hello", nil) + rec := w.ServeHTTP(req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "world" { + t.Fatalf("expected world, got %q", rec.Body.String()) + } +} + +func TestForwardRequest(t *testing.T) { + w := NewEngine(Options{}) + w.Routes().GET("/fwd").Handle(func(p PreContext) HTTPResponse { + return Text(http.StatusOK, "ok") + }) + + req := httptest.NewRequest(http.MethodGet, "/fwd", nil) + rec := httptest.NewRecorder() + w.ForwardRequest(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } +} + +func TestListRoutes(t *testing.T) { + w := NewEngine(Options{}) + w.Routes().GET("/a").Handle(func(p PreContext) HTTPResponse { return Status(200) }) + w.Routes().POST("/b").Handle(func(p PreContext) HTTPResponse { return Status(200) }) + + rs := w.ListRoutes() + if len(rs) < 2 { + t.Fatalf("expected at least 2 routes, got %d", len(rs)) + } +} + +func TestDebugPrintRoutes_NoPanic(t *testing.T) { + w := NewEngine(Options{}) + w.Routes().GET("/x").Handle(func(p PreContext) HTTPResponse { return Status(200) }) + // just verify it doesn't panic + w.DebugPrintRoutes() +} + +func TestCleanMiddlewareName(t *testing.T) { + w := NewEngine(Options{ + DebugTrimHandlerPrefixes: []string{"customprefix."}, + DebugReplaceHandlerNames: map[string]string{"BadName": "GoodName"}, + }) + + cases := []struct { + in, want string + }{ + {"ginext.BodyBuffer", "[BodyBuffer]"}, + {"foo.(*GinRoutesWrapper).WithJSONFilter", "[JSONFilter]"}, + {"ginext.someThing", "someThing"}, + {"api.someThing", "someThing"}, + {"customprefix.thing", "thing"}, + {"BadName", "GoodName"}, + {"badname", "GoodName"}, + {"some.pkg.Func.func1", "some.pkg.Func"}, + {"some.pkg.Func.func1.2", "some.pkg.Func"}, + } + for _, tc := range cases { + if got := w.cleanMiddlewareName(tc.in); got != tc.want { + t.Errorf("cleanMiddlewareName(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} diff --git a/ginext/jsonFilter_test.go b/ginext/jsonFilter_test.go new file mode 100644 index 0000000..2ed1c1f --- /dev/null +++ b/ginext/jsonFilter_test.go @@ -0,0 +1,35 @@ +package ginext + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestSetJSONFilter_StoresInContext(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + SetJSONFilter(c, "myfilter") + + v := c.GetString(jsonFilterKey) + if v != "myfilter" { + t.Fatalf("expected filter to be stored, got %q", v) + } +} + +func TestSetJSONFilter_Empty(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + SetJSONFilter(c, "") + + v := c.GetString(jsonFilterKey) + if v != "" { + t.Fatalf("expected empty filter") + } +} diff --git a/ginext/response_test.go b/ginext/response_test.go new file mode 100644 index 0000000..d16e6dc --- /dev/null +++ b/ginext/response_test.go @@ -0,0 +1,320 @@ +package ginext + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func newTestCtx() (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + return c, rec +} + +func TestJSONResponse_Basics(t *testing.T) { + r := JSON(http.StatusOK, map[string]any{"a": 1}) + if !r.IsSuccess() { + t.Fatalf("expected IsSuccess true for 200") + } + ir, ok := r.(InspectableHTTPResponse) + if !ok { + t.Fatalf("expected InspectableHTTPResponse") + } + if ir.Statuscode() != http.StatusOK { + t.Fatalf("statuscode mismatch: %d", ir.Statuscode()) + } + if ir.ContentType() != "application/json" { + t.Fatalf("expected content-type application/json, got %q", ir.ContentType()) + } +} + +func TestJSONResponse_Write(t *testing.T) { + c, rec := newTestCtx() + r := JSON(http.StatusCreated, map[string]any{"hello": "world"}) + r.Write(c) + if rec.Code != http.StatusCreated { + t.Fatalf("expected 201, got %d", rec.Code) + } + if !bytes.Contains(rec.Body.Bytes(), []byte("\"hello\":\"world\"")) { + t.Fatalf("unexpected body: %s", rec.Body.String()) + } +} + +func TestJSONResponse_BodyString(t *testing.T) { + c, _ := newTestCtx() + r := JSON(http.StatusOK, map[string]any{"x": 42}) + ir := r.(InspectableHTTPResponse) + body := ir.BodyString(c) + if body == nil { + t.Fatalf("expected body, got nil") + } + if !bytes.Contains([]byte(*body), []byte("\"x\":42")) { + t.Fatalf("unexpected body: %q", *body) + } +} + +func TestJSONResponse_FilterFromContext(t *testing.T) { + c, _ := newTestCtx() + SetJSONFilter(c, "abc") + r := JSON(http.StatusOK, map[string]any{"x": 1}) + body := r.(InspectableHTTPResponse).BodyString(c) + if body == nil { + t.Fatalf("expected body") + } +} + +func TestJSONResponse_FilterOverride(t *testing.T) { + c, _ := newTestCtx() + SetJSONFilter(c, "abc") + r := JSONWithFilter(http.StatusOK, map[string]any{"x": 1}, "override") + if r == nil { + t.Fatalf("expected response") + } + if !r.IsSuccess() { + t.Fatalf("expected success") + } + body := r.(InspectableHTTPResponse).BodyString(c) + if body == nil { + t.Fatalf("expected body") + } +} + +func TestResponse_IsSuccessRanges(t *testing.T) { + cases := []struct { + code int + ok bool + }{ + {100, false}, + {199, false}, + {200, true}, + {201, true}, + {299, true}, + {300, true}, + {399, true}, + {400, false}, + {500, false}, + } + for _, tc := range cases { + r := JSON(tc.code, nil) + if r.IsSuccess() != tc.ok { + t.Errorf("status %d: expected IsSuccess=%v", tc.code, tc.ok) + } + r2 := Status(tc.code) + if r2.IsSuccess() != tc.ok { + t.Errorf("Status(%d): expected IsSuccess=%v", tc.code, tc.ok) + } + r3 := Text(tc.code, "x") + if r3.IsSuccess() != tc.ok { + t.Errorf("Text(%d): expected IsSuccess=%v", tc.code, tc.ok) + } + r4 := Data(tc.code, "text/plain", []byte("x")) + if r4.IsSuccess() != tc.ok { + t.Errorf("Data(%d): expected IsSuccess=%v", tc.code, tc.ok) + } + } +} + +func TestResponse_WithHeader(t *testing.T) { + r := JSON(http.StatusOK, nil). + WithHeader("X-Foo", "bar"). + WithHeader("X-Baz", "qux") + headers := r.(InspectableHTTPResponse).Headers() + if len(headers) != 2 { + t.Fatalf("expected 2 headers, got %d", len(headers)) + } + if headers[0] != "X-Foo=bar" || headers[1] != "X-Baz=qux" { + t.Fatalf("headers wrong: %v", headers) + } +} + +func TestResponse_WithCookie_DoesNotPanic(t *testing.T) { + r := JSON(http.StatusOK, nil). + WithCookie("session", "abc", 3600, "/", "example.com", true, true) + if r == nil { + t.Fatalf("expected response") + } + c, rec := newTestCtx() + r.Write(c) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + cookies := rec.Result().Cookies() + if len(cookies) != 1 { + t.Fatalf("expected 1 cookie, got %d", len(cookies)) + } + if cookies[0].Name != "session" || cookies[0].Value != "abc" { + t.Fatalf("cookie wrong: %+v", cookies[0]) + } +} + +func TestTextResponse(t *testing.T) { + c, rec := newTestCtx() + r := Text(http.StatusOK, "hello") + if r.(InspectableHTTPResponse).ContentType() != "text/plain" { + t.Fatalf("expected text/plain") + } + body := r.(InspectableHTTPResponse).BodyString(c) + if body == nil || *body != "hello" { + t.Fatalf("body mismatch") + } + r.Write(c) + if rec.Body.String() != "hello" { + t.Fatalf("write body mismatch: %q", rec.Body.String()) + } +} + +func TestDataResponse(t *testing.T) { + c, rec := newTestCtx() + payload := []byte{0x01, 0x02, 0x03} + r := Data(http.StatusOK, "application/octet-stream", payload) + if r.(InspectableHTTPResponse).ContentType() != "application/octet-stream" { + t.Fatalf("contenttype mismatch") + } + r.Write(c) + if !bytes.Equal(rec.Body.Bytes(), payload) { + t.Fatalf("body mismatch") + } + body := r.(InspectableHTTPResponse).BodyString(c) + if body == nil || *body != string(payload) { + t.Fatalf("BodyString mismatch") + } +} + +func TestStatusResponse(t *testing.T) { + c, rec := newTestCtx() + r := Status(http.StatusNoContent) + if r.(InspectableHTTPResponse).ContentType() != "" { + t.Fatalf("expected empty content type") + } + if r.(InspectableHTTPResponse).BodyString(c) != nil { + t.Fatalf("expected nil body") + } + r.Write(c) + if c.Writer.Status() != http.StatusNoContent { + t.Fatalf("expected 204, got %d", c.Writer.Status()) + } + _ = rec +} + +func TestRedirectResponse(t *testing.T) { + c, rec := newTestCtx() + r := Redirect(http.StatusFound, "/elsewhere") + if r.(InspectableHTTPResponse).Statuscode() != http.StatusFound { + t.Fatalf("status mismatch") + } + if r.(InspectableHTTPResponse).ContentType() != "" { + t.Fatalf("expected empty content type") + } + if r.(InspectableHTTPResponse).BodyString(c) != nil { + t.Fatalf("expected nil body") + } + r.Write(c) + if rec.Code != http.StatusFound { + t.Fatalf("expected 302, got %d", rec.Code) + } + if rec.Header().Get("Location") != "/elsewhere" { + t.Fatalf("expected Location header") + } +} + +func TestNotImplemented(t *testing.T) { + r := NotImplemented() + if r == nil { + t.Fatalf("expected response") + } + if r.IsSuccess() { + t.Fatalf("NotImplemented must not be success") + } +} + +func TestError_NotSuccess(t *testing.T) { + r := Error(http.ErrAbortHandler) + if r.IsSuccess() { + t.Fatalf("error response must not be success") + } + herr, ok := r.(HTTPErrorResponse) + if !ok { + t.Fatalf("expected HTTPErrorResponse") + } + if herr.Error() == nil { + t.Fatalf("expected non-nil err") + } +} + +func TestError_ContentTypeJSON(t *testing.T) { + r := Error(http.ErrAbortHandler) + if r.(InspectableHTTPResponse).ContentType() != "application/json" { + t.Fatalf("expected application/json") + } +} + +func TestDownloadData(t *testing.T) { + c, rec := newTestCtx() + payload := []byte("file content") + r := DownloadData(http.StatusOK, "text/plain", "f.txt", payload) + if r.(InspectableHTTPResponse).ContentType() != "text/plain" { + t.Fatalf("contenttype mismatch") + } + if r.(InspectableHTTPResponse).Statuscode() != http.StatusOK { + t.Fatalf("status mismatch") + } + body := r.(InspectableHTTPResponse).BodyString(c) + if body == nil || *body != string(payload) { + t.Fatalf("body mismatch") + } + r.Write(c) + if rec.Header().Get("Content-Disposition") == "" { + t.Fatalf("expected Content-Disposition header") + } + if !bytes.Contains([]byte(rec.Header().Get("Content-Disposition")), []byte("f.txt")) { + t.Fatalf("Content-Disposition does not contain filename: %q", rec.Header().Get("Content-Disposition")) + } +} + +func TestSeekable(t *testing.T) { + r := Seekable("foo.bin", "application/octet-stream", bytes.NewReader([]byte("xyz"))) + if !r.IsSuccess() { + t.Fatalf("seekable must be success") + } + ir := r.(InspectableHTTPResponse) + if ir.Statuscode() != 200 { + t.Fatalf("expected 200") + } + if ir.ContentType() != "application/octet-stream" { + t.Fatalf("contenttype mismatch") + } + body := ir.BodyString(nil) + if body == nil || *body != "(seekable)" { + t.Fatalf("BodyString mismatch") + } +} + +func TestFile_Builders(t *testing.T) { + r := File("text/plain", "/tmp/this-file-should-not-exist-xyz") + if !r.IsSuccess() { + t.Fatalf("File must IsSuccess true") + } + if r.(InspectableHTTPResponse).Statuscode() != 200 { + t.Fatalf("expected 200") + } + if r.(InspectableHTTPResponse).ContentType() != "text/plain" { + t.Fatalf("contenttype mismatch") + } + r2 := Download("application/pdf", "/tmp/this-file-should-not-exist-xyz", "doc.pdf") + if r2.(InspectableHTTPResponse).ContentType() != "application/pdf" { + t.Fatalf("contenttype mismatch") + } + body := r.(InspectableHTTPResponse).BodyString(nil) + if body != nil { + t.Fatalf("expected nil body for nonexistent file") + } +} diff --git a/ginext/routes_test.go b/ginext/routes_test.go new file mode 100644 index 0000000..11f84ef --- /dev/null +++ b/ginext/routes_test.go @@ -0,0 +1,278 @@ +package ginext + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestJoinPaths(t *testing.T) { + cases := []struct { + abs, rel, want string + }{ + {"", "", ""}, + {"/api", "", "/api"}, + {"/api", "users", "/api/users"}, + {"/api/", "users", "/api/users"}, + {"/api", "/users", "/api/users"}, + {"/api/", "/users/", "/api/users/"}, + {"/api", "users/", "/api/users/"}, + {"", "/users", "/users"}, + {"/", "/", "/"}, + } + for _, tc := range cases { + got := joinPaths(tc.abs, tc.rel) + if got != tc.want { + t.Errorf("joinPaths(%q,%q)=%q want %q", tc.abs, tc.rel, got, tc.want) + } + } +} + +func TestLastChar(t *testing.T) { + if lastChar("hello") != 'o' { + t.Fatalf("expected 'o'") + } + if lastChar("/") != '/' { + t.Fatalf("expected '/'") + } +} + +func TestLastChar_PanicsOnEmpty(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected panic on empty string") + } + }() + lastChar("") +} + +func sampleHandler(_ *gin.Context) {} + +func TestNameOfFunction(t *testing.T) { + name := nameOfFunction(sampleHandler) + if name == "" { + t.Fatalf("expected non-empty name") + } + // nameOfFunction strips path prefix, expecting form "ginext.sampleHandler" or similar + if name != "ginext.sampleHandler" { + t.Errorf("expected ginext.sampleHandler, got %q", name) + } +} + +type sampleStruct struct{} + +func (s sampleStruct) Method(_ *gin.Context) {} + +func TestNameOfFunction_StripsFmSuffix(t *testing.T) { + s := sampleStruct{} + // Method values get a "-fm" suffix that nameOfFunction should strip + name := nameOfFunction(s.Method) + if name == "" { + t.Fatalf("expected non-empty name") + } + if got := name; got[len(got)-3:] == "-fm" { + t.Errorf("expected -fm suffix to be stripped, got %q", name) + } +} + +func TestRoutes_GroupAndAbsPath(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes() + if rw.absPath != "" { + t.Fatalf("expected empty absPath") + } + g1 := rw.Group("/api") + if g1.absPath != "/api" { + t.Fatalf("expected /api, got %q", g1.absPath) + } + g2 := g1.Group("/v1") + if g2.absPath != "/api/v1" { + t.Fatalf("expected /api/v1, got %q", g2.absPath) + } +} + +func TestRoutes_UseAccumulatesMiddleware(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes() + + mw1 := func(c *gin.Context) {} + mw2 := func(c *gin.Context) {} + + r1 := rw.Use(mw1) + if len(r1.defaultHandler) != 1 { + t.Fatalf("expected 1 handler after Use, got %d", len(r1.defaultHandler)) + } + r2 := r1.Use(mw2) + if len(r2.defaultHandler) != 2 { + t.Fatalf("expected 2 handlers, got %d", len(r2.defaultHandler)) + } + // Original parent should be unchanged + if len(rw.defaultHandler) != 0 { + t.Fatalf("expected parent to remain unchanged, got %d", len(rw.defaultHandler)) + } +} + +func TestRoutes_GroupCopiesMiddleware(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes().Use(func(c *gin.Context) {}) + g := rw.Group("/x") + if len(g.defaultHandler) != 1 { + t.Fatalf("expected group to inherit middleware") + } +} + +func TestRoutes_MethodBuilders(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes() + + cases := []struct { + name string + build func(string) *GinRouteBuilder + want string + }{ + {"GET", rw.GET, http.MethodGet}, + {"POST", rw.POST, http.MethodPost}, + {"PUT", rw.PUT, http.MethodPut}, + {"PATCH", rw.PATCH, http.MethodPatch}, + {"DELETE", rw.DELETE, http.MethodDelete}, + {"HEAD", rw.HEAD, http.MethodHead}, + {"OPTIONS", rw.OPTIONS, http.MethodOptions}, + {"COUNT", rw.COUNT, "COUNT"}, + {"Any", rw.Any, "*"}, + } + for _, tc := range cases { + b := tc.build("/foo") + if b.method != tc.want { + t.Errorf("%s: expected method %q, got %q", tc.name, tc.want, b.method) + } + if b.relPath != "/foo" { + t.Errorf("%s: expected relPath /foo, got %q", tc.name, b.relPath) + } + } +} + +func TestRoutes_RouteBuilderUseAppends(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes() + b := rw.GET("/x") + startCount := len(b.handlers) + b.Use(func(c *gin.Context) {}) + if len(b.handlers) != startCount+1 { + t.Fatalf("expected handlers to grow by 1") + } +} + +func TestRoutes_RouteBuilderInheritsDefaultHandlers(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes().Use(func(c *gin.Context) {}) + b := rw.GET("/x") + if len(b.handlers) != 1 { + t.Fatalf("expected route to inherit default handler") + } +} + +func TestRoutes_WithJSONFilter_AddsHandler(t *testing.T) { + w := NewEngine(Options{}) + rw := w.Routes().WithJSONFilter("xyz") + if len(rw.defaultHandler) != 1 { + t.Fatalf("expected json filter middleware to be added") + } + // invoke it to verify it sets the key + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + rw.defaultHandler[0](c) + if c.GetString(jsonFilterKey) != "xyz" { + t.Fatalf("expected jsonFilterKey to be set to xyz") + } +} + +func TestRoutes_RouteBuilderWithJSONFilter(t *testing.T) { + w := NewEngine(Options{}) + b := w.Routes().GET("/x").WithJSONFilter("abc") + if len(b.handlers) != 1 { + t.Fatalf("expected handler to be added") + } +} + +func TestRoutes_HandleRegistersAndStoresSpec(t *testing.T) { + w := NewEngine(Options{}) + w.Routes().GET("/foo").Handle(func(p PreContext) HTTPResponse { + return Status(http.StatusOK) + }) + + if len(w.routeSpecs) != 1 { + t.Fatalf("expected 1 route spec, got %d", len(w.routeSpecs)) + } + spec := w.routeSpecs[0] + if spec.Method != http.MethodGet { + t.Fatalf("expected GET, got %s", spec.Method) + } + if spec.URL != "/foo" { + t.Fatalf("expected /foo, got %s", spec.URL) + } + + // Hitting the route should serve our handler + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + rec := w.ServeHTTP(req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } +} + +func TestRoutes_AnyRegistersAllMethods(t *testing.T) { + w := NewEngine(Options{}) + w.Routes().Any("/wild").Handle(func(p PreContext) HTTPResponse { + return Status(http.StatusOK) + }) + + for _, m := range anyMethods { + req := httptest.NewRequest(m, "/wild", nil) + rec := w.ServeHTTP(req) + if rec.Code != http.StatusOK { + t.Errorf("method %s: expected 200, got %d", m, rec.Code) + } + } + if w.routeSpecs[0].Method != "ANY" { + t.Fatalf("expected method label ANY, got %s", w.routeSpecs[0].Method) + } +} + +func TestRoutes_GroupedRoutes(t *testing.T) { + w := NewEngine(Options{}) + api := w.Routes().Group("/api") + api.GET("/ping").Handle(func(p PreContext) HTTPResponse { + return Text(http.StatusOK, "pong") + }) + + req := httptest.NewRequest(http.MethodGet, "/api/ping", nil) + rec := w.ServeHTTP(req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "pong" { + t.Fatalf("expected pong, got %q", rec.Body.String()) + } + if w.routeSpecs[0].URL != "/api/ping" { + t.Fatalf("expected absPath /api/ping in spec, got %s", w.routeSpecs[0].URL) + } +} + +func TestRoutes_NoRoute(t *testing.T) { + w := NewEngine(Options{}) + w.NoRoute(func(p PreContext) HTTPResponse { + return Status(http.StatusTeapot) + }) + + req := httptest.NewRequest(http.MethodGet, "/missing", nil) + rec := w.ServeHTTP(req) + if rec.Code != http.StatusTeapot { + t.Fatalf("expected 418, got %d", rec.Code) + } + + if len(w.routeSpecs) != 1 || w.routeSpecs[0].URL != "[NO_ROUTE]" { + t.Fatalf("expected NO_ROUTE spec to be recorded") + } +} diff --git a/googleapi/attachment_test.go b/googleapi/attachment_test.go new file mode 100644 index 0000000..88d11e2 --- /dev/null +++ b/googleapi/attachment_test.go @@ -0,0 +1,149 @@ +package googleapi + +import ( + "encoding/base64" + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestAttachmentDumpNormalWithFilename(t *testing.T) { + a := MailAttachment{ + IsInline: false, + ContentType: "text/plain", + Filename: "hello.txt", + Data: []byte("HelloWorld"), + } + + lines := a.dump() + joined := strings.Join(lines, "\n") + + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: text/plain; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Transfer-Encoding: base64")) + tst.AssertTrue(t, strings.Contains(joined, `Content-Disposition: attachment;filename="hello.txt"`)) + tst.AssertFalse(t, strings.Contains(joined, "Content-Disposition: inline")) + + expectedB64 := base64.StdEncoding.EncodeToString([]byte("HelloWorld")) + tst.AssertTrue(t, strings.Contains(joined, expectedB64)) +} + +func TestAttachmentDumpInlineWithFilename(t *testing.T) { + a := MailAttachment{ + IsInline: true, + ContentType: "image/png", + Filename: "img.png", + Data: []byte{0x01, 0x02, 0x03, 0x04}, + } + + lines := a.dump() + joined := strings.Join(lines, "\n") + + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: image/png; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Transfer-Encoding: base64")) + tst.AssertTrue(t, strings.Contains(joined, `Content-Disposition: inline;filename="img.png"`)) +} + +func TestAttachmentDumpNormalNoFilename(t *testing.T) { + a := MailAttachment{ + IsInline: false, + ContentType: "text/plain", + Filename: "", + Data: []byte("foo"), + } + + lines := a.dump() + joined := strings.Join(lines, "\n") + + tst.AssertTrue(t, langext.InArray("Content-Disposition: attachment", lines)) + tst.AssertFalse(t, strings.Contains(joined, "filename=")) +} + +func TestAttachmentDumpInlineNoFilename(t *testing.T) { + a := MailAttachment{ + IsInline: true, + ContentType: "text/plain", + Filename: "", + Data: []byte("foo"), + } + + lines := a.dump() + + tst.AssertTrue(t, langext.InArray("Content-Disposition: inline", lines)) +} + +func TestAttachmentDumpNoContentType(t *testing.T) { + a := MailAttachment{ + IsInline: false, + ContentType: "", + Filename: "x.bin", + Data: []byte("x"), + } + + lines := a.dump() + + for _, l := range lines { + tst.AssertFalse(t, strings.HasPrefix(l, "Content-Type:")) + } + tst.AssertTrue(t, langext.InArray("Content-Transfer-Encoding: base64", lines)) +} + +func TestAttachmentDumpEmptyData(t *testing.T) { + a := MailAttachment{ + IsInline: false, + ContentType: "application/octet-stream", + Filename: "empty.bin", + Data: []byte{}, + } + + lines := a.dump() + joined := strings.Join(lines, "\n") + + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: application/octet-stream; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Transfer-Encoding: base64")) +} + +func TestAttachmentDumpLongDataLineWrapped(t *testing.T) { + // Data needs to result in > 80 base64 chars to test the wrapping. + // 100 bytes => 136 base64 chars => should wrap into 2 lines (80 + 56). + data := make([]byte, 100) + for i := range data { + data[i] = byte(i) + } + + a := MailAttachment{ + IsInline: false, + ContentType: "application/octet-stream", + Filename: "big.bin", + Data: data, + } + + lines := a.dump() + + // Find the base64 lines (everything after the headers). + b64Lines := make([]string, 0) + foundFirstHeader := false + for _, l := range lines { + if !foundFirstHeader && (strings.HasPrefix(l, "Content-") || l == "") { + foundFirstHeader = true + continue + } + if strings.HasPrefix(l, "Content-") { + continue + } + b64Lines = append(b64Lines, l) + } + + full := strings.Join(b64Lines, "") + expected := base64.StdEncoding.EncodeToString(data) + tst.AssertEqual(t, full, expected) + + // Each line (except possibly last) should be 80 chars. + for i, l := range b64Lines { + if i < len(b64Lines)-1 { + tst.AssertEqual(t, len(l), 80) + } else { + tst.AssertTrue(t, len(l) <= 80) + } + } +} diff --git a/googleapi/mimeMessage_assert_test.go b/googleapi/mimeMessage_assert_test.go new file mode 100644 index 0000000..82e324b --- /dev/null +++ b/googleapi/mimeMessage_assert_test.go @@ -0,0 +1,259 @@ +package googleapi + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestMimeMailPlainOnly(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "Test Subject", + MailBody{Plain: "Hello plain body"}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "From: from@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "To: to@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "Subject: Test Subject")) + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: text/plain; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(mail, "Content-Transfer-Encoding: 7bit")) + tst.AssertTrue(t, strings.Contains(mail, "Hello plain body")) + tst.AssertTrue(t, strings.Contains(mail, "MIME-Version: 1.0")) + + // Each line must be terminated by CRLF. + tst.AssertTrue(t, strings.Contains(mail, "\r\n")) +} + +func TestMimeMailHTMLOnly(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{HTML: "

Hi

"}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: text/html; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(mail, "

Hi

")) + tst.AssertFalse(t, strings.Contains(mail, "multipart/")) +} + +func TestMimeMailAlternative(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{ + Plain: "Plain Body", + HTML: "

HTML Body

", + }, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: multipart/alternative;")) + tst.AssertTrue(t, strings.Contains(mail, "Plain Body")) + tst.AssertTrue(t, strings.Contains(mail, "

HTML Body

")) + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: text/plain; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: text/html; charset=UTF-8")) +} + +func TestMimeMailWithCC(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + []string{"cc@example.com"}, + nil, nil, + "S", + MailBody{Plain: "x"}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "cc@example.com")) +} + +func TestMimeMailWithBCC(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, + []string{"bcc@example.com"}, + nil, + "S", + MailBody{Plain: "x"}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "Bcc: bcc@example.com")) +} + +func TestMimeMailWithReplyTo(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, + []string{"reply@example.com"}, + "S", + MailBody{Plain: "x"}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "Reply-To: reply@example.com")) +} + +func TestMimeMailMultipleRecipients(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"a@example.com", "b@example.com", "c@example.com"}, + nil, nil, nil, + "S", + MailBody{Plain: "x"}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "a@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "b@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "c@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "a@example.com, b@example.com, c@example.com")) +} + +func TestMimeMailSubjectEncoding(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "Hällö Wörld", + MailBody{Plain: "x"}, + nil) + + // Non-ASCII subject must be quoted-printable encoded. + tst.AssertTrue(t, strings.Contains(mail, "Subject: =?UTF-8?q?")) +} + +func TestMimeMailWithNormalAttachment(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{Plain: "Body"}, + []MailAttachment{ + {Data: []byte("attached"), Filename: "f.txt", IsInline: false, ContentType: "text/plain"}, + }) + + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: multipart/mixed;")) + tst.AssertTrue(t, strings.Contains(mail, `Content-Disposition: attachment;filename="f.txt"`)) +} + +func TestMimeMailWithInlineAttachment(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{HTML: "

x

"}, + []MailAttachment{ + {Data: []byte{1, 2, 3}, Filename: "img.png", IsInline: true, ContentType: "image/png"}, + }) + + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: multipart/related;")) + tst.AssertTrue(t, strings.Contains(mail, `Content-Disposition: inline;filename="img.png"`)) +} + +func TestMimeMailWithBothAttachmentTypes(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{HTML: "

x

"}, + []MailAttachment{ + {Data: []byte{1}, Filename: "img.png", IsInline: true, ContentType: "image/png"}, + {Data: []byte{2}, Filename: "f.txt", IsInline: false, ContentType: "text/plain"}, + }) + + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: multipart/mixed;")) + tst.AssertTrue(t, strings.Contains(mail, "Content-Type: multipart/related;")) + tst.AssertTrue(t, strings.Contains(mail, `Content-Disposition: inline;filename="img.png"`)) + tst.AssertTrue(t, strings.Contains(mail, `Content-Disposition: attachment;filename="f.txt"`)) +} + +func TestMimeMailEmptyBody(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{}, + nil) + + tst.AssertTrue(t, strings.Contains(mail, "From: from@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "To: to@example.com")) + tst.AssertTrue(t, strings.Contains(mail, "Subject: S")) + // No body type was set, so no Content-Type for body should be present. + tst.AssertFalse(t, strings.Contains(mail, "Content-Type: text/")) + tst.AssertFalse(t, strings.Contains(mail, "Content-Type: multipart/")) +} + +func TestMimeMailHasDateHeader(t *testing.T) { + mail := encodeMimeMail( + "from@example.com", + []string{"to@example.com"}, + nil, nil, nil, + "S", + MailBody{Plain: "x"}, + nil) + + tst.AssertTrue(t, strings.HasPrefix(mail, "Date: ")) +} + +func TestDumpMailBodyPlainOnly(t *testing.T) { + lines := dumpMailBody(MailBody{Plain: "Plain"}, false, false, "BOUND", "BOUNDALT") + joined := strings.Join(lines, "\n") + tst.AssertTrue(t, strings.Contains(joined, "--BOUND")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: text/plain; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "Plain")) +} + +func TestDumpMailBodyHTMLOnly(t *testing.T) { + lines := dumpMailBody(MailBody{HTML: "

x

"}, false, false, "BOUND", "BOUNDALT") + joined := strings.Join(lines, "\n") + tst.AssertTrue(t, strings.Contains(joined, "--BOUND")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: text/html; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "

x

")) +} + +func TestDumpMailBodyEmpty(t *testing.T) { + lines := dumpMailBody(MailBody{}, false, false, "BOUND", "BOUNDALT") + joined := strings.Join(lines, "\n") + // Default empty case still emits the boundary and a default Content-Type header. + tst.AssertTrue(t, strings.Contains(joined, "--BOUND")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: text/plain; charset=UTF-8")) +} + +func TestDumpMailBodyMixedAlternative(t *testing.T) { + // HTML+Plain with normal attachments and no inline → uses alternative sub-block. + lines := dumpMailBody(MailBody{Plain: "P", HTML: "

H

"}, false, true, "BOUND", "BOUNDALT") + joined := strings.Join(lines, "\n") + tst.AssertTrue(t, strings.Contains(joined, "--BOUND")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: multipart/alternative; boundary=BOUNDALT")) + tst.AssertTrue(t, strings.Contains(joined, "--BOUNDALT")) + tst.AssertTrue(t, strings.Contains(joined, "P")) + tst.AssertTrue(t, strings.Contains(joined, "

H

")) +} + +func TestDumpMailBodyMixedInline(t *testing.T) { + // HTML+Plain with inline attachments → simplified to single HTML block. + lines := dumpMailBody(MailBody{Plain: "P", HTML: "

H

"}, true, false, "BOUND", "BOUNDALT") + tst.AssertEqual(t, len(lines), 2) + tst.AssertEqual(t, lines[0], "--BOUND") + tst.AssertEqual(t, lines[1], "

H

") +} + +func TestDumpMailBodyBothNoAttachments(t *testing.T) { + lines := dumpMailBody(MailBody{Plain: "P", HTML: "

H

"}, false, false, "BOUND", "BOUNDALT") + joined := strings.Join(lines, "\n") + tst.AssertTrue(t, strings.Contains(joined, "--BOUND")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: text/plain; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "Content-Type: text/html; charset=UTF-8")) + tst.AssertTrue(t, strings.Contains(joined, "P")) + tst.AssertTrue(t, strings.Contains(joined, "

H

")) +} diff --git a/googleapi/oAuth_test.go b/googleapi/oAuth_test.go new file mode 100644 index 0000000..cf2f0ae --- /dev/null +++ b/googleapi/oAuth_test.go @@ -0,0 +1,59 @@ +package googleapi + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" + "time" +) + +func TestNewGoogleOAuthReturnsNonNil(t *testing.T) { + auth := NewGoogleOAuth("cid", "csecret", "rtok") + tst.AssertTrue(t, auth != nil) +} + +func TestNewGoogleOAuthFieldsSet(t *testing.T) { + auth := NewGoogleOAuth("cid", "csecret", "rtok") + c, ok := auth.(*oauth) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, c.clientID, "cid") + tst.AssertEqual(t, c.clientSecret, "csecret") + tst.AssertEqual(t, c.refreshToken, "rtok") + tst.AssertTrue(t, c.accessToken == nil) + tst.AssertTrue(t, c.expiryDate == nil) +} + +func TestOAuthAccessTokenCachedReturnsStored(t *testing.T) { + c := &oauth{ + clientID: "cid", + clientSecret: "csecret", + refreshToken: "rtok", + } + + tok := "cached-token-value" + expiry := time.Now().Add(1 * time.Hour) + c.accessToken = &tok + c.expiryDate = &expiry + + got, err := c.AccessToken() + tst.AssertNoErr(t, err) + tst.AssertEqual(t, got, "cached-token-value") +} + +func TestOAuthAccessTokenCachedMultipleCalls(t *testing.T) { + c := &oauth{ + clientID: "cid", + clientSecret: "csecret", + refreshToken: "rtok", + } + + tok := "another-token" + expiry := time.Now().Add(30 * time.Minute) + c.accessToken = &tok + c.expiryDate = &expiry + + for range 5 { + got, err := c.AccessToken() + tst.AssertNoErr(t, err) + tst.AssertEqual(t, got, "another-token") + } +} diff --git a/googleapi/service_test.go b/googleapi/service_test.go new file mode 100644 index 0000000..3aad5d8 --- /dev/null +++ b/googleapi/service_test.go @@ -0,0 +1,26 @@ +package googleapi + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestNewGoogleClientReturnsNonNil(t *testing.T) { + auth := NewGoogleOAuth("cid", "csecret", "rtok") + gc := NewGoogleClient(auth) + tst.AssertTrue(t, gc != nil) +} + +func TestNewGoogleClientWiresOAuth(t *testing.T) { + auth := NewGoogleOAuth("cid", "csecret", "rtok") + gc := NewGoogleClient(auth) + c, ok := gc.(*client) + tst.AssertTrue(t, ok) + tst.AssertTrue(t, c.oauth == auth) +} + +func TestMailBodyZeroValue(t *testing.T) { + b := MailBody{} + tst.AssertEqual(t, b.Plain, "") + tst.AssertEqual(t, b.HTML, "") +} diff --git a/imageext/enums_test.go b/imageext/enums_test.go new file mode 100644 index 0000000..36f9f46 --- /dev/null +++ b/imageext/enums_test.go @@ -0,0 +1,122 @@ +package imageext + +import ( + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestImageFit_Valid(t *testing.T) { + tst.AssertTrue(t, ImageFitStretch.Valid()) + tst.AssertTrue(t, ImageFitCover.Valid()) + tst.AssertTrue(t, ImageFitContainCenter.Valid()) + tst.AssertTrue(t, ImageFitContainTopLeft.Valid()) + tst.AssertTrue(t, ImageFitContainTopRight.Valid()) + tst.AssertTrue(t, ImageFitContainBottomLeft.Valid()) + tst.AssertTrue(t, ImageFitContainBottomRight.Valid()) + + tst.AssertFalse(t, ImageFit("UNKNOWN").Valid()) + tst.AssertFalse(t, ImageFit("").Valid()) +} + +func TestImageFit_String(t *testing.T) { + tst.AssertEqual(t, ImageFitStretch.String(), "STRETCH") + tst.AssertEqual(t, ImageFitCover.String(), "COVER") + tst.AssertEqual(t, ImageFitContainCenter.String(), "CONTAIN_CENTER") +} + +func TestImageFit_VarName(t *testing.T) { + tst.AssertEqual(t, ImageFitStretch.VarName(), "ImageFitStretch") + tst.AssertEqual(t, ImageFitContainBottomRight.VarName(), "ImageFitContainBottomRight") + tst.AssertEqual(t, ImageFit("UNKNOWN").VarName(), "") +} + +func TestImageFit_TypeName(t *testing.T) { + tst.AssertEqual(t, ImageFitStretch.TypeName(), "ImageFit") +} + +func TestImageFit_Values(t *testing.T) { + values := ImageFitValues() + tst.AssertEqual(t, len(values), 7) + tst.AssertEqual(t, values[0], ImageFitStretch) +} + +func TestImageFit_ValuesAny(t *testing.T) { + tst.AssertEqual(t, len(ImageFitStretch.ValuesAny()), 7) +} + +func TestImageFit_ValuesMeta(t *testing.T) { + meta := ImageFitValuesMeta() + tst.AssertEqual(t, len(meta), 7) + tst.AssertEqual(t, meta[0].VarName, "ImageFitStretch") +} + +func TestParseImageFit(t *testing.T) { + v, ok := ParseImageFit("COVER") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v, ImageFitCover) + + v, ok = ParseImageFit("CONTAIN_TOPLEFT") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v, ImageFitContainTopLeft) + + _, ok = ParseImageFit("BOGUS") + tst.AssertFalse(t, ok) +} + +func TestImageCompresson_Valid(t *testing.T) { + tst.AssertTrue(t, CompressionPNGNone.Valid()) + tst.AssertTrue(t, CompressionPNGSpeed.Valid()) + tst.AssertTrue(t, CompressionPNGBest.Valid()) + tst.AssertTrue(t, CompressionJPEG100.Valid()) + tst.AssertTrue(t, CompressionJPEG1.Valid()) + + tst.AssertFalse(t, ImageCompresson("UNKNOWN").Valid()) + tst.AssertFalse(t, ImageCompresson("").Valid()) +} + +func TestImageCompresson_String(t *testing.T) { + tst.AssertEqual(t, CompressionPNGNone.String(), "PNG_NONE") + tst.AssertEqual(t, CompressionJPEG90.String(), "JPEG_090") +} + +func TestImageCompresson_VarName(t *testing.T) { + tst.AssertEqual(t, CompressionJPEG50.VarName(), "CompressionJPEG50") + tst.AssertEqual(t, ImageCompresson("UNKNOWN").VarName(), "") +} + +func TestImageCompresson_TypeName(t *testing.T) { + tst.AssertEqual(t, CompressionJPEG50.TypeName(), "ImageCompresson") +} + +func TestImageCompresson_Values(t *testing.T) { + values := ImageCompressonValues() + tst.AssertEqual(t, len(values), 12) +} + +func TestImageCompresson_ValuesAny(t *testing.T) { + tst.AssertEqual(t, len(CompressionPNGBest.ValuesAny()), 12) +} + +func TestImageCompresson_ValuesMeta(t *testing.T) { + meta := ImageCompressonValuesMeta() + tst.AssertEqual(t, len(meta), 12) +} + +func TestParseImageCompresson(t *testing.T) { + v, ok := ParseImageCompresson("PNG_BEST") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v, CompressionPNGBest) + + v, ok = ParseImageCompresson("JPEG_080") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, v, CompressionJPEG80) + + _, ok = ParseImageCompresson("BOGUS") + tst.AssertFalse(t, ok) +} + +func TestAllPackageEnums(t *testing.T) { + enums := AllPackageEnums() + tst.AssertEqual(t, len(enums), 2) +} diff --git a/imageext/image_test.go b/imageext/image_test.go new file mode 100644 index 0000000..d7c5e51 --- /dev/null +++ b/imageext/image_test.go @@ -0,0 +1,302 @@ +package imageext + +import ( + "bytes" + "image" + "image/color" + "image/jpeg" + "image/png" + "strings" + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +// makeRGBA creates a solid-color RGBA image of the given size. +func makeRGBA(w, h int, c color.Color) *image.RGBA { + img := image.NewRGBA(image.Rect(0, 0, w, h)) + for y := 0; y < h; y++ { + for x := 0; x < w; x++ { + img.Set(x, y, c) + } + } + return img +} + +// makeGradient creates an RGBA image with a gradient pattern. +// Useful for codecs that may behave oddly with uniform colors. +func makeGradient(w, h int) *image.RGBA { + img := image.NewRGBA(image.Rect(0, 0, w, h)) + for y := 0; y < h; y++ { + for x := 0; x < w; x++ { + img.Set(x, y, color.RGBA{ + R: uint8((x * 255) / max1(w-1)), + G: uint8((y * 255) / max1(h-1)), + B: uint8(((x + y) * 255) / max1(w+h-2)), + A: 255, + }) + } + } + return img +} + +func max1(v int) int { + if v <= 0 { + return 1 + } + return v +} + +func TestCropImage_HalfRegion(t *testing.T) { + src := makeGradient(100, 80) + + out, err := CropImage(src, 0.0, 0.0, 0.5, 0.5) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 50) + tst.AssertEqual(t, out.Bounds().Dy(), 40) +} + +func TestCropImage_Offset(t *testing.T) { + src := makeGradient(200, 100) + + out, err := CropImage(src, 0.25, 0.5, 0.5, 0.5) + tst.AssertNoErr(t, err) + + // SubImage preserves coordinates of the parent image. + tst.AssertEqual(t, out.Bounds().Min.X, 50) + tst.AssertEqual(t, out.Bounds().Min.Y, 50) + tst.AssertEqual(t, out.Bounds().Dx(), 100) + tst.AssertEqual(t, out.Bounds().Dy(), 50) +} + +func TestCropImage_Full(t *testing.T) { + src := makeGradient(40, 40) + + out, err := CropImage(src, 0.0, 0.0, 1.0, 1.0) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 40) + tst.AssertEqual(t, out.Bounds().Dy(), 40) +} + +func TestEncodeImage_AllCompressions(t *testing.T) { + src := makeGradient(20, 16) + + cases := []struct { + comp ImageCompresson + mime string + signature []byte + }{ + {CompressionPNGNone, "image/png", []byte{0x89, 'P', 'N', 'G'}}, + {CompressionPNGSpeed, "image/png", []byte{0x89, 'P', 'N', 'G'}}, + {CompressionPNGBest, "image/png", []byte{0x89, 'P', 'N', 'G'}}, + {CompressionJPEG100, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG90, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG80, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG70, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG60, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG50, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG25, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG10, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + {CompressionJPEG1, "image/jpeg", []byte{0xFF, 0xD8, 0xFF}}, + } + + for _, c := range cases { + t.Run(string(c.comp), func(t *testing.T) { + buf, mime, err := EncodeImage(src, c.comp) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, mime, c.mime) + tst.AssertTrue(t, buf.Len() > 0) + + data := buf.Bytes() + tst.AssertTrue(t, len(data) >= len(c.signature)) + tst.AssertTrue(t, bytes.Equal(data[:len(c.signature)], c.signature)) + + // Round-trip decode to confirm the bytes form a valid image. + var dec image.Image + var derr error + if c.mime == "image/png" { + dec, derr = png.Decode(bytes.NewReader(data)) + } else { + dec, derr = jpeg.Decode(bytes.NewReader(data)) + } + tst.AssertNoErr(t, derr) + tst.AssertEqual(t, dec.Bounds().Dx(), 20) + tst.AssertEqual(t, dec.Bounds().Dy(), 16) + }) + } +} + +func TestEncodeImage_UnknownCompression(t *testing.T) { + src := makeRGBA(4, 4, color.White) + + buf, mime, err := EncodeImage(src, ImageCompresson("UNKNOWN")) + tst.AssertTrue(t, err != nil) + tst.AssertEqual(t, mime, "") + tst.AssertEqual(t, buf.Len(), 0) +} + +func TestObjectFitImage_Cover_SmallerThanBB(t *testing.T) { + // Image (100x100) is smaller than the BB (200x100), so the output is + // scaled down to the smaller-axis factor of 0.5: 100x50. + src := makeGradient(100, 100) + + out, rect, err := ObjectFitImage(src, 200, 100, ImageFitCover, color.Transparent) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 100) + tst.AssertEqual(t, out.Bounds().Dy(), 50) + tst.AssertDeepEqual(t, rect, PercentageRectangle{0, 0, 1, 1}) +} + +func TestObjectFitImage_Cover_LargerThanBB(t *testing.T) { + // Image (400x200) is larger than the BB (200x100); fac is capped at 1, so + // the output is exactly the BB size. + src := makeGradient(400, 200) + + out, _, err := ObjectFitImage(src, 200, 100, ImageFitCover, color.Transparent) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 200) + tst.AssertEqual(t, out.Bounds().Dy(), 100) +} + +func TestObjectFitImage_ContainCenter(t *testing.T) { + // Image 100x100 in a BB of 200x100 -> output 200x100, drawn rect 100x100 centered. + src := makeGradient(100, 100) + + out, rect, err := ObjectFitImage(src, 200, 100, ImageFitContainCenter, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 200) + tst.AssertEqual(t, out.Bounds().Dy(), 100) + + // (200-100)/2 = 50 -> X=50/200 = 0.25, W=100/200=0.5, Y=0, H=1 + tst.AssertDeepEqual(t, rect, PercentageRectangle{X: 0.25, Y: 0, W: 0.5, H: 1}) +} + +func TestObjectFitImage_ContainTopLeft(t *testing.T) { + src := makeGradient(100, 100) + + out, rect, err := ObjectFitImage(src, 200, 100, ImageFitContainTopLeft, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 200) + tst.AssertEqual(t, out.Bounds().Dy(), 100) + tst.AssertDeepEqual(t, rect, PercentageRectangle{X: 0, Y: 0, W: 0.5, H: 1}) +} + +func TestObjectFitImage_ContainTopRight(t *testing.T) { + src := makeGradient(100, 100) + + out, rect, err := ObjectFitImage(src, 200, 100, ImageFitContainTopRight, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 200) + tst.AssertEqual(t, out.Bounds().Dy(), 100) + tst.AssertDeepEqual(t, rect, PercentageRectangle{X: 0.5, Y: 0, W: 0.5, H: 1}) +} + +func TestObjectFitImage_ContainBottomLeft(t *testing.T) { + // Image 200x100 in a BB of 100x100 (image is bigger so facOut is capped at 1) + src := makeGradient(200, 100) + + out, rect, err := ObjectFitImage(src, 100, 100, ImageFitContainBottomLeft, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 100) + tst.AssertEqual(t, out.Bounds().Dy(), 100) + // dw=100, dh=50 -> bottom-left rect: (0, 50, 100, 100) -> Y=0.5, H=0.5 + tst.AssertDeepEqual(t, rect, PercentageRectangle{X: 0, Y: 0.5, W: 1, H: 0.5}) +} + +func TestObjectFitImage_ContainBottomRight(t *testing.T) { + src := makeGradient(200, 100) + + out, rect, err := ObjectFitImage(src, 100, 100, ImageFitContainBottomRight, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 100) + tst.AssertEqual(t, out.Bounds().Dy(), 100) + tst.AssertDeepEqual(t, rect, PercentageRectangle{X: 0, Y: 0.5, W: 1, H: 0.5}) +} + +func TestObjectFitImage_Stretch(t *testing.T) { + // Image 100x100 in BB 200x100 -> uses max(facW=0.5, facH=1.0) capped at 1, so result is 200x100. + src := makeGradient(100, 100) + + out, rect, err := ObjectFitImage(src, 200, 100, ImageFitStretch, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 200) + tst.AssertEqual(t, out.Bounds().Dy(), 100) + tst.AssertDeepEqual(t, rect, PercentageRectangle{0, 0, 1, 1}) +} + +func TestObjectFitImage_Stretch_SmallImage(t *testing.T) { + // Image 50x25 in BB 200x100 -> max(0.25, 0.25) = 0.25, output 50x25. + src := makeGradient(50, 25) + + out, _, err := ObjectFitImage(src, 200, 100, ImageFitStretch, color.Black) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, out.Bounds().Dx(), 50) + tst.AssertEqual(t, out.Bounds().Dy(), 25) +} + +func TestObjectFitImage_UnknownFit(t *testing.T) { + src := makeGradient(20, 20) + + out, _, err := ObjectFitImage(src, 100, 100, ImageFit("BOGUS"), color.Black) + tst.AssertTrue(t, err != nil) + if out != nil { + t.Errorf("expected nil image on error, got %v", out) + } +} + +func TestVerifyAndDecodeImage_PNG(t *testing.T) { + src := makeGradient(12, 8) + buf := bytes.Buffer{} + tst.AssertNoErr(t, png.Encode(&buf, src)) + + out, err := VerifyAndDecodeImage(&buf, "image/png") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, out.Bounds().Dx(), 12) + tst.AssertEqual(t, out.Bounds().Dy(), 8) +} + +func TestVerifyAndDecodeImage_JPEG(t *testing.T) { + src := makeGradient(16, 16) + buf := bytes.Buffer{} + tst.AssertNoErr(t, jpeg.Encode(&buf, src, &jpeg.Options{Quality: 90})) + + out, err := VerifyAndDecodeImage(&buf, "image/jpeg") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, out.Bounds().Dx(), 16) + tst.AssertEqual(t, out.Bounds().Dy(), 16) +} + +func TestVerifyAndDecodeImage_UnknownMime(t *testing.T) { + out, err := VerifyAndDecodeImage(strings.NewReader("whatever"), "image/gif") + tst.AssertTrue(t, err != nil) + if out != nil { + t.Errorf("expected nil image on error, got %v", out) + } +} + +func TestVerifyAndDecodeImage_BadPNG(t *testing.T) { + out, err := VerifyAndDecodeImage(strings.NewReader("not a png"), "image/png") + tst.AssertTrue(t, err != nil) + if out != nil { + t.Errorf("expected nil image on error, got %v", out) + } +} + +func TestVerifyAndDecodeImage_BadJPEG(t *testing.T) { + out, err := VerifyAndDecodeImage(strings.NewReader("not a jpeg"), "image/jpeg") + tst.AssertTrue(t, err != nil) + if out != nil { + t.Errorf("expected nil image on error, got %v", out) + } +} diff --git a/imageext/types_test.go b/imageext/types_test.go new file mode 100644 index 0000000..4d3464a --- /dev/null +++ b/imageext/types_test.go @@ -0,0 +1,56 @@ +package imageext + +import ( + "image" + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestPercentageRectangle_Of_FullRef(t *testing.T) { + r := PercentageRectangle{X: 0.25, Y: 0.5, W: 0.5, H: 0.25} + ref := Rectangle{X: 0, Y: 0, W: 100, H: 200} + + got := r.Of(ref) + tst.AssertDeepEqual(t, got, Rectangle{X: 25, Y: 100, W: 50, H: 50}) +} + +func TestPercentageRectangle_Of_OffsetRef(t *testing.T) { + r := PercentageRectangle{X: 0.5, Y: 0.5, W: 0.5, H: 0.5} + ref := Rectangle{X: 10, Y: 20, W: 100, H: 100} + + got := r.Of(ref) + tst.AssertDeepEqual(t, got, Rectangle{X: 60, Y: 70, W: 50, H: 50}) +} + +func TestPercentageRectangle_Of_Identity(t *testing.T) { + r := PercentageRectangle{X: 0, Y: 0, W: 1, H: 1} + ref := Rectangle{X: 5, Y: 6, W: 7, H: 8} + + got := r.Of(ref) + tst.AssertDeepEqual(t, got, ref) +} + +func TestCalcRelativeRect_FullInner(t *testing.T) { + inner := image.Rect(0, 0, 100, 100) + outer := image.Rect(0, 0, 100, 100) + + got := calcRelativeRect(inner, outer) + tst.AssertDeepEqual(t, got, PercentageRectangle{X: 0, Y: 0, W: 1, H: 1}) +} + +func TestCalcRelativeRect_Centered(t *testing.T) { + inner := image.Rect(50, 25, 150, 75) + outer := image.Rect(0, 0, 200, 100) + + got := calcRelativeRect(inner, outer) + tst.AssertDeepEqual(t, got, PercentageRectangle{X: 0.25, Y: 0.25, W: 0.5, H: 0.5}) +} + +func TestCalcRelativeRect_OffsetOuter(t *testing.T) { + inner := image.Rect(120, 60, 220, 110) + outer := image.Rect(100, 50, 300, 150) + + got := calcRelativeRect(inner, outer) + tst.AssertDeepEqual(t, got, PercentageRectangle{X: 0.1, Y: 0.1, W: 0.5, H: 0.5}) +} diff --git a/langext/base62_test.go b/langext/base62_test.go new file mode 100644 index 0000000..b345335 --- /dev/null +++ b/langext/base62_test.go @@ -0,0 +1,86 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestEncodeBase62Zero(t *testing.T) { + tst.AssertEqual(t, EncodeBase62(0), "0") +} + +func TestEncodeBase62Small(t *testing.T) { + tst.AssertEqual(t, EncodeBase62(1), "1") + tst.AssertEqual(t, EncodeBase62(9), "9") + tst.AssertEqual(t, EncodeBase62(10), "A") + tst.AssertEqual(t, EncodeBase62(35), "Z") + tst.AssertEqual(t, EncodeBase62(36), "a") + tst.AssertEqual(t, EncodeBase62(61), "z") + tst.AssertEqual(t, EncodeBase62(62), "10") +} + +func TestDecodeBase62Empty(t *testing.T) { + _, err := DecodeBase62("") + if err == nil { + t.Errorf("expected error on empty input") + } +} + +func TestDecodeBase62Invalid(t *testing.T) { + _, err := DecodeBase62("foo!bar") + if err == nil { + t.Errorf("expected error on invalid character") + } +} + +func TestDecodeBase62Basic(t *testing.T) { + v, err := DecodeBase62("0") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, v, uint64(0)) + + v, err = DecodeBase62("10") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, v, uint64(62)) +} + +func TestEncodeDecodeBase62RoundTrip(t *testing.T) { + for _, n := range []uint64{0, 1, 61, 62, 100, 12345, 1<<32 - 1, 1 << 40} { + s := EncodeBase62(n) + v, err := DecodeBase62(s) + if err != nil { + t.Errorf("decode error for %d (encoded %q): %v", n, s, err) + continue + } + tst.AssertEqual(t, v, n) + } +} + +func TestRandBase62Length(t *testing.T) { + for _, l := range []int{0, 1, 8, 32, 64} { + s := RandBase62(l) + tst.AssertEqual(t, len(s), l) + } +} + +func TestRandBase62Alphabet(t *testing.T) { + const alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + s := RandBase62(256) + for _, r := range s { + if !strings.ContainsRune(alphabet, r) { + t.Errorf("character %q not in base62 alphabet", string(r)) + } + } +} + +func TestRandBase62Distinct(t *testing.T) { + a := RandBase62(32) + b := RandBase62(32) + if a == b { + t.Errorf("two base62 random strings of length 32 should not be equal") + } +} diff --git a/langext/bool_test.go b/langext/bool_test.go new file mode 100644 index 0000000..39b03da --- /dev/null +++ b/langext/bool_test.go @@ -0,0 +1,84 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestFormatBoolTrue(t *testing.T) { + tst.AssertEqual(t, FormatBool(true, "yes", "no"), "yes") +} + +func TestFormatBoolFalse(t *testing.T) { + tst.AssertEqual(t, FormatBool(false, "yes", "no"), "no") +} + +func TestConditional(t *testing.T) { + tst.AssertEqual(t, Conditional(true, 1, 2), 1) + tst.AssertEqual(t, Conditional(false, 1, 2), 2) + tst.AssertEqual(t, Conditional(true, "a", "b"), "a") + tst.AssertEqual(t, Conditional(false, "a", "b"), "b") +} + +func TestConditionalFn00(t *testing.T) { + tst.AssertEqual(t, ConditionalFn00(true, 10, 20), 10) + tst.AssertEqual(t, ConditionalFn00(false, 10, 20), 20) +} + +func TestConditionalFn10Lazy(t *testing.T) { + called := false + v := ConditionalFn10(false, func() int { + called = true + return 1 + }, 99) + tst.AssertEqual(t, v, 99) + tst.AssertEqual(t, called, false) + + v = ConditionalFn10(true, func() int { + called = true + return 1 + }, 99) + tst.AssertEqual(t, v, 1) + tst.AssertEqual(t, called, true) +} + +func TestConditionalFn01Lazy(t *testing.T) { + called := false + v := ConditionalFn01(true, 1, func() int { + called = true + return 99 + }) + tst.AssertEqual(t, v, 1) + tst.AssertEqual(t, called, false) + + v = ConditionalFn01(false, 1, func() int { + called = true + return 99 + }) + tst.AssertEqual(t, v, 99) + tst.AssertEqual(t, called, true) +} + +func TestConditionalFn11Lazy(t *testing.T) { + calledT := false + calledF := false + + v := ConditionalFn11(true, + func() int { calledT = true; return 1 }, + func() int { calledF = true; return 2 }, + ) + tst.AssertEqual(t, v, 1) + tst.AssertEqual(t, calledT, true) + tst.AssertEqual(t, calledF, false) + + calledT = false + calledF = false + + v = ConditionalFn11(false, + func() int { calledT = true; return 1 }, + func() int { calledF = true; return 2 }, + ) + tst.AssertEqual(t, v, 2) + tst.AssertEqual(t, calledT, false) + tst.AssertEqual(t, calledF, true) +} diff --git a/langext/bytes_test.go b/langext/bytes_test.go new file mode 100644 index 0000000..791e864 --- /dev/null +++ b/langext/bytes_test.go @@ -0,0 +1,68 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestFormatBytesToSI(t *testing.T) { + tst.AssertEqual(t, FormatBytesToSI(0), "0 B") + tst.AssertEqual(t, FormatBytesToSI(999), "999 B") + tst.AssertEqual(t, FormatBytesToSI(1000), "1.0 kB") + tst.AssertEqual(t, FormatBytesToSI(1500), "1.5 kB") + tst.AssertEqual(t, FormatBytesToSI(1000*1000), "1.0 MB") + tst.AssertEqual(t, FormatBytesToSI(1000*1000*1000), "1.0 GB") +} + +func TestFormatBytes(t *testing.T) { + tst.AssertEqual(t, FormatBytes(0), "0 B") + tst.AssertEqual(t, FormatBytes(1023), "1023 B") + tst.AssertEqual(t, FormatBytes(1024), "1.0 KiB") + tst.AssertEqual(t, FormatBytes(1024*1024), "1.0 MiB") + tst.AssertEqual(t, FormatBytes(1024*1024*1024), "1.0 GiB") + tst.AssertEqual(t, FormatBytes(1536), "1.5 KiB") +} + +func TestBytesXOR(t *testing.T) { + a := []byte{0x01, 0x02, 0x03, 0xFF} + b := []byte{0xFF, 0xFE, 0xFD, 0x00} + + r, err := BytesXOR(a, b) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []byte{0xFE, 0xFC, 0xFE, 0xFF} + tst.AssertArrayEqual(t, r, expected) +} + +func TestBytesXORLengthMismatch(t *testing.T) { + a := []byte{0x01, 0x02} + b := []byte{0x01, 0x02, 0x03} + + _, err := BytesXOR(a, b) + if err == nil { + t.Fatalf("expected error on length mismatch, got nil") + } +} + +func TestBytesXOREmpty(t *testing.T) { + r, err := BytesXOR([]byte{}, []byte{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(r), 0) +} + +func TestBytesXORSelfIsZero(t *testing.T) { + a := []byte{0xAB, 0xCD, 0xEF} + r, err := BytesXOR(a, a) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + for i, v := range r { + if v != 0 { + t.Errorf("expected zero at index %d, got %#x", i, v) + } + } +} diff --git a/langext/coalesce_test.go b/langext/coalesce_test.go new file mode 100644 index 0000000..4f9bae0 --- /dev/null +++ b/langext/coalesce_test.go @@ -0,0 +1,134 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" + "time" +) + +type stringerImpl struct{ v string } + +func (s stringerImpl) String() string { return s.v } + +func TestCoalesceWithValue(t *testing.T) { + v := 42 + tst.AssertEqual(t, Coalesce(&v, 0), 42) +} + +func TestCoalesceWithNil(t *testing.T) { + var p *int + tst.AssertEqual(t, Coalesce(p, 99), 99) +} + +func TestCoalesceOpt(t *testing.T) { + v := 1 + w := 2 + tst.AssertDeRefEqual(t, CoalesceOpt(&v, &w), 1) + + var p *int + tst.AssertDeRefEqual(t, CoalesceOpt(p, &w), 2) + + tst.AssertPtrEqual(t, CoalesceOpt[int](nil, nil), nil) +} + +func TestCoalesce3(t *testing.T) { + v := 1 + w := 2 + tst.AssertEqual(t, Coalesce3(&v, &w, 99), 1) + tst.AssertEqual(t, Coalesce3[int](nil, &w, 99), 2) + tst.AssertEqual(t, Coalesce3[int](nil, nil, 99), 99) +} + +func TestCoalesce3Opt(t *testing.T) { + v := 1 + tst.AssertDeRefEqual(t, Coalesce3Opt[int](nil, nil, &v), 1) + tst.AssertPtrEqual(t, Coalesce3Opt[int](nil, nil, nil), nil) +} + +func TestCoalesce4(t *testing.T) { + v := 1 + w := 2 + x := 3 + tst.AssertEqual(t, Coalesce4(&v, &w, &x, 99), 1) + tst.AssertEqual(t, Coalesce4[int](nil, &w, &x, 99), 2) + tst.AssertEqual(t, Coalesce4[int](nil, nil, &x, 99), 3) + tst.AssertEqual(t, Coalesce4[int](nil, nil, nil, 99), 99) +} + +func TestCoalesce4Opt(t *testing.T) { + v := 4 + tst.AssertDeRefEqual(t, Coalesce4Opt[int](nil, nil, nil, &v), 4) + tst.AssertPtrEqual(t, Coalesce4Opt[int](nil, nil, nil, nil), nil) +} + +func TestCoalesceString(t *testing.T) { + s := "hello" + tst.AssertEqual(t, CoalesceString(&s, "def"), "hello") + tst.AssertEqual(t, CoalesceString(nil, "def"), "def") +} + +func TestCoalesceInt(t *testing.T) { + v := 5 + tst.AssertEqual(t, CoalesceInt(&v, 99), 5) + tst.AssertEqual(t, CoalesceInt(nil, 99), 99) +} + +func TestCoalesceInt32(t *testing.T) { + v := int32(7) + tst.AssertEqual(t, CoalesceInt32(&v, 99), int32(7)) + tst.AssertEqual(t, CoalesceInt32(nil, 99), int32(99)) +} + +func TestCoalesceBool(t *testing.T) { + v := true + tst.AssertEqual(t, CoalesceBool(&v, false), true) + tst.AssertEqual(t, CoalesceBool(nil, true), true) + tst.AssertEqual(t, CoalesceBool(nil, false), false) +} + +func TestCoalesceTime(t *testing.T) { + now := time.Now() + def := time.Unix(0, 0) + tst.AssertEqual(t, CoalesceTime(&now, def), now) + tst.AssertEqual(t, CoalesceTime(nil, def), def) +} + +func TestCoalesceStringer(t *testing.T) { + s := stringerImpl{v: "hi"} + tst.AssertEqual(t, CoalesceStringer(s, "def"), "hi") + + var nilStringer *stringerImpl + tst.AssertEqual(t, CoalesceStringer(nilStringer, "def"), "def") +} + +func TestCoalesceDefault(t *testing.T) { + tst.AssertEqual(t, CoalesceDefault(0, 99), 99) + tst.AssertEqual(t, CoalesceDefault(5, 99), 5) + tst.AssertEqual(t, CoalesceDefault("", "def"), "def") + tst.AssertEqual(t, CoalesceDefault("v", "def"), "v") +} + +func TestSafeCastMatching(t *testing.T) { + var v any = "hello" + tst.AssertEqual(t, SafeCast(v, "default"), "hello") +} + +func TestSafeCastMismatch(t *testing.T) { + var v any = 42 + tst.AssertEqual(t, SafeCast(v, "default"), "default") +} + +func TestSafeCastNil(t *testing.T) { + tst.AssertEqual(t, SafeCast(nil, 99), 99) +} + +func TestCoalesceDblPtrWithValue(t *testing.T) { + v := 1 + pv := &v + tst.AssertDeRefEqual(t, CoalesceDblPtr(&pv, nil), 1) +} + +func TestCoalesceDblPtrFallback(t *testing.T) { + w := 2 + tst.AssertDeRefEqual(t, CoalesceDblPtr[int](nil, &w), 2) +} diff --git a/langext/compare_test.go b/langext/compare_test.go new file mode 100644 index 0000000..131419a --- /dev/null +++ b/langext/compare_test.go @@ -0,0 +1,63 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestCompareIntArrLess(t *testing.T) { + tst.AssertEqual(t, CompareIntArr([]int{1, 2, 3}, []int{1, 2, 4}), true) + tst.AssertEqual(t, CompareIntArr([]int{0}, []int{1}), true) +} + +func TestCompareIntArrGreater(t *testing.T) { + tst.AssertEqual(t, CompareIntArr([]int{1, 2, 5}, []int{1, 2, 4}), false) + tst.AssertEqual(t, CompareIntArr([]int{2}, []int{1}), false) +} + +func TestCompareIntArrEqual(t *testing.T) { + tst.AssertEqual(t, CompareIntArr([]int{1, 2, 3}, []int{1, 2, 3}), false) + tst.AssertEqual(t, CompareIntArr([]int{}, []int{}), false) +} + +func TestCompareArrLess(t *testing.T) { + tst.AssertEqual(t, CompareArr([]int{1, 2, 3}, []int{1, 2, 4}), -1) +} + +func TestCompareArrGreater(t *testing.T) { + r := CompareArr([]int{1, 2, 5}, []int{1, 2, 4}) + if r <= 0 { + t.Errorf("expected positive, got %d", r) + } +} + +func TestCompareArrEqual(t *testing.T) { + tst.AssertEqual(t, CompareArr([]int{1, 2, 3}, []int{1, 2, 3}), 0) + tst.AssertEqual(t, CompareArr([]int{}, []int{}), 0) +} + +func TestCompareString(t *testing.T) { + tst.AssertEqual(t, CompareString("a", "b"), -1) + tst.AssertEqual(t, CompareString("b", "a"), 1) + tst.AssertEqual(t, CompareString("a", "a"), 0) +} + +func TestCompareInt(t *testing.T) { + tst.AssertEqual(t, CompareInt(1, 2), -1) + tst.AssertEqual(t, CompareInt(2, 1), 1) + tst.AssertEqual(t, CompareInt(2, 2), 0) +} + +func TestCompareInt64(t *testing.T) { + tst.AssertEqual(t, CompareInt64(int64(1), int64(2)), -1) + tst.AssertEqual(t, CompareInt64(int64(2), int64(1)), 1) + tst.AssertEqual(t, CompareInt64(int64(0), int64(0)), 0) +} + +func TestCompareGeneric(t *testing.T) { + tst.AssertEqual(t, Compare(1, 2), -1) + tst.AssertEqual(t, Compare(2, 1), 1) + tst.AssertEqual(t, Compare(2, 2), 0) + tst.AssertEqual(t, Compare("x", "y"), -1) + tst.AssertEqual(t, Compare(3.5, 1.2), 1) +} diff --git a/langext/coords_test.go b/langext/coords_test.go new file mode 100644 index 0000000..938d08b --- /dev/null +++ b/langext/coords_test.go @@ -0,0 +1,59 @@ +package langext + +import ( + "math" + "testing" +) + +func floatEquals(a, b, eps float64) bool { + return math.Abs(a-b) < eps +} + +func TestDegToRadZero(t *testing.T) { + if !floatEquals(DegToRad(0), 0, 1e-9) { + t.Errorf("expected 0, got %v", DegToRad(0)) + } +} + +func TestDegToRad180(t *testing.T) { + if !floatEquals(DegToRad(180), math.Pi, 1e-9) { + t.Errorf("expected Pi, got %v", DegToRad(180)) + } +} + +func TestDegToRad90(t *testing.T) { + if !floatEquals(DegToRad(90), math.Pi/2, 1e-9) { + t.Errorf("expected Pi/2, got %v", DegToRad(90)) + } +} + +func TestRadToDegZero(t *testing.T) { + // note: function is implemented as rad / (Pi*180), tests document actual behavior + if !floatEquals(RadToDeg(0), 0, 1e-9) { + t.Errorf("expected 0, got %v", RadToDeg(0)) + } +} + +func TestGeoDistanceSamePoint(t *testing.T) { + d := GeoDistance(10.0, 50.0, 10.0, 50.0) + if !floatEquals(d, 0, 1e-3) { + t.Errorf("expected 0, got %v", d) + } +} + +func TestGeoDistancePositive(t *testing.T) { + // Berlin (~52.5200, 13.4050) to Munich (~48.1351, 11.5820) + d := GeoDistance(13.4050, 52.5200, 11.5820, 48.1351) + // Distance should be around ~500km - just check the order of magnitude. + if d < 400000 || d > 700000 { + t.Errorf("Berlin-Munich distance unexpected: got %v", d) + } +} + +func TestGeoDistanceSymmetric(t *testing.T) { + d1 := GeoDistance(10.0, 50.0, 11.0, 51.0) + d2 := GeoDistance(11.0, 51.0, 10.0, 50.0) + if !floatEquals(d1, d2, 1e-3) { + t.Errorf("expected symmetry, got %v != %v", d1, d2) + } +} diff --git a/langext/func_test.go b/langext/func_test.go new file mode 100644 index 0000000..18cef10 --- /dev/null +++ b/langext/func_test.go @@ -0,0 +1,22 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestFuncChain(t *testing.T) { + addOne := func(v int) int { return v + 1 } + timesTwo := func(v int) int { return v * 2 } + + chained := FuncChain(addOne, timesTwo) + tst.AssertEqual(t, chained(3), 8) +} + +func TestFuncChainOrder(t *testing.T) { + first := func(v string) string { return v + "A" } + second := func(v string) string { return v + "B" } + + chained := FuncChain(first, second) + tst.AssertEqual(t, chained("X"), "XAB") +} diff --git a/langext/io_test.go b/langext/io_test.go new file mode 100644 index 0000000..7a1a732 --- /dev/null +++ b/langext/io_test.go @@ -0,0 +1,36 @@ +package langext + +import ( + "bytes" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestWriteNopCloserWrite(t *testing.T) { + var buf bytes.Buffer + wc := WriteNopCloser(&buf) + + n, err := wc.Write([]byte("hello")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, n, 5) + tst.AssertEqual(t, buf.String(), "hello") +} + +func TestWriteNopCloserClose(t *testing.T) { + var buf bytes.Buffer + wc := WriteNopCloser(&buf) + + err := wc.Close() + if err != nil { + t.Errorf("expected nil error from no-op Close, got %v", err) + } + + // Can still write after close (it's a no-op) + _, err = wc.Write([]byte("after")) + if err != nil { + t.Errorf("expected to write after Close, got %v", err) + } + tst.AssertEqual(t, buf.String(), "after") +} diff --git a/langext/iter_test.go b/langext/iter_test.go new file mode 100644 index 0000000..cc0289d --- /dev/null +++ b/langext/iter_test.go @@ -0,0 +1,61 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestIterSingleValueSeq(t *testing.T) { + seq := IterSingleValueSeq(42) + + count := 0 + var got int + for v := range seq { + got = v + count++ + } + + tst.AssertEqual(t, count, 1) + tst.AssertEqual(t, got, 42) +} + +func TestIterSingleValueSeqString(t *testing.T) { + seq := IterSingleValueSeq("hello") + + values := make([]string, 0) + for v := range seq { + values = append(values, v) + } + + tst.AssertEqual(t, len(values), 1) + tst.AssertEqual(t, values[0], "hello") +} + +func TestIterSingleValueSeq2(t *testing.T) { + seq := IterSingleValueSeq2("key", 42) + + count := 0 + var k string + var v int + for kk, vv := range seq { + k = kk + v = vv + count++ + } + + tst.AssertEqual(t, count, 1) + tst.AssertEqual(t, k, "key") + tst.AssertEqual(t, v, 42) +} + +func TestIterSingleValueSeqEarlyBreak(t *testing.T) { + seq := IterSingleValueSeq(1) + + count := 0 + for range seq { + count++ + break + } + + tst.AssertEqual(t, count, 1) +} diff --git a/langext/json_test.go b/langext/json_test.go new file mode 100644 index 0000000..0511719 --- /dev/null +++ b/langext/json_test.go @@ -0,0 +1,192 @@ +package langext + +import ( + "encoding/json" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestTryPrettyPrintJsonValid(t *testing.T) { + in := `{"a":1,"b":2}` + out := TryPrettyPrintJson(in) + if !strings.Contains(out, "\n") { + t.Errorf("expected pretty-printed result with newlines, got %q", out) + } + if !strings.Contains(out, `"a"`) { + t.Errorf("expected key in result, got %q", out) + } +} + +func TestTryPrettyPrintJsonInvalidPassThrough(t *testing.T) { + in := `not valid json` + tst.AssertEqual(t, TryPrettyPrintJson(in), in) +} + +func TestPrettyPrintJsonValid(t *testing.T) { + out, ok := PrettyPrintJson(`{"a":1}`) + tst.AssertEqual(t, ok, true) + if !strings.Contains(out, "\n") { + t.Errorf("expected formatted output, got %q", out) + } +} + +func TestPrettyPrintJsonInvalid(t *testing.T) { + in := `not json` + out, ok := PrettyPrintJson(in) + tst.AssertEqual(t, ok, false) + tst.AssertEqual(t, out, in) +} + +func TestPatchJsonString(t *testing.T) { + in := `{"a":1,"b":2}` + out, err := PatchJson(in, "c", 3) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var m map[string]any + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("invalid json result: %v", err) + } + if v, ok := m["c"].(float64); !ok || v != 3 { + t.Errorf("expected c=3, got %v", m["c"]) + } + if v, ok := m["a"].(float64); !ok || v != 1 { + t.Errorf("expected a=1, got %v", m["a"]) + } +} + +func TestPatchJsonBytes(t *testing.T) { + in := []byte(`{"a":1}`) + out, err := PatchJson(in, "b", "hello") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]any + if err := json.Unmarshal(out, &m); err != nil { + t.Fatalf("invalid json result: %v", err) + } + if v, ok := m["b"].(string); !ok || v != "hello" { + t.Errorf("expected b=hello, got %v", m["b"]) + } +} + +func TestPatchJsonInvalid(t *testing.T) { + _, err := PatchJson("not json", "k", "v") + if err == nil { + t.Errorf("expected error on invalid json") + } +} + +func TestPatchRemJson(t *testing.T) { + in := `{"a":1,"b":2}` + out, err := PatchRemJson(in, "a") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]any + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("invalid json result: %v", err) + } + if _, exists := m["a"]; exists { + t.Errorf("expected key 'a' to be removed") + } + if v, ok := m["b"].(float64); !ok || v != 2 { + t.Errorf("expected b=2, got %v", m["b"]) + } +} + +func TestPatchRemJsonMissingKey(t *testing.T) { + in := `{"a":1}` + out, err := PatchRemJson(in, "missing") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + var m map[string]any + if err := json.Unmarshal([]byte(out), &m); err != nil { + t.Fatalf("invalid json: %v", err) + } + if v, ok := m["a"].(float64); !ok || v != 1 { + t.Errorf("expected a=1, got %v", m["a"]) + } +} + +func TestMarshalJsonOrPanic(t *testing.T) { + tst.AssertEqual(t, MarshalJsonOrPanic(42), "42") + tst.AssertEqual(t, MarshalJsonOrPanic("hi"), `"hi"`) +} + +func TestMarshalJsonOrPanicPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on un-marshalable input") + } + }() + // channels can't be marshaled + MarshalJsonOrPanic(make(chan int)) +} + +func TestMarshalJsonOrDefault(t *testing.T) { + tst.AssertEqual(t, MarshalJsonOrDefault(42, "def"), "42") + tst.AssertEqual(t, MarshalJsonOrDefault(make(chan int), "def"), "def") +} + +func TestMarshalJsonOrNilSuccess(t *testing.T) { + p := MarshalJsonOrNil(42) + if p == nil { + t.Fatalf("expected non-nil pointer") + } + tst.AssertEqual(t, *p, "42") +} + +func TestMarshalJsonOrNilError(t *testing.T) { + p := MarshalJsonOrNil(make(chan int)) + if p != nil { + t.Errorf("expected nil pointer on error, got %v", *p) + } +} + +func TestMarshalJsonIndentOrPanic(t *testing.T) { + out := MarshalJsonIndentOrPanic(map[string]int{"a": 1}, "", " ") + if !strings.Contains(out, "\n") { + t.Errorf("expected indented output, got %q", out) + } +} + +func TestMarshalJsonIndentOrDefault(t *testing.T) { + out := MarshalJsonIndentOrDefault(make(chan int), "", " ", "DEF") + tst.AssertEqual(t, out, "DEF") +} + +func TestMarshalJsonIndentOrNilSuccess(t *testing.T) { + p := MarshalJsonIndentOrNil(map[string]int{"a": 1}, "", " ") + if p == nil || !strings.Contains(*p, "\n") { + t.Errorf("expected indented JSON pointer") + } +} + +func TestMarshalJsonIndentOrNilFailure(t *testing.T) { + p := MarshalJsonIndentOrNil(make(chan int), "", " ") + if p != nil { + t.Errorf("expected nil pointer on error") + } +} + +func TestHTypeIsMap(t *testing.T) { + h := H{"a": 1} + out, err := json.Marshal(h) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, string(out), `{"a":1}`) +} + +func TestATypeIsArray(t *testing.T) { + a := A{1, "x", true} + out, err := json.Marshal(a) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, string(out), `[1,"x",true]`) +} diff --git a/langext/maps_test.go b/langext/maps_test.go new file mode 100644 index 0000000..a8943ba --- /dev/null +++ b/langext/maps_test.go @@ -0,0 +1,146 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "sort" + "testing" +) + +func TestMapKeyArr(t *testing.T) { + m := map[string]int{"a": 1, "b": 2, "c": 3} + keys := MapKeyArr(m) + sort.Strings(keys) + tst.AssertArrayEqual(t, keys, []string{"a", "b", "c"}) +} + +func TestMapKeyArrEmpty(t *testing.T) { + m := map[string]int{} + keys := MapKeyArr(m) + tst.AssertEqual(t, len(keys), 0) +} + +func TestMapValueArr(t *testing.T) { + m := map[string]int{"a": 1, "b": 2, "c": 3} + values := MapValueArr(m) + sort.Ints(values) + tst.AssertArrayEqual(t, values, []int{1, 2, 3}) +} + +func TestArrToMap(t *testing.T) { + type item struct { + key string + v int + } + arr := []item{{"a", 1}, {"b", 2}} + m := ArrToMap(arr, func(i item) string { return i.key }) + tst.AssertEqual(t, len(m), 2) + tst.AssertEqual(t, m["a"].v, 1) + tst.AssertEqual(t, m["b"].v, 2) +} + +func TestArrToKVMap(t *testing.T) { + arr := []int{1, 2, 3} + m := ArrToKVMap(arr, + func(v int) int { return v }, + func(v int) string { + return [...]string{"", "one", "two", "three"}[v] + }, + ) + tst.AssertEqual(t, m[1], "one") + tst.AssertEqual(t, m[2], "two") + tst.AssertEqual(t, m[3], "three") +} + +func TestArrToSet(t *testing.T) { + arr := []string{"a", "b", "a", "c"} + set := ArrToSet(arr) + tst.AssertEqual(t, len(set), 3) + tst.AssertEqual(t, set["a"], true) + tst.AssertEqual(t, set["b"], true) + tst.AssertEqual(t, set["c"], true) + tst.AssertEqual(t, set["d"], false) +} + +func TestMapToArr(t *testing.T) { + m := map[string]int{"a": 1, "b": 2} + arr := MapToArr(m) + tst.AssertEqual(t, len(arr), 2) + roundTrip := make(map[string]int) + for _, e := range arr { + roundTrip[e.Key] = e.Value + } + tst.AssertEqual(t, roundTrip["a"], 1) + tst.AssertEqual(t, roundTrip["b"], 2) +} + +func TestCopyMap(t *testing.T) { + src := map[string]int{"a": 1, "b": 2} + dst := CopyMap(src) + tst.AssertEqual(t, len(dst), 2) + tst.AssertEqual(t, dst["a"], 1) + + // Mutating dst should not affect src + dst["a"] = 99 + tst.AssertEqual(t, src["a"], 1) +} + +func TestForceMapNil(t *testing.T) { + var m map[string]int + res := ForceMap(m) + if res == nil { + t.Errorf("expected non-nil result") + } + tst.AssertEqual(t, len(res), 0) +} + +func TestForceMapNonNil(t *testing.T) { + m := map[string]int{"x": 1} + res := ForceMap(m) + tst.AssertEqual(t, res["x"], 1) +} + +func TestForceJsonMapOrPanic(t *testing.T) { + type s struct { + A int `json:"a"` + B string `json:"b"` + } + res := ForceJsonMapOrPanic(s{A: 1, B: "x"}) + if v, ok := res["a"].(float64); !ok || v != 1 { + t.Errorf("expected a=1, got %v", res["a"]) + } + if v, ok := res["b"].(string); !ok || v != "x" { + t.Errorf("expected b=x, got %v", res["b"]) + } +} + +func TestForceJsonMapOrPanicPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on un-marshalable input") + } + }() + ForceJsonMapOrPanic(make(chan int)) +} + +func TestMapMerge(t *testing.T) { + base := map[string]int{"a": 1, "b": 2} + a := map[string]int{"b": 22, "c": 3} + b := map[string]int{"d": 4} + + res := MapMerge(base, a, b) + tst.AssertEqual(t, res["a"], 1) + tst.AssertEqual(t, res["b"], 22) // overwritten + tst.AssertEqual(t, res["c"], 3) + tst.AssertEqual(t, res["d"], 4) + tst.AssertEqual(t, len(res), 4) + + // base must remain untouched + tst.AssertEqual(t, base["b"], 2) +} + +func TestMapMergeNoExtras(t *testing.T) { + base := map[string]int{"a": 1} + res := MapMerge(base) + tst.AssertEqual(t, res["a"], 1) + tst.AssertEqual(t, len(res), 1) +} diff --git a/langext/must_test.go b/langext/must_test.go new file mode 100644 index 0000000..c550b10 --- /dev/null +++ b/langext/must_test.go @@ -0,0 +1,38 @@ +package langext + +import ( + "errors" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestMustSuccess(t *testing.T) { + v := Must(42, nil) + tst.AssertEqual(t, v, 42) + + s := Must("hello", nil) + tst.AssertEqual(t, s, "hello") +} + +func TestMustPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on error") + } + }() + Must(0, errors.New("boom")) +} + +func TestMustBoolSuccess(t *testing.T) { + v := MustBool(42, true) + tst.AssertEqual(t, v, 42) +} + +func TestMustBoolPanics(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic on not ok") + } + }() + MustBool(0, false) +} diff --git a/langext/object_test.go b/langext/object_test.go new file mode 100644 index 0000000..da8fbae --- /dev/null +++ b/langext/object_test.go @@ -0,0 +1,43 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestDeepCopyByJsonStruct(t *testing.T) { + type item struct { + Name string `json:"name"` + Age int `json:"age"` + } + src := item{Name: "alice", Age: 30} + dst, err := DeepCopyByJson(src) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, dst.Name, "alice") + tst.AssertEqual(t, dst.Age, 30) +} + +func TestDeepCopyByJsonSlice(t *testing.T) { + src := []int{1, 2, 3} + dst, err := DeepCopyByJson(src) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertArrayEqual(t, dst, []int{1, 2, 3}) + + // Mutating the copy must not affect the source + dst[0] = 99 + tst.AssertEqual(t, src[0], 1) +} + +func TestDeepCopyByJsonError(t *testing.T) { + type bad struct { + C chan int + } + _, err := DeepCopyByJson(bad{C: make(chan int)}) + if err == nil { + t.Errorf("expected error for un-marshalable type") + } +} diff --git a/langext/os_test.go b/langext/os_test.go new file mode 100644 index 0000000..d3ce33c --- /dev/null +++ b/langext/os_test.go @@ -0,0 +1,28 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "os" + "path/filepath" + "testing" +) + +func TestFileExistsTrue(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "f.txt") + if err := os.WriteFile(path, []byte("hi"), 0o644); err != nil { + t.Fatalf("setup failed: %v", err) + } + tst.AssertEqual(t, FileExists(path), true) +} + +func TestFileExistsFalse(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "missing.txt") + tst.AssertEqual(t, FileExists(path), false) +} + +func TestFileExistsDirectoryReturnsFalse(t *testing.T) { + dir := t.TempDir() + tst.AssertEqual(t, FileExists(dir), false) +} diff --git a/langext/panic_test.go b/langext/panic_test.go new file mode 100644 index 0000000..1bcb353 --- /dev/null +++ b/langext/panic_test.go @@ -0,0 +1,121 @@ +package langext + +import ( + "errors" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestRunPanicSafeNoPanic(t *testing.T) { + called := false + err := RunPanicSafe(func() { + called = true + }) + tst.AssertEqual(t, called, true) + if err != nil { + t.Errorf("expected nil err, got %v", err) + } +} + +func TestRunPanicSafeRecovers(t *testing.T) { + err := RunPanicSafe(func() { + panic("boom") + }) + if err == nil { + t.Fatalf("expected error from panic") + } + pwe, ok := err.(PanicWrappedErr) + if !ok { + t.Fatalf("expected PanicWrappedErr, got %T", err) + } + tst.AssertEqual(t, pwe.RecoveredObj(), "boom") + tst.AssertEqual(t, pwe.Error(), "A panic occured") +} + +func TestRunPanicSafeR1NoPanic(t *testing.T) { + expected := errors.New("expected") + err := RunPanicSafeR1(func() error { + return expected + }) + if err != expected { + t.Errorf("expected original error, got %v", err) + } +} + +func TestRunPanicSafeR1Panics(t *testing.T) { + err := RunPanicSafeR1(func() error { + panic("boom") + }) + if err == nil { + t.Fatalf("expected wrapped panic") + } + if _, ok := err.(PanicWrappedErr); !ok { + t.Errorf("expected PanicWrappedErr, got %T", err) + } +} + +func TestRunPanicSafeR2NoPanic(t *testing.T) { + v, err := RunPanicSafeR2(func() (int, error) { + return 42, nil + }) + tst.AssertEqual(t, v, 42) + if err != nil { + t.Errorf("expected nil err, got %v", err) + } +} + +func TestRunPanicSafeR2Panics(t *testing.T) { + v, err := RunPanicSafeR2(func() (int, error) { + panic("boom") + }) + tst.AssertEqual(t, v, 0) // zero value + if err == nil { + t.Errorf("expected wrapped panic") + } +} + +func TestRunPanicSafeR3NoPanic(t *testing.T) { + a, b, err := RunPanicSafeR3(func() (int, string, error) { + return 1, "two", nil + }) + tst.AssertEqual(t, a, 1) + tst.AssertEqual(t, b, "two") + if err != nil { + t.Errorf("expected nil err, got %v", err) + } +} + +func TestRunPanicSafeR3Panics(t *testing.T) { + a, b, err := RunPanicSafeR3(func() (int, string, error) { + panic("boom") + }) + tst.AssertEqual(t, a, 0) + tst.AssertEqual(t, b, "") + if err == nil { + t.Errorf("expected wrapped panic") + } +} + +func TestRunPanicSafeR4NoPanic(t *testing.T) { + a, b, c, err := RunPanicSafeR4(func() (int, string, bool, error) { + return 1, "two", true, nil + }) + tst.AssertEqual(t, a, 1) + tst.AssertEqual(t, b, "two") + tst.AssertEqual(t, c, true) + if err != nil { + t.Errorf("expected nil err, got %v", err) + } +} + +func TestRunPanicSafeR4Panics(t *testing.T) { + a, b, c, err := RunPanicSafeR4(func() (int, string, bool, error) { + panic("boom") + }) + tst.AssertEqual(t, a, 0) + tst.AssertEqual(t, b, "") + tst.AssertEqual(t, c, false) + if err == nil { + t.Errorf("expected wrapped panic") + } +} diff --git a/langext/pointer_test.go b/langext/pointer_test.go new file mode 100644 index 0000000..6482290 --- /dev/null +++ b/langext/pointer_test.go @@ -0,0 +1,143 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestPtr(t *testing.T) { + p := Ptr(42) + if p == nil { + t.Fatalf("expected non-nil") + } + tst.AssertEqual(t, *p, 42) +} + +func TestPtrString(t *testing.T) { + p := Ptr("hi") + tst.AssertEqual(t, *p, "hi") +} + +func TestPTrue(t *testing.T) { + if PTrue == nil || *PTrue != true { + t.Errorf("PTrue should point to true") + } +} + +func TestPFalse(t *testing.T) { + if PFalse == nil || *PFalse != false { + t.Errorf("PFalse should point to false") + } +} + +func TestDblPtr(t *testing.T) { + pp := DblPtr(7) + if pp == nil || *pp == nil { + t.Fatalf("expected non-nil double pointer") + } + tst.AssertEqual(t, **pp, 7) +} + +func TestDblPtrIfNotNilWithValue(t *testing.T) { + v := 5 + pp := DblPtrIfNotNil(&v) + if pp == nil { + t.Fatalf("expected non-nil double pointer") + } + tst.AssertEqual(t, **pp, 5) +} + +func TestDblPtrIfNotNilNil(t *testing.T) { + pp := DblPtrIfNotNil[int](nil) + if pp != nil { + t.Errorf("expected nil for nil input") + } +} + +func TestDblPtrNil(t *testing.T) { + pp := DblPtrNil[int]() + if pp == nil { + t.Fatalf("expected non-nil outer pointer") + } + if *pp != nil { + t.Errorf("expected inner pointer to be nil") + } +} + +func TestArrPtr(t *testing.T) { + p := ArrPtr(1, 2, 3) + if p == nil { + t.Fatalf("expected non-nil pointer") + } + tst.AssertArrayEqual(t, *p, []int{1, 2, 3}) +} + +func TestPtrInt32(t *testing.T) { + p := PtrInt32(7) + tst.AssertEqual(t, *p, int32(7)) +} + +func TestPtrInt64(t *testing.T) { + p := PtrInt64(7) + tst.AssertEqual(t, *p, int64(7)) +} + +func TestPtrFloat32(t *testing.T) { + p := PtrFloat32(1.5) + tst.AssertEqual(t, *p, float32(1.5)) +} + +func TestPtrFloat64(t *testing.T) { + p := PtrFloat64(2.5) + tst.AssertEqual(t, *p, 2.5) +} + +func TestIsNilTrue(t *testing.T) { + tst.AssertEqual(t, IsNil(nil), true) + + var p *int + tst.AssertEqual(t, IsNil(p), true) + + var m map[string]int + tst.AssertEqual(t, IsNil(m), true) + + var s []int + tst.AssertEqual(t, IsNil(s), true) + + var c chan int + tst.AssertEqual(t, IsNil(c), true) + + var f func() + tst.AssertEqual(t, IsNil(f), true) +} + +func TestIsNilFalse(t *testing.T) { + v := 5 + tst.AssertEqual(t, IsNil(&v), false) + tst.AssertEqual(t, IsNil(5), false) + tst.AssertEqual(t, IsNil("hi"), false) + tst.AssertEqual(t, IsNil(map[string]int{}), false) + tst.AssertEqual(t, IsNil([]int{}), false) +} + +func TestPtrEqualsBothNil(t *testing.T) { + tst.AssertEqual(t, PtrEquals[int](nil, nil), true) +} + +func TestPtrEqualsBothEqual(t *testing.T) { + a := 5 + b := 5 + tst.AssertEqual(t, PtrEquals(&a, &b), true) +} + +func TestPtrEqualsBothDifferent(t *testing.T) { + a := 5 + b := 6 + tst.AssertEqual(t, PtrEquals(&a, &b), false) +} + +func TestPtrEqualsOneNil(t *testing.T) { + a := 5 + tst.AssertEqual(t, PtrEquals(&a, nil), false) + tst.AssertEqual(t, PtrEquals[int](nil, &a), false) +} diff --git a/langext/rand_test.go b/langext/rand_test.go new file mode 100644 index 0000000..044cfbb --- /dev/null +++ b/langext/rand_test.go @@ -0,0 +1,34 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestRandBytesLength(t *testing.T) { + for _, sz := range []int{0, 1, 16, 32, 1024} { + b := RandBytes(sz) + tst.AssertEqual(t, len(b), sz) + } +} + +func TestRandBytesDistinct(t *testing.T) { + a := RandBytes(32) + b := RandBytes(32) + + // Two cryptographic random sequences should not be equal in 32 bytes. + if len(a) != 32 || len(b) != 32 { + t.Fatalf("unexpected length") + } + + equal := true + for i := range a { + if a[i] != b[i] { + equal = false + break + } + } + if equal { + t.Errorf("two consecutive 32-byte RandBytes calls returned identical results") + } +} diff --git a/langext/sort_test.go b/langext/sort_test.go new file mode 100644 index 0000000..c768178 --- /dev/null +++ b/langext/sort_test.go @@ -0,0 +1,118 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestSortInPlace(t *testing.T) { + arr := []int{3, 1, 2} + Sort(arr) + tst.AssertArrayEqual(t, arr, []int{1, 2, 3}) +} + +func TestSortStrings(t *testing.T) { + arr := []string{"c", "a", "b"} + Sort(arr) + tst.AssertArrayEqual(t, arr, []string{"a", "b", "c"}) +} + +func TestAsSorted(t *testing.T) { + src := []int{3, 1, 2} + out := AsSorted(src) + tst.AssertArrayEqual(t, out, []int{1, 2, 3}) + // original unchanged + tst.AssertArrayEqual(t, src, []int{3, 1, 2}) +} + +func TestSortStable(t *testing.T) { + arr := []int{3, 1, 2, 1} + SortStable(arr) + tst.AssertArrayEqual(t, arr, []int{1, 1, 2, 3}) +} + +func TestAsSortedStable(t *testing.T) { + src := []int{3, 1, 2, 1} + out := AsSortedStable(src) + tst.AssertArrayEqual(t, out, []int{1, 1, 2, 3}) + tst.AssertArrayEqual(t, src, []int{3, 1, 2, 1}) +} + +func TestIsSorted(t *testing.T) { + tst.AssertEqual(t, IsSorted([]int{1, 2, 3}), true) + tst.AssertEqual(t, IsSorted([]int{3, 2, 1}), false) + tst.AssertEqual(t, IsSorted([]int{1, 1, 1}), true) + tst.AssertEqual(t, IsSorted([]int{}), true) +} + +func TestSortSlice(t *testing.T) { + arr := []int{3, 1, 2} + SortSlice(arr, func(a, b int) bool { return a < b }) + tst.AssertArrayEqual(t, arr, []int{1, 2, 3}) +} + +func TestAsSortedSlice(t *testing.T) { + src := []int{3, 1, 2} + out := AsSortedSlice(src, func(a, b int) bool { return a > b }) + tst.AssertArrayEqual(t, out, []int{3, 2, 1}) + tst.AssertArrayEqual(t, src, []int{3, 1, 2}) +} + +func TestSortSliceStable(t *testing.T) { + arr := []int{3, 1, 2, 1} + SortSliceStable(arr, func(a, b int) bool { return a < b }) + tst.AssertArrayEqual(t, arr, []int{1, 1, 2, 3}) +} + +func TestAsSortedSliceStable(t *testing.T) { + src := []int{3, 1, 2, 1} + out := AsSortedSliceStable(src, func(a, b int) bool { return a < b }) + tst.AssertArrayEqual(t, out, []int{1, 1, 2, 3}) + tst.AssertArrayEqual(t, src, []int{3, 1, 2, 1}) +} + +func TestIsSliceSorted(t *testing.T) { + tst.AssertEqual(t, IsSliceSorted([]int{1, 2, 3}, func(a, b int) bool { return a < b }), true) + tst.AssertEqual(t, IsSliceSorted([]int{3, 2, 1}, func(a, b int) bool { return a < b }), false) +} + +type byKey struct { + key int + v string +} + +func TestSortBy(t *testing.T) { + arr := []byKey{{3, "c"}, {1, "a"}, {2, "b"}} + SortBy(arr, func(v byKey) int { return v.key }) + tst.AssertEqual(t, arr[0].v, "a") + tst.AssertEqual(t, arr[1].v, "b") + tst.AssertEqual(t, arr[2].v, "c") +} + +func TestAsSortedBy(t *testing.T) { + src := []byKey{{3, "c"}, {1, "a"}, {2, "b"}} + out := AsSortedBy(src, func(v byKey) int { return v.key }) + tst.AssertEqual(t, out[0].v, "a") + tst.AssertEqual(t, out[2].v, "c") + // source unchanged + tst.AssertEqual(t, src[0].v, "c") +} + +func TestSortByStable(t *testing.T) { + arr := []byKey{{1, "a1"}, {1, "a2"}, {0, "b"}} + SortByStable(arr, func(v byKey) int { return v.key }) + tst.AssertEqual(t, arr[0].v, "b") + // stable order for ties + tst.AssertEqual(t, arr[1].v, "a1") + tst.AssertEqual(t, arr[2].v, "a2") +} + +func TestAsSortedByStable(t *testing.T) { + src := []byKey{{1, "a1"}, {1, "a2"}, {0, "b"}} + out := AsSortedByStable(src, func(v byKey) int { return v.key }) + tst.AssertEqual(t, out[0].v, "b") + tst.AssertEqual(t, out[1].v, "a1") + tst.AssertEqual(t, out[2].v, "a2") + // source unchanged + tst.AssertEqual(t, src[0].v, "a1") +} diff --git a/langext/uuid_test.go b/langext/uuid_test.go new file mode 100644 index 0000000..a531c60 --- /dev/null +++ b/langext/uuid_test.go @@ -0,0 +1,132 @@ +package langext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "regexp" + "strings" + "testing" +) + +func TestNewUUIDLength(t *testing.T) { + u, err := NewUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(u), 16) +} + +func TestNewUUIDVersionAndVariant(t *testing.T) { + u, err := NewUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Version 4 is in upper nibble of byte 6 + tst.AssertEqual(t, u[6]&0xf0, byte(0x40)) + // Variant 10 in top two bits of byte 8 + tst.AssertEqual(t, u[8]&0xc0, byte(0x80)) +} + +func TestNewUUIDRandomness(t *testing.T) { + a, _ := NewUUID() + b, _ := NewUUID() + if a == b { + t.Errorf("two UUIDs should not be equal") + } +} + +var hexUUIDRegex = regexp.MustCompile(`^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$`) + +func TestNewHexUUIDFormat(t *testing.T) { + s, err := NewHexUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(s), 36) + if !hexUUIDRegex.MatchString(s) { + t.Errorf("not a valid hex UUID format: %q", s) + } +} + +func TestMustHexUUID(t *testing.T) { + s := MustHexUUID() + tst.AssertEqual(t, len(s), 36) + if !hexUUIDRegex.MatchString(s) { + t.Errorf("not a valid hex UUID format: %q", s) + } +} + +var upperHexRegex = regexp.MustCompile(`^[0-9A-F]{8}-[0-9A-F]{4}-4[0-9A-F]{3}-[89AB][0-9A-F]{3}-[0-9A-F]{12}$`) + +func TestNewUpperHexUUIDFormat(t *testing.T) { + s, err := NewUpperHexUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(s), 36) + tst.AssertEqual(t, s, strings.ToUpper(s)) + if !upperHexRegex.MatchString(s) { + t.Errorf("not a valid upper-hex UUID format: %q", s) + } +} + +func TestMustUpperHexUUID(t *testing.T) { + s := MustUpperHexUUID() + tst.AssertEqual(t, len(s), 36) +} + +func TestNewRawHexUUIDFormat(t *testing.T) { + s, err := NewRawHexUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(s), 32) + if strings.Contains(s, "-") { + t.Errorf("raw hex should have no dashes: %q", s) + } + tst.AssertEqual(t, s, strings.ToUpper(s)) +} + +func TestMustRawHexUUID(t *testing.T) { + s := MustRawHexUUID() + tst.AssertEqual(t, len(s), 32) +} + +func TestNewBracesUUID(t *testing.T) { + s, err := NewBracesUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(s), 38) + tst.AssertEqual(t, string(s[37]), "}") +} + +func TestMustBracesUUID(t *testing.T) { + s := MustBracesUUID() + tst.AssertEqual(t, len(s), 38) +} + +func TestNewParensUUID(t *testing.T) { + s, err := NewParensUUID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + tst.AssertEqual(t, len(s), 38) + tst.AssertEqual(t, string(s[37]), ")") +} + +func TestMustParensUUID(t *testing.T) { + s := MustParensUUID() + tst.AssertEqual(t, len(s), 38) +} + +func TestUUIDsAreUnique(t *testing.T) { + const count = 100 + seen := make(map[string]bool, count) + for range count { + s := MustHexUUID() + if seen[s] { + t.Errorf("collision in UUID set: %q", s) + } + seen[s] = true + } +} diff --git a/mathext/clamp_test.go b/mathext/clamp_test.go new file mode 100644 index 0000000..d12c92b --- /dev/null +++ b/mathext/clamp_test.go @@ -0,0 +1,158 @@ +package mathext + +import "testing" + +func TestClampIntWithinRange(t *testing.T) { + if got := ClampInt(5, 1, 10); got != 5 { + t.Errorf("ClampInt(5, 1, 10) = %v, want 5", got) + } +} + +func TestClampIntBelowRange(t *testing.T) { + if got := ClampInt(-3, 1, 10); got != 1 { + t.Errorf("ClampInt(-3, 1, 10) = %v, want 1", got) + } +} + +func TestClampIntAboveRange(t *testing.T) { + if got := ClampInt(15, 1, 10); got != 10 { + t.Errorf("ClampInt(15, 1, 10) = %v, want 10", got) + } +} + +func TestClampIntAtLowerBound(t *testing.T) { + if got := ClampInt(1, 1, 10); got != 1 { + t.Errorf("ClampInt(1, 1, 10) = %v, want 1", got) + } +} + +func TestClampIntAtUpperBound(t *testing.T) { + if got := ClampInt(10, 1, 10); got != 10 { + t.Errorf("ClampInt(10, 1, 10) = %v, want 10", got) + } +} + +func TestClampInt32WithinRange(t *testing.T) { + if got := ClampInt32(int32(5), int32(1), int32(10)); got != 5 { + t.Errorf("ClampInt32(5, 1, 10) = %v, want 5", got) + } +} + +func TestClampInt32BelowRange(t *testing.T) { + if got := ClampInt32(int32(-3), int32(1), int32(10)); got != 1 { + t.Errorf("ClampInt32(-3, 1, 10) = %v, want 1", got) + } +} + +func TestClampInt32AboveRange(t *testing.T) { + if got := ClampInt32(int32(15), int32(1), int32(10)); got != 10 { + t.Errorf("ClampInt32(15, 1, 10) = %v, want 10", got) + } +} + +func TestClampFloat32WithinRange(t *testing.T) { + if got := ClampFloat32(float32(5.5), float32(1.0), float32(10.0)); got != 5.5 { + t.Errorf("ClampFloat32(5.5, 1.0, 10.0) = %v, want 5.5", got) + } +} + +func TestClampFloat32BelowRange(t *testing.T) { + if got := ClampFloat32(float32(-1.5), float32(0.0), float32(10.0)); got != 0.0 { + t.Errorf("ClampFloat32(-1.5, 0.0, 10.0) = %v, want 0.0", got) + } +} + +func TestClampFloat32AboveRange(t *testing.T) { + if got := ClampFloat32(float32(11.5), float32(0.0), float32(10.0)); got != 10.0 { + t.Errorf("ClampFloat32(11.5, 0.0, 10.0) = %v, want 10.0", got) + } +} + +func TestClampFloat64WithinRange(t *testing.T) { + if got := ClampFloat64(5.5, 1.0, 10.0); got != 5.5 { + t.Errorf("ClampFloat64(5.5, 1.0, 10.0) = %v, want 5.5", got) + } +} + +func TestClampFloat64BelowRange(t *testing.T) { + if got := ClampFloat64(-1.5, 0.0, 10.0); got != 0.0 { + t.Errorf("ClampFloat64(-1.5, 0.0, 10.0) = %v, want 0.0", got) + } +} + +func TestClampFloat64AboveRange(t *testing.T) { + if got := ClampFloat64(11.5, 0.0, 10.0); got != 10.0 { + t.Errorf("ClampFloat64(11.5, 0.0, 10.0) = %v, want 10.0", got) + } +} + +func TestClampGenericIntWithinRange(t *testing.T) { + if got := Clamp(5, 1, 10); got != 5 { + t.Errorf("Clamp(5, 1, 10) = %v, want 5", got) + } +} + +func TestClampGenericIntBelowRange(t *testing.T) { + if got := Clamp(-3, 1, 10); got != 1 { + t.Errorf("Clamp(-3, 1, 10) = %v, want 1", got) + } +} + +func TestClampGenericIntAboveRange(t *testing.T) { + if got := Clamp(15, 1, 10); got != 10 { + t.Errorf("Clamp(15, 1, 10) = %v, want 10", got) + } +} + +func TestClampGenericFloat64WithinRange(t *testing.T) { + if got := Clamp(5.5, 1.0, 10.0); got != 5.5 { + t.Errorf("Clamp(5.5, 1.0, 10.0) = %v, want 5.5", got) + } +} + +func TestClampGenericFloat64BelowRange(t *testing.T) { + if got := Clamp(-2.0, 0.0, 10.0); got != 0.0 { + t.Errorf("Clamp(-2.0, 0.0, 10.0) = %v, want 0.0", got) + } +} + +func TestClampGenericFloat64AboveRange(t *testing.T) { + if got := Clamp(20.5, 0.0, 10.0); got != 10.0 { + t.Errorf("Clamp(20.5, 0.0, 10.0) = %v, want 10.0", got) + } +} + +func TestClampOptNilFallback(t *testing.T) { + var v *int = nil + if got := ClampOpt(v, 7, 1, 10); got != 7 { + t.Errorf("ClampOpt(nil, 7, 1, 10) = %v, want 7", got) + } +} + +func TestClampOptValueWithinRange(t *testing.T) { + val := 5 + if got := ClampOpt(&val, 7, 1, 10); got != 5 { + t.Errorf("ClampOpt(&5, 7, 1, 10) = %v, want 5", got) + } +} + +func TestClampOptValueBelowRange(t *testing.T) { + val := -3 + if got := ClampOpt(&val, 7, 1, 10); got != 1 { + t.Errorf("ClampOpt(&-3, 7, 1, 10) = %v, want 1", got) + } +} + +func TestClampOptValueAboveRange(t *testing.T) { + val := 15 + if got := ClampOpt(&val, 7, 1, 10); got != 10 { + t.Errorf("ClampOpt(&15, 7, 1, 10) = %v, want 10", got) + } +} + +func TestClampOptFloat64Nil(t *testing.T) { + var v *float64 = nil + if got := ClampOpt(v, 2.5, 0.0, 10.0); got != 2.5 { + t.Errorf("ClampOpt(nil, 2.5, 0.0, 10.0) = %v, want 2.5", got) + } +} diff --git a/mathext/float_test.go b/mathext/float_test.go new file mode 100644 index 0000000..e933cf6 --- /dev/null +++ b/mathext/float_test.go @@ -0,0 +1,57 @@ +package mathext + +import "testing" + +func TestFloat64EpsilonEqExactlyEqual(t *testing.T) { + if !Float64EpsilonEq(1.0, 1.0, 1e-9) { + t.Errorf("Float64EpsilonEq(1.0, 1.0, 1e-9) = false, want true") + } +} + +func TestFloat64EpsilonEqWithinEpsilon(t *testing.T) { + if !Float64EpsilonEq(1.0, 1.0+1e-10, 1e-9) { + t.Errorf("Float64EpsilonEq(1.0, 1.0+1e-10, 1e-9) = false, want true") + } +} + +func TestFloat64EpsilonEqOutsideEpsilon(t *testing.T) { + if Float64EpsilonEq(1.0, 1.1, 1e-9) { + t.Errorf("Float64EpsilonEq(1.0, 1.1, 1e-9) = true, want false") + } +} + +func TestFloat64EpsilonEqAtEpsilonBoundary(t *testing.T) { + if !Float64EpsilonEq(0.0, 0.5, 0.5) { + t.Errorf("Float64EpsilonEq(0.0, 0.5, 0.5) = false, want true") + } +} + +func TestFloat64EpsilonEqNegativeDifference(t *testing.T) { + if !Float64EpsilonEq(2.0, 2.0-1e-10, 1e-9) { + t.Errorf("Float64EpsilonEq(2.0, 2.0-1e-10, 1e-9) = false, want true") + } +} + +func TestFloat64EpsilonEqLargeDifference(t *testing.T) { + if Float64EpsilonEq(0.0, 100.0, 0.5) { + t.Errorf("Float64EpsilonEq(0.0, 100.0, 0.5) = true, want false") + } +} + +func TestFloat64EpsilonEqNegativeNumbers(t *testing.T) { + if !Float64EpsilonEq(-1.0, -1.0+1e-10, 1e-9) { + t.Errorf("Float64EpsilonEq(-1.0, -1.0+1e-10, 1e-9) = false, want true") + } +} + +func TestFloat64EpsilonEqZeroEpsilonEqualValues(t *testing.T) { + if !Float64EpsilonEq(3.14, 3.14, 0.0) { + t.Errorf("Float64EpsilonEq(3.14, 3.14, 0.0) = false, want true") + } +} + +func TestFloat64EpsilonEqZeroEpsilonDifferentValues(t *testing.T) { + if Float64EpsilonEq(3.14, 3.15, 0.0) { + t.Errorf("Float64EpsilonEq(3.14, 3.15, 0.0) = true, want false") + } +} diff --git a/mathext/math_test.go b/mathext/math_test.go new file mode 100644 index 0000000..f47fb7a --- /dev/null +++ b/mathext/math_test.go @@ -0,0 +1,215 @@ +package mathext + +import ( + "math" + "testing" +) + +func TestSumFloat64HappyPath(t *testing.T) { + values := []float64{1.0, 2.0, 3.0, 4.0} + expected := 10.0 + if got := SumFloat64(values); got != expected { + t.Errorf("SumFloat64(%v) = %v, want %v", values, got, expected) + } +} + +func TestSumFloat64Empty(t *testing.T) { + values := []float64{} + if got := SumFloat64(values); got != 0.0 { + t.Errorf("SumFloat64(empty) = %v, want 0.0", got) + } +} + +func TestSumFloat64Negatives(t *testing.T) { + values := []float64{-1.0, -2.0, 3.0} + expected := 0.0 + if got := SumFloat64(values); got != expected { + t.Errorf("SumFloat64(%v) = %v, want %v", values, got, expected) + } +} + +func TestAvgFloat64HappyPath(t *testing.T) { + values := []float64{2.0, 4.0, 6.0, 8.0} + expected := 5.0 + if got := AvgFloat64(values); got != expected { + t.Errorf("AvgFloat64(%v) = %v, want %v", values, got, expected) + } +} + +func TestAvgFloat64SingleValue(t *testing.T) { + values := []float64{42.0} + expected := 42.0 + if got := AvgFloat64(values); got != expected { + t.Errorf("AvgFloat64(%v) = %v, want %v", values, got, expected) + } +} + +func TestAvgFloat64EmptyReturnsNaN(t *testing.T) { + values := []float64{} + got := AvgFloat64(values) + if !math.IsNaN(got) { + t.Errorf("AvgFloat64(empty) = %v, want NaN", got) + } +} + +func TestMaxIntFirstLarger(t *testing.T) { + if got := Max(5, 3); got != 5 { + t.Errorf("Max(5, 3) = %v, want 5", got) + } +} + +func TestMaxIntSecondLarger(t *testing.T) { + if got := Max(3, 5); got != 5 { + t.Errorf("Max(3, 5) = %v, want 5", got) + } +} + +func TestMaxIntEqual(t *testing.T) { + if got := Max(5, 5); got != 5 { + t.Errorf("Max(5, 5) = %v, want 5", got) + } +} + +func TestMaxFloat64(t *testing.T) { + if got := Max(2.7, 3.1); got != 3.1 { + t.Errorf("Max(2.7, 3.1) = %v, want 3.1", got) + } +} + +func TestMaxString(t *testing.T) { + if got := Max("apple", "banana"); got != "banana" { + t.Errorf(`Max("apple", "banana") = %v, want "banana"`, got) + } +} + +func TestMax3FirstLargest(t *testing.T) { + if got := Max3(10, 5, 3); got != 10 { + t.Errorf("Max3(10, 5, 3) = %v, want 10", got) + } +} + +func TestMax3MiddleLargest(t *testing.T) { + if got := Max3(5, 10, 3); got != 10 { + t.Errorf("Max3(5, 10, 3) = %v, want 10", got) + } +} + +func TestMax3LastLargest(t *testing.T) { + if got := Max3(5, 3, 10); got != 10 { + t.Errorf("Max3(5, 3, 10) = %v, want 10", got) + } +} + +func TestMax4(t *testing.T) { + if got := Max4(1, 5, 3, 7); got != 7 { + t.Errorf("Max4(1, 5, 3, 7) = %v, want 7", got) + } +} + +func TestMax4FirstLargest(t *testing.T) { + if got := Max4(10, 5, 3, 7); got != 10 { + t.Errorf("Max4(10, 5, 3, 7) = %v, want 10", got) + } +} + +func TestMax4ThirdLargest(t *testing.T) { + if got := Max4(1, 5, 100, 7); got != 100 { + t.Errorf("Max4(1, 5, 100, 7) = %v, want 100", got) + } +} + +func TestMinIntFirstSmaller(t *testing.T) { + if got := Min(3, 5); got != 3 { + t.Errorf("Min(3, 5) = %v, want 3", got) + } +} + +func TestMinIntSecondSmaller(t *testing.T) { + if got := Min(5, 3); got != 3 { + t.Errorf("Min(5, 3) = %v, want 3", got) + } +} + +func TestMinIntEqual(t *testing.T) { + if got := Min(5, 5); got != 5 { + t.Errorf("Min(5, 5) = %v, want 5", got) + } +} + +func TestMinFloat64(t *testing.T) { + if got := Min(2.7, 3.1); got != 2.7 { + t.Errorf("Min(2.7, 3.1) = %v, want 2.7", got) + } +} + +func TestMin3FirstSmallest(t *testing.T) { + if got := Min3(1, 5, 10); got != 1 { + t.Errorf("Min3(1, 5, 10) = %v, want 1", got) + } +} + +func TestMin3MiddleSmallest(t *testing.T) { + if got := Min3(5, 1, 10); got != 1 { + t.Errorf("Min3(5, 1, 10) = %v, want 1", got) + } +} + +func TestMin3LastSmallest(t *testing.T) { + if got := Min3(5, 10, 1); got != 1 { + t.Errorf("Min3(5, 10, 1) = %v, want 1", got) + } +} + +func TestMin4(t *testing.T) { + if got := Min4(7, 3, 5, 1); got != 1 { + t.Errorf("Min4(7, 3, 5, 1) = %v, want 1", got) + } +} + +func TestMin4FirstSmallest(t *testing.T) { + if got := Min4(1, 5, 3, 7); got != 1 { + t.Errorf("Min4(1, 5, 3, 7) = %v, want 1", got) + } +} + +func TestMin4ThirdSmallest(t *testing.T) { + if got := Min4(10, 5, 1, 7); got != 1 { + t.Errorf("Min4(10, 5, 1, 7) = %v, want 1", got) + } +} + +func TestAbsPositiveInt(t *testing.T) { + if got := Abs(5); got != 5 { + t.Errorf("Abs(5) = %v, want 5", got) + } +} + +func TestAbsNegativeInt(t *testing.T) { + if got := Abs(-5); got != 5 { + t.Errorf("Abs(-5) = %v, want 5", got) + } +} + +func TestAbsZeroInt(t *testing.T) { + if got := Abs(0); got != 0 { + t.Errorf("Abs(0) = %v, want 0", got) + } +} + +func TestAbsPositiveFloat64(t *testing.T) { + if got := Abs(3.14); got != 3.14 { + t.Errorf("Abs(3.14) = %v, want 3.14", got) + } +} + +func TestAbsNegativeFloat64(t *testing.T) { + if got := Abs(-3.14); got != 3.14 { + t.Errorf("Abs(-3.14) = %v, want 3.14", got) + } +} + +func TestAbsNegativeFloat32(t *testing.T) { + if got := Abs(float32(-2.5)); got != float32(2.5) { + t.Errorf("Abs(-2.5) = %v, want 2.5", got) + } +} diff --git a/mongoext/pipeline_test.go b/mongoext/pipeline_test.go new file mode 100644 index 0000000..372d2ae --- /dev/null +++ b/mongoext/pipeline_test.go @@ -0,0 +1,130 @@ +package mongoext + +import ( + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +func TestFixTextSearchPipelineEmpty(t *testing.T) { + pipeline := mongo.Pipeline{} + result := FixTextSearchPipeline(pipeline) + tst.AssertEqual(t, len(result), 0) +} + +func TestFixTextSearchPipelineNoTextSearch(t *testing.T) { + pipeline := mongo.Pipeline{ + bson.D{{Key: "$match", Value: bson.M{"foo": "bar"}}}, + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + } + + result := FixTextSearchPipeline(pipeline) + + tst.AssertEqual(t, len(result), 2) + tst.AssertEqual(t, result[0][0].Key, "$match") + tst.AssertEqual(t, result[1][0].Key, "$sort") +} + +func TestFixTextSearchPipelineMovesTextSearchToFront(t *testing.T) { + pipeline := mongo.Pipeline{ + bson.D{{Key: "$match", Value: bson.M{"foo": "bar"}}}, + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + bson.D{{Key: "$match", Value: bson.M{"$text": bson.M{"$search": "hello world"}}}}, + } + + result := FixTextSearchPipeline(pipeline) + + tst.AssertEqual(t, len(result), 3) + + // $text/$search should be at front + first := result[0] + matchVal, ok := first[0].Value.(bson.M) + tst.AssertTrue(t, ok) + textVal, ok := matchVal["$text"].(bson.M) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, textVal["$search"].(string), "hello world") + + // other entries should preserve order + tst.AssertEqual(t, result[1][0].Key, "$match") + tst.AssertEqual(t, result[2][0].Key, "$sort") +} + +func TestFixTextSearchPipelineMultipleTextSearches(t *testing.T) { + pipeline := mongo.Pipeline{ + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + bson.D{{Key: "$match", Value: bson.M{"$text": bson.M{"$search": "first"}}}}, + bson.D{{Key: "$match", Value: bson.M{"foo": "bar"}}}, + bson.D{{Key: "$match", Value: bson.M{"$text": bson.M{"$search": "second"}}}}, + } + + result := FixTextSearchPipeline(pipeline) + + tst.AssertEqual(t, len(result), 4) + + // Last seen text-search should be prepended last, ending up at front. + first := result[0][0].Value.(bson.M) + tst.AssertEqual(t, first["$text"].(bson.M)["$search"].(string), "second") + + second := result[1][0].Value.(bson.M) + tst.AssertEqual(t, second["$text"].(bson.M)["$search"].(string), "first") + + tst.AssertEqual(t, result[2][0].Key, "$sort") + tst.AssertEqual(t, result[3][0].Key, "$match") +} + +func TestFixTextSearchPipelineMatchWithoutText(t *testing.T) { + pipeline := mongo.Pipeline{ + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + bson.D{{Key: "$match", Value: bson.M{"name": "alice"}}}, + } + + result := FixTextSearchPipeline(pipeline) + + tst.AssertEqual(t, len(result), 2) + tst.AssertEqual(t, result[0][0].Key, "$sort") + tst.AssertEqual(t, result[1][0].Key, "$match") +} + +func TestFixTextSearchPipelineTextWithoutSearch(t *testing.T) { + // $text present but without $search key — should NOT be moved + pipeline := mongo.Pipeline{ + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + bson.D{{Key: "$match", Value: bson.M{"$text": bson.M{"$language": "en"}}}}, + } + + result := FixTextSearchPipeline(pipeline) + + tst.AssertEqual(t, len(result), 2) + tst.AssertEqual(t, result[0][0].Key, "$sort") + tst.AssertEqual(t, result[1][0].Key, "$match") +} + +func TestFixTextSearchPipelineMatchValueWrongType(t *testing.T) { + // $match with non-bson.M value — function should keep entry in place + pipeline := mongo.Pipeline{ + bson.D{{Key: "$match", Value: "not a map"}}, + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + } + + result := FixTextSearchPipeline(pipeline) + + tst.AssertEqual(t, len(result), 2) + tst.AssertEqual(t, result[0][0].Key, "$match") + tst.AssertEqual(t, result[1][0].Key, "$sort") +} + +func TestFixTextSearchPipelinePreservesOriginal(t *testing.T) { + original := mongo.Pipeline{ + bson.D{{Key: "$sort", Value: bson.M{"baz": 1}}}, + bson.D{{Key: "$match", Value: bson.M{"$text": bson.M{"$search": "x"}}}}, + } + originalLen := len(original) + originalFirstKey := original[0][0].Key + + _ = FixTextSearchPipeline(original) + + tst.AssertEqual(t, len(original), originalLen) + tst.AssertEqual(t, original[0][0].Key, originalFirstKey) +} diff --git a/mongoext/projections_test.go b/mongoext/projections_test.go new file mode 100644 index 0000000..229103a --- /dev/null +++ b/mongoext/projections_test.go @@ -0,0 +1,90 @@ +package mongoext + +import ( + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestProjectionFromStructSimple(t *testing.T) { + type model struct { + ID string `bson:"_id"` + Name string `bson:"name"` + Age int `bson:"age"` + } + + res := ProjectionFromStruct(model{}) + + tst.AssertEqual(t, len(res), 3) + tst.AssertEqual(t, res["_id"], 1) + tst.AssertEqual(t, res["name"], 1) + tst.AssertEqual(t, res["age"], 1) +} + +func TestProjectionFromStructIgnoresUntagged(t *testing.T) { + type model struct { + Tagged string `bson:"tagged"` + Untagged string + Other int `json:"other"` + } + + res := ProjectionFromStruct(model{}) + + tst.AssertEqual(t, len(res), 1) + tst.AssertEqual(t, res["tagged"], 1) + + if _, ok := res["Untagged"]; ok { + t.Errorf("untagged field should not be in projection") + } + if _, ok := res["Other"]; ok { + t.Errorf("non-bson-tagged field should not be in projection") + } +} + +func TestProjectionFromStructWithOptions(t *testing.T) { + type model struct { + ID string `bson:"_id,omitempty"` + Name string `bson:"name,omitempty"` + Slug string `bson:"slug,inline"` + } + + res := ProjectionFromStruct(model{}) + + tst.AssertEqual(t, len(res), 3) + tst.AssertEqual(t, res["_id"], 1) + tst.AssertEqual(t, res["name"], 1) + tst.AssertEqual(t, res["slug"], 1) +} + +func TestProjectionFromStructEmpty(t *testing.T) { + type empty struct{} + + res := ProjectionFromStruct(empty{}) + + tst.AssertEqual(t, len(res), 0) +} + +func TestProjectionFromStructPointerValues(t *testing.T) { + type model struct { + Name *string `bson:"name"` + Tags []int `bson:"tags"` + } + + res := ProjectionFromStruct(model{}) + + tst.AssertEqual(t, len(res), 2) + tst.AssertEqual(t, res["name"], 1) + tst.AssertEqual(t, res["tags"], 1) +} + +func TestProjectionFromStructAllSkipped(t *testing.T) { + type model struct { + A string + B int + C bool + } + + res := ProjectionFromStruct(model{}) + + tst.AssertEqual(t, len(res), 0) +} diff --git a/mongoext/registry_test.go b/mongoext/registry_test.go new file mode 100644 index 0000000..a377238 --- /dev/null +++ b/mongoext/registry_test.go @@ -0,0 +1,79 @@ +package mongoext + +import ( + "bytes" + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestCreateGoExtBsonRegistryNotNil(t *testing.T) { + reg := CreateGoExtBsonRegistry() + if reg == nil { + t.Fatal("registry should not be nil") + } +} + +func TestCreateGoExtBsonRegistryEmbeddedDocumentDecodesAsBsonM(t *testing.T) { + reg := CreateGoExtBsonRegistry() + + doc := bson.M{ + "name": "alice", + "nested": bson.M{ + "key": "value", + "num": int32(42), + }, + } + + raw, err := bson.Marshal(doc) + tst.AssertNoErr(t, err) + + dec := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(raw))) + dec.SetRegistry(reg) + + var decoded map[string]any + err = dec.Decode(&decoded) + tst.AssertNoErr(t, err) + + nested, ok := decoded["nested"].(bson.M) + if !ok { + t.Fatalf("expected nested to be bson.M, got %T", decoded["nested"]) + } + + tst.AssertEqual(t, nested["key"].(string), "value") +} + +func TestCreateGoExtBsonRegistryStructFieldOfTypeAny(t *testing.T) { + reg := CreateGoExtBsonRegistry() + + type wrapper struct { + Payload any `bson:"payload"` + } + + source := wrapper{Payload: bson.M{"x": "y"}} + raw, err := bson.Marshal(source) + tst.AssertNoErr(t, err) + + dec := bson.NewDecoder(bson.NewDocumentReader(bytes.NewReader(raw))) + dec.SetRegistry(reg) + + var decoded wrapper + err = dec.Decode(&decoded) + tst.AssertNoErr(t, err) + + payload, ok := decoded.Payload.(bson.M) + if !ok { + t.Fatalf("expected Payload to be bson.M, got %T", decoded.Payload) + } + tst.AssertEqual(t, payload["x"].(string), "y") +} + +func TestCreateGoExtBsonRegistryReturnsIndependentInstances(t *testing.T) { + r1 := CreateGoExtBsonRegistry() + r2 := CreateGoExtBsonRegistry() + + if r1 == r2 { + t.Error("expected each call to return a new registry instance") + } +} diff --git a/pagination/filter_test.go b/pagination/filter_test.go new file mode 100644 index 0000000..3d961b7 --- /dev/null +++ b/pagination/filter_test.go @@ -0,0 +1,102 @@ +package pagination + +import ( + "context" + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +func TestCreateFilterReturnsNonNil(t *testing.T) { + f := CreateFilter(mongo.Pipeline{}, bson.D{}) + if f == nil { + t.Fatal("expected non-nil MongoFilter") + } +} + +func TestCreateFilterPreservesPipeline(t *testing.T) { + pipeline := mongo.Pipeline{ + bson.D{{Key: "$match", Value: bson.D{{Key: "active", Value: true}}}}, + bson.D{{Key: "$limit", Value: 10}}, + } + f := CreateFilter(pipeline, bson.D{}) + + got := f.FilterQuery(context.Background()) + if len(got) != len(pipeline) { + t.Fatalf("expected pipeline len %d, got %d", len(pipeline), len(got)) + } + for i := range pipeline { + if len(got[i]) != len(pipeline[i]) { + t.Errorf("stage %d: expected len %d, got %d", i, len(pipeline[i]), len(got[i])) + } + for j, e := range pipeline[i] { + if got[i][j].Key != e.Key { + t.Errorf("stage %d field %d: expected key %q, got %q", i, j, e.Key, got[i][j].Key) + } + } + } +} + +func TestCreateFilterPreservesSort(t *testing.T) { + sort := bson.D{ + {Key: "createdAt", Value: -1}, + {Key: "_id", Value: 1}, + } + f := CreateFilter(mongo.Pipeline{}, sort) + + got := f.Sort(context.Background()) + if len(got) != len(sort) { + t.Fatalf("expected sort len %d, got %d", len(sort), len(got)) + } + for i, e := range sort { + if got[i].Key != e.Key { + t.Errorf("field %d: expected key %q, got %q", i, e.Key, got[i].Key) + } + if got[i].Value != e.Value { + t.Errorf("field %d: expected value %v, got %v", i, e.Value, got[i].Value) + } + } +} + +func TestCreateFilterEmptyInputs(t *testing.T) { + f := CreateFilter(mongo.Pipeline{}, bson.D{}) + + if got := f.FilterQuery(context.Background()); len(got) != 0 { + t.Errorf("expected empty pipeline, got len %d", len(got)) + } + if got := f.Sort(context.Background()); len(got) != 0 { + t.Errorf("expected empty sort, got len %d", len(got)) + } +} + +func TestCreateFilterNilInputs(t *testing.T) { + f := CreateFilter(nil, nil) + + if got := f.FilterQuery(context.Background()); got != nil { + t.Errorf("expected nil pipeline, got %v", got) + } + if got := f.Sort(context.Background()); got != nil { + t.Errorf("expected nil sort, got %v", got) + } +} + +func TestCreateFilterImplementsMongoFilter(t *testing.T) { + var _ MongoFilter = CreateFilter(mongo.Pipeline{}, bson.D{}) +} + +func TestCreateFilterIgnoresContext(t *testing.T) { + pipeline := mongo.Pipeline{bson.D{{Key: "$count", Value: "n"}}} + sort := bson.D{{Key: "x", Value: 1}} + f := CreateFilter(pipeline, sort) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + if got := f.FilterQuery(ctx); len(got) != 1 { + t.Errorf("expected pipeline len 1 even with cancelled ctx, got %d", len(got)) + } + if got := f.Sort(ctx); len(got) != 1 { + t.Errorf("expected sort len 1 even with cancelled ctx, got %d", len(got)) + } +} diff --git a/pagination/pagination_test.go b/pagination/pagination_test.go new file mode 100644 index 0000000..504f535 --- /dev/null +++ b/pagination/pagination_test.go @@ -0,0 +1,82 @@ +package pagination + +import "testing" + +func TestCalcPaginationTotalPagesZeroItems(t *testing.T) { + if got := CalcPaginationTotalPages(0, 10); got != 0 { + t.Errorf("expected 0, got %d", got) + } +} + +func TestCalcPaginationTotalPagesZeroLimit(t *testing.T) { + if got := CalcPaginationTotalPages(100, 0); got != 0 { + t.Errorf("expected 0, got %d", got) + } +} + +func TestCalcPaginationTotalPagesZeroBoth(t *testing.T) { + if got := CalcPaginationTotalPages(0, 0); got != 0 { + t.Errorf("expected 0, got %d", got) + } +} + +func TestCalcPaginationTotalPagesExactMultiple(t *testing.T) { + if got := CalcPaginationTotalPages(100, 10); got != 10 { + t.Errorf("expected 10, got %d", got) + } +} + +func TestCalcPaginationTotalPagesPartialLastPage(t *testing.T) { + if got := CalcPaginationTotalPages(101, 10); got != 11 { + t.Errorf("expected 11, got %d", got) + } +} + +func TestCalcPaginationTotalPagesSingleItem(t *testing.T) { + if got := CalcPaginationTotalPages(1, 10); got != 1 { + t.Errorf("expected 1, got %d", got) + } +} + +func TestCalcPaginationTotalPagesItemsLessThanLimit(t *testing.T) { + if got := CalcPaginationTotalPages(5, 10); got != 1 { + t.Errorf("expected 1, got %d", got) + } +} + +func TestCalcPaginationTotalPagesLimitOfOne(t *testing.T) { + if got := CalcPaginationTotalPages(7, 1); got != 7 { + t.Errorf("expected 7, got %d", got) + } +} + +func TestCalcPaginationTotalPagesItemsEqualLimit(t *testing.T) { + if got := CalcPaginationTotalPages(10, 10); got != 1 { + t.Errorf("expected 1, got %d", got) + } +} + +func TestCalcPaginationTotalPagesLargeNumbers(t *testing.T) { + if got := CalcPaginationTotalPages(1_000_000, 250); got != 4000 { + t.Errorf("expected 4000, got %d", got) + } +} + +func TestCalcPaginationTotalPagesOneMoreThanMultiple(t *testing.T) { + if got := CalcPaginationTotalPages(11, 5); got != 3 { + t.Errorf("expected 3, got %d", got) + } +} + +func TestPaginationStructFields(t *testing.T) { + p := Pagination{ + Page: 2, + Limit: 25, + TotalPages: 4, + TotalItems: 100, + CurrentPageCount: 25, + } + if p.Page != 2 || p.Limit != 25 || p.TotalPages != 4 || p.TotalItems != 100 || p.CurrentPageCount != 25 { + t.Errorf("unexpected pagination struct values: %+v", p) + } +} diff --git a/reflectext/casting_test.go b/reflectext/casting_test.go new file mode 100644 index 0000000..8063cb8 --- /dev/null +++ b/reflectext/casting_test.go @@ -0,0 +1,188 @@ +package reflectext + +import ( + "reflect" + "testing" +) + +type aliasInt int +type aliasString string +type aliasFloat float64 +type aliasBool bool + +type aliasIntSlice []int +type aliasStringMap map[string]int +type aliasArr [3]int +type aliasIntPtr *int + +type myStruct struct { + A int + B string +} + +func TestUnderlying_Primitives(t *testing.T) { + cases := []struct { + in reflect.Type + want reflect.Kind + }{ + {reflect.TypeFor[aliasInt](), reflect.Int}, + {reflect.TypeFor[aliasString](), reflect.String}, + {reflect.TypeFor[aliasFloat](), reflect.Float64}, + {reflect.TypeFor[aliasBool](), reflect.Bool}, + } + for _, c := range cases { + got := Underlying(c.in) + if got.Kind() != c.want { + t.Errorf("Underlying(%v).Kind() = %v, want %v", c.in, got.Kind(), c.want) + } + if got.Name() != "" && got != reflectBasicTypes[c.want] { + t.Errorf("Underlying(%v) was not the basic type", c.in) + } + } +} + +func TestUnderlying_UnnamedReturnsSelf(t *testing.T) { + t1 := reflect.TypeFor[[]int]() + got := Underlying(t1) + if got != t1 { + t.Errorf("Underlying of unnamed slice should be itself") + } +} + +func TestUnderlying_Slice(t *testing.T) { + t1 := reflect.TypeFor[aliasIntSlice]() + got := Underlying(t1) + if got.Kind() != reflect.Slice { + t.Errorf("expected slice kind, got %v", got.Kind()) + } + if got.Elem().Kind() != reflect.Int { + t.Errorf("expected element of int, got %v", got.Elem().Kind()) + } + if got.Name() != "" { + t.Errorf("underlying type should be unnamed, got name %q", got.Name()) + } +} + +func TestUnderlying_Map(t *testing.T) { + t1 := reflect.TypeFor[aliasStringMap]() + got := Underlying(t1) + if got.Kind() != reflect.Map { + t.Errorf("expected map kind, got %v", got.Kind()) + } +} + +func TestUnderlying_Array(t *testing.T) { + t1 := reflect.TypeFor[aliasArr]() + got := Underlying(t1) + if got.Kind() != reflect.Array { + t.Errorf("expected array kind, got %v", got.Kind()) + } + if got.Len() != 3 { + t.Errorf("expected array len 3, got %d", got.Len()) + } +} + +func TestUnderlying_Pointer(t *testing.T) { + t1 := reflect.TypeFor[aliasIntPtr]() + got := Underlying(t1) + if got.Kind() != reflect.Pointer { + t.Errorf("expected pointer kind, got %v", got.Kind()) + } + if got.Elem().Kind() != reflect.Int { + t.Errorf("expected element kind int, got %v", got.Elem().Kind()) + } +} + +func TestUnderlying_Func(t *testing.T) { + type fnType func(a int, b string) (bool, error) + t1 := reflect.TypeFor[fnType]() + got := Underlying(t1) + if got.Kind() != reflect.Func { + t.Errorf("expected func kind, got %v", got.Kind()) + } + if got.NumIn() != 2 || got.NumOut() != 2 { + t.Errorf("unexpected in/out count: %d/%d", got.NumIn(), got.NumOut()) + } +} + +func TestTryCast_AliasToBase(t *testing.T) { + v := aliasInt(42) + got, ok := TryCast[int](v) + if !ok { + t.Errorf("expected ok cast") + } + if got != 42 { + t.Errorf("expected 42, got %v", got) + } +} + +func TestTryCast_SameType(t *testing.T) { + got, ok := TryCast[int](42) + if !ok { + t.Errorf("expected ok cast") + } + if got != 42 { + t.Errorf("expected 42, got %v", got) + } +} + +func TestTryCast_StringSameType(t *testing.T) { + got, ok := TryCast[string]("hello") + if !ok { + t.Errorf("expected ok cast") + } + if got != "hello" { + t.Errorf("expected hello, got %v", got) + } +} + +func TestTryCast_IncompatibleTypes(t *testing.T) { + _, ok := TryCast[string](aliasInt(42)) + if ok { + t.Errorf("expected fail cast int->string") + } + + _, ok = TryCast[int](aliasString("foo")) + if ok { + t.Errorf("expected fail cast string->int") + } +} + +func TestTryCastType_AliasToBase(t *testing.T) { + v := aliasInt(42) + res, ok := TryCastType(v, reflect.TypeFor[int]()) + if !ok { + t.Errorf("expected ok cast") + } + if i, isInt := res.(int); !isInt || i != 42 { + t.Errorf("expected int(42), got %T:%v", res, res) + } +} + +func TestTryCastType_BaseToAlias(t *testing.T) { + res, ok := TryCastType(42, reflect.TypeFor[aliasInt]()) + if !ok { + t.Errorf("expected ok cast") + } + if i, isAlias := res.(aliasInt); !isAlias || i != aliasInt(42) { + t.Errorf("expected aliasInt(42), got %T:%v", res, res) + } +} + +func TestTryCastType_Incompatible(t *testing.T) { + _, ok := TryCastType("hello", reflect.TypeFor[int]()) + if ok { + t.Errorf("expected fail cast string->int") + } +} + +func TestUnderlying_Struct(t *testing.T) { + t1 := reflect.TypeFor[myStruct]() + got := Underlying(t1) + if got.Kind() != reflect.Struct { + t.Errorf("expected struct kind, got %v", got.Kind()) + } + if got.NumField() != 2 { + t.Errorf("expected 2 fields, got %d", got.NumField()) + } +} diff --git a/reflectext/primitiveStringSerializer_test.go b/reflectext/primitiveStringSerializer_test.go new file mode 100644 index 0000000..7e9b1a7 --- /dev/null +++ b/reflectext/primitiveStringSerializer_test.go @@ -0,0 +1,367 @@ +package reflectext + +import ( + "reflect" + "testing" + "time" + + "go.mongodb.org/mongo-driver/v2/bson" +) + +func TestPSS_String(t *testing.T) { + pss := PrimitiveStringSerializer{} + + s, err := pss.ValueToString("hello") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "hello" { + t.Errorf("expected hello, got %q", s) + } + + v, err := pss.ValueFromString("world", reflect.TypeFor[string]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(string) != "world" { + t.Errorf("expected world, got %v", v) + } +} + +func TestPSS_Int(t *testing.T) { + pss := PrimitiveStringSerializer{} + + s, err := pss.ValueToString(42) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "42" { + t.Errorf("expected 42, got %q", s) + } + + v, err := pss.ValueFromString("42", reflect.TypeFor[int]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(int) != 42 { + t.Errorf("expected 42, got %v", v) + } +} + +func TestPSS_Int64(t *testing.T) { + pss := PrimitiveStringSerializer{} + + s, err := pss.ValueToString(int64(-100)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "-100" { + t.Errorf("expected -100, got %q", s) + } + + v, err := pss.ValueFromString("-100", reflect.TypeFor[int64]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(int64) != -100 { + t.Errorf("expected -100, got %v", v) + } +} + +func TestPSS_Uint(t *testing.T) { + pss := PrimitiveStringSerializer{} + + s, err := pss.ValueToString(uint(123)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "123" { + t.Errorf("expected 123, got %q", s) + } + + v, err := pss.ValueFromString("123", reflect.TypeFor[uint]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(uint) != 123 { + t.Errorf("expected 123, got %v", v) + } +} + +func TestPSS_Float(t *testing.T) { + pss := PrimitiveStringSerializer{} + + s, err := pss.ValueToString(3.14) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "3.14" { + t.Errorf("expected 3.14, got %q", s) + } + + v, err := pss.ValueFromString("3.14", reflect.TypeFor[float64]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(float64) != 3.14 { + t.Errorf("expected 3.14, got %v", v) + } +} + +func TestPSS_Float32(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v, err := pss.ValueFromString("1.5", reflect.TypeFor[float32]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(float32) != 1.5 { + t.Errorf("expected 1.5, got %v", v) + } +} + +func TestPSS_Bool(t *testing.T) { + pss := PrimitiveStringSerializer{} + + s, err := pss.ValueToString(true) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "true" { + t.Errorf("expected true, got %q", s) + } + + s, err = pss.ValueToString(false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "false" { + t.Errorf("expected false, got %q", s) + } + + v, err := pss.ValueFromString("true", reflect.TypeFor[bool]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(bool) != true { + t.Errorf("expected true, got %v", v) + } +} + +func TestPSS_BoolInvalid(t *testing.T) { + pss := PrimitiveStringSerializer{} + + _, err := pss.ValueFromString("notabool", reflect.TypeFor[bool]()) + if err == nil { + t.Errorf("expected error for invalid bool") + } +} + +func TestPSS_IntInvalid(t *testing.T) { + pss := PrimitiveStringSerializer{} + + _, err := pss.ValueFromString("notanint", reflect.TypeFor[int]()) + if err == nil { + t.Errorf("expected error for invalid int") + } +} + +func TestPSS_FloatInvalid(t *testing.T) { + pss := PrimitiveStringSerializer{} + + _, err := pss.ValueFromString("notafloat", reflect.TypeFor[float64]()) + if err == nil { + t.Errorf("expected error for invalid float") + } +} + +func TestPSS_Time(t *testing.T) { + pss := PrimitiveStringSerializer{} + + tm := time.Date(2023, 4, 5, 12, 30, 45, 0, time.UTC) + s, err := pss.ValueToString(tm) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + v, err := pss.ValueFromString(s, reflect.TypeFor[time.Time]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !v.(time.Time).Equal(tm) { + t.Errorf("expected %v, got %v", tm, v) + } +} + +func TestPSS_TimeInvalid(t *testing.T) { + pss := PrimitiveStringSerializer{} + + _, err := pss.ValueFromString("not-a-time", reflect.TypeFor[time.Time]()) + if err == nil { + t.Errorf("expected error for invalid time") + } +} + +func TestPSS_ObjectID(t *testing.T) { + pss := PrimitiveStringSerializer{} + + oid := bson.NewObjectID() + s, err := pss.ValueToString(oid) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != oid.Hex() { + t.Errorf("expected %v, got %v", oid.Hex(), s) + } + + v, err := pss.ValueFromString(oid.Hex(), reflect.TypeFor[bson.ObjectID]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(bson.ObjectID) != oid { + t.Errorf("expected %v, got %v", oid, v) + } +} + +func TestPSS_ObjectIDInvalid(t *testing.T) { + pss := PrimitiveStringSerializer{} + + _, err := pss.ValueFromString("not-a-hex-id", reflect.TypeFor[bson.ObjectID]()) + if err == nil { + t.Errorf("expected error for invalid object id") + } +} + +func TestPSS_PointerNil(t *testing.T) { + pss := PrimitiveStringSerializer{} + + var p *int + s, err := pss.ValueToString(p) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "" { + t.Errorf("expected empty string, got %q", s) + } +} + +func TestPSS_PointerSet(t *testing.T) { + pss := PrimitiveStringSerializer{} + + x := 99 + p := &x + s, err := pss.ValueToString(p) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "99" { + t.Errorf("expected 99, got %q", s) + } +} + +func TestPSS_FromStringEmptyToPointer(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v, err := pss.ValueFromString("", reflect.TypeFor[*int]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + pInt, ok := v.(*int) + if !ok { + t.Fatalf("expected *int, got %T", v) + } + if pInt != nil { + t.Errorf("expected nil pointer, got %v", *pInt) + } +} + +func TestPSS_FromStringPointer(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v, err := pss.ValueFromString("55", reflect.TypeFor[*int]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + pInt, ok := v.(*int) + if !ok { + t.Fatalf("expected *int, got %T", v) + } + if pInt == nil || *pInt != 55 { + t.Errorf("expected pointer to 55, got %v", pInt) + } +} + +func TestPSS_FromStringEmptyToInt(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v, err := pss.ValueFromString("", reflect.TypeFor[int]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(int) != 0 { + t.Errorf("expected 0, got %v", v) + } +} + +type psAliasInt int +type psAliasString string + +func TestPSS_AliasToString(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v := psAliasInt(77) + s, err := pss.ValueToString(v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "77" { + t.Errorf("expected 77, got %q", s) + } +} + +func TestPSS_AliasFromString(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v, err := pss.ValueFromString("77", reflect.TypeFor[psAliasInt]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(psAliasInt) != psAliasInt(77) { + t.Errorf("expected 77, got %v", v) + } +} + +func TestPSS_AliasStringToString(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v := psAliasString("hello") + s, err := pss.ValueToString(v) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if s != "hello" { + t.Errorf("expected hello, got %q", s) + } +} + +func TestPSS_AliasStringFromString(t *testing.T) { + pss := PrimitiveStringSerializer{} + + v, err := pss.ValueFromString("hello", reflect.TypeFor[psAliasString]()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if v.(psAliasString) != psAliasString("hello") { + t.Errorf("expected hello, got %v", v) + } +} + +func TestPSS_UnknownTypeToString(t *testing.T) { + pss := PrimitiveStringSerializer{} + + type unknownStruct struct{ X int } + _, err := pss.ValueToString(unknownStruct{X: 1}) + if err == nil { + t.Errorf("expected error for unknown type") + } +} diff --git a/rext/wrapper_extra_test.go b/rext/wrapper_extra_test.go new file mode 100644 index 0000000..4926d21 --- /dev/null +++ b/rext/wrapper_extra_test.go @@ -0,0 +1,260 @@ +package rext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "regexp" + "strings" + "testing" +) + +func TestW(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + tst.AssertTrue(t, r != nil) + tst.AssertEqual(t, r.String(), `\d+`) +} + +func TestIsMatchTrue(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + tst.AssertTrue(t, r.IsMatch("abc 123 def")) +} + +func TestIsMatchFalse(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + tst.AssertFalse(t, r.IsMatch("abc def")) +} + +func TestString(t *testing.T) { + r := W(regexp.MustCompile(`^foo(bar)?$`)) + tst.AssertEqual(t, r.String(), `^foo(bar)?$`) +} + +func TestGroupCountWrapper(t *testing.T) { + r0 := W(regexp.MustCompile(`abc`)) + tst.AssertEqual(t, r0.GroupCount(), 0) + + r1 := W(regexp.MustCompile(`(a)(b)(c)`)) + tst.AssertEqual(t, r1.GroupCount(), 3) + + r2 := W(regexp.MustCompile(`(?P\d+)-(?P\d+)`)) + tst.AssertEqual(t, r2.GroupCount(), 2) +} + +func TestMatchFirstFound(t *testing.T) { + r := W(regexp.MustCompile(`(\d+)-(\d+)`)) + m, ok := r.MatchFirst("a 12-34 b 56-78 c") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, m.FullMatch().Value(), "12-34") + tst.AssertEqual(t, m.GroupByIndex(1).Value(), "12") + tst.AssertEqual(t, m.GroupByIndex(2).Value(), "34") +} + +func TestMatchFirstNotFound(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + m, ok := r.MatchFirst("nothing here") + tst.AssertFalse(t, ok) + // zero-value match should be returned + tst.AssertEqual(t, len(m.submatchesIndex), 0) +} + +func TestMatchAllMultiple(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + matches := r.MatchAll("a 1 b 22 c 333") + tst.AssertEqual(t, len(matches), 3) + tst.AssertEqual(t, matches[0].FullMatch().Value(), "1") + tst.AssertEqual(t, matches[1].FullMatch().Value(), "22") + tst.AssertEqual(t, matches[2].FullMatch().Value(), "333") +} + +func TestMatchAllNone(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + matches := r.MatchAll("abc") + tst.AssertEqual(t, len(matches), 0) +} + +func TestMatchAllSingle(t *testing.T) { + r := W(regexp.MustCompile(`(?P\d+)`)) + matches := r.MatchAll("only 42 here") + tst.AssertEqual(t, len(matches), 1) + tst.AssertEqual(t, matches[0].GroupByName("num").Value(), "42") +} + +func TestReplaceAllNonLiteralExpansion(t *testing.T) { + r := W(regexp.MustCompile(`(\w+)@(\w+)`)) + out := r.ReplaceAll("hi alice@example, hi bob@example", "$2/$1", false) + tst.AssertEqual(t, out, "hi example/alice, hi example/bob") +} + +func TestReplaceAllLiteralNoExpansion(t *testing.T) { + r := W(regexp.MustCompile(`(\w+)@(\w+)`)) + out := r.ReplaceAll("hi alice@example", "$2/$1", true) + tst.AssertEqual(t, out, "hi $2/$1") +} + +func TestReplaceAllFunc(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + out := r.ReplaceAllFunc("a1 b22 c333", func(s string) string { + return strings.Repeat("x", len(s)) + }) + tst.AssertEqual(t, out, "ax bxx cxxx") +} + +func TestRemoveAll(t *testing.T) { + r := W(regexp.MustCompile(`\s+`)) + out := r.RemoveAll(" hello world ") + tst.AssertEqual(t, out, "helloworld") +} + +func TestRemoveAllNoMatch(t *testing.T) { + r := W(regexp.MustCompile(`\d+`)) + out := r.RemoveAll("no digits here") + tst.AssertEqual(t, out, "no digits here") +} + +func TestRemoveAllDoesNotExpandPlaceholders(t *testing.T) { + // removal uses literal replacement, so no expansion happens + r := W(regexp.MustCompile(`(\w+)`)) + out := r.RemoveAll("abc") + tst.AssertEqual(t, out, "") +} + +// --- RegexMatch --- + +func TestRegexMatchFullMatch(t *testing.T) { + r := W(regexp.MustCompile(`b\w+d`)) + m, ok := r.MatchFirst("aa beard cc") + tst.AssertTrue(t, ok) + fm := m.FullMatch() + tst.AssertEqual(t, fm.Value(), "beard") + tst.AssertEqual(t, fm.Start(), 3) + tst.AssertEqual(t, fm.End(), 8) +} + +func TestRegexMatchGroupCount(t *testing.T) { + r := W(regexp.MustCompile(`(a)(b)(c)`)) + m, ok := r.MatchFirst("abc") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, m.GroupCount(), 3) +} + +func TestRegexMatchGroupByIndex(t *testing.T) { + r := W(regexp.MustCompile(`(\w+)-(\w+)-(\w+)`)) + m, ok := r.MatchFirst("foo-bar-baz") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, m.GroupByIndex(0).Value(), "foo-bar-baz") + tst.AssertEqual(t, m.GroupByIndex(1).Value(), "foo") + tst.AssertEqual(t, m.GroupByIndex(2).Value(), "bar") + tst.AssertEqual(t, m.GroupByIndex(3).Value(), "baz") +} + +func TestRegexMatchGroupByName(t *testing.T) { + r := W(regexp.MustCompile(`(?P\w+)\s+(?P\w+)`)) + m, ok := r.MatchFirst("John Doe") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, m.GroupByName("first").Value(), "John") + tst.AssertEqual(t, m.GroupByName("last").Value(), "Doe") +} + +func TestRegexMatchGroupByNamePanics(t *testing.T) { + r := W(regexp.MustCompile(`(?P\d+)`)) + m, ok := r.MatchFirst("99") + tst.AssertTrue(t, ok) + + defer func() { + rec := recover() + tst.AssertTrue(t, rec != nil) + }() + + _ = m.GroupByName("nonexistent") + t.Fatal("expected panic") +} + +func TestGroupByNameOrEmptyMissingName(t *testing.T) { + r := W(regexp.MustCompile(`(?P\d+)`)) + m, ok := r.MatchFirst("99") + tst.AssertTrue(t, ok) + + g := m.GroupByNameOrEmpty("not-present") + tst.AssertTrue(t, g.IsEmpty()) + tst.AssertFalse(t, g.Exists()) + tst.AssertEqual(t, g.ValueOrEmpty(), "") + tst.AssertPtrEqual(t, g.ValueOrNil(), nil) +} + +// --- RegexMatchGroup --- + +func TestRegexMatchGroupAccessors(t *testing.T) { + r := W(regexp.MustCompile(`(\w+)`)) + m, ok := r.MatchFirst(" hello ") + tst.AssertTrue(t, ok) + + g := m.GroupByIndex(1) + tst.AssertEqual(t, g.Value(), "hello") + tst.AssertEqual(t, g.Start(), 2) + tst.AssertEqual(t, g.End(), 7) + s, e := g.Range() + tst.AssertEqual(t, s, 2) + tst.AssertEqual(t, e, 7) + tst.AssertEqual(t, g.Length(), 5) +} + +// --- OptRegexMatchGroup --- + +func TestOptRegexMatchGroupExisting(t *testing.T) { + r := W(regexp.MustCompile(`(?P\w+)`)) + m, ok := r.MatchFirst(" hello ") + tst.AssertTrue(t, ok) + + g := m.GroupByNameOrEmpty("word") + tst.AssertTrue(t, g.Exists()) + tst.AssertFalse(t, g.IsEmpty()) + tst.AssertEqual(t, g.Value(), "hello") + tst.AssertEqual(t, g.ValueOrEmpty(), "hello") + tst.AssertEqual(t, *g.ValueOrNil(), "hello") + tst.AssertEqual(t, g.Start(), 2) + tst.AssertEqual(t, g.End(), 7) + s, e := g.Range() + tst.AssertEqual(t, s, 2) + tst.AssertEqual(t, e, 7) + tst.AssertEqual(t, g.Length(), 5) +} + +func TestOptRegexMatchGroupOptionalNotMatched(t *testing.T) { + // group2 is optional and won't match; group1 will + r := W(regexp.MustCompile(`(?PA+)(?PB+)?`)) + m, ok := r.MatchFirst("AAA") + tst.AssertTrue(t, ok) + + g1 := m.GroupByNameOrEmpty("group1") + tst.AssertTrue(t, g1.Exists()) + tst.AssertEqual(t, g1.ValueOrEmpty(), "AAA") + + g2 := m.GroupByNameOrEmpty("group2") + tst.AssertTrue(t, g2.IsEmpty()) + tst.AssertFalse(t, g2.Exists()) + tst.AssertEqual(t, g2.ValueOrEmpty(), "") + tst.AssertPtrEqual(t, g2.ValueOrNil(), nil) +} + +// --- Misc combined behavior --- + +func TestMultipleNamedGroupsAcrossMatches(t *testing.T) { + r := W(regexp.MustCompile(`(?P\w+)=(?P\d+)`)) + matches := r.MatchAll("a=1 b=22 c=333") + tst.AssertEqual(t, len(matches), 3) + + tst.AssertEqual(t, matches[0].GroupByName("k").Value(), "a") + tst.AssertEqual(t, matches[0].GroupByName("v").Value(), "1") + tst.AssertEqual(t, matches[1].GroupByName("k").Value(), "b") + tst.AssertEqual(t, matches[1].GroupByName("v").Value(), "22") + tst.AssertEqual(t, matches[2].GroupByName("k").Value(), "c") + tst.AssertEqual(t, matches[2].GroupByName("v").Value(), "333") +} + +func TestEmptyMatchHaystack(t *testing.T) { + r := W(regexp.MustCompile(`.*`)) + tst.AssertTrue(t, r.IsMatch("")) + m, ok := r.MatchFirst("") + tst.AssertTrue(t, ok) + tst.AssertEqual(t, m.FullMatch().Value(), "") + tst.AssertEqual(t, m.FullMatch().Length(), 0) +} diff --git a/rfctime/date_test.go b/rfctime/date_test.go new file mode 100644 index 0000000..d84309c --- /dev/null +++ b/rfctime/date_test.go @@ -0,0 +1,196 @@ +package rfctime + +import ( + "encoding/json" + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestDateString(t *testing.T) { + d := Date{Year: 2023, Month: 5, Day: 7} + tst.AssertEqual(t, d.String(), "2023-05-07") + tst.AssertEqual(t, d.Serialize(), "2023-05-07") + tst.AssertEqual(t, d.GoString(), "rfctime.Date{Year: 2023, Month: 5, Day: 7}") + tst.AssertEqual(t, d.FormatStr(), "2006-01-02") +} + +func TestDateIsZero(t *testing.T) { + tst.AssertEqual(t, Date{}.IsZero(), true) + tst.AssertEqual(t, Date{Year: 1, Month: 1, Day: 1}.IsZero(), false) +} + +func TestDateNew(t *testing.T) { + tm := time.Date(2023, 5, 7, 12, 30, 0, 0, time.UTC) + d := NewDate(tm) + tst.AssertEqual(t, d.Year, 2023) + tst.AssertEqual(t, d.Month, 5) + tst.AssertEqual(t, d.Day, 7) +} + +func TestDateTimeConversions(t *testing.T) { + d := Date{Year: 2023, Month: 5, Day: 7} + utc := d.TimeUTC() + tst.AssertEqual(t, utc.Year(), 2023) + tst.AssertEqual(t, utc.Month(), time.May) + tst.AssertEqual(t, utc.Day(), 7) + tst.AssertEqual(t, utc.Location(), time.UTC) + + loc := d.TimeLocal() + tst.AssertEqual(t, loc.Location(), time.Local) + + custom := d.Time(time.UTC) + tst.AssertEqual(t, custom.Hour(), 0) + tst.AssertEqual(t, custom.Location(), time.UTC) +} + +func TestDateJSON(t *testing.T) { + type Wrap struct { + D Date `json:"d"` + } + w1 := Wrap{D: Date{Year: 2023, Month: 5, Day: 7}} + b, err := json.Marshal(w1) + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), `{"d":"2023-05-07"}`) + + var w2 Wrap + if err := json.Unmarshal(b, &w2); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, w2.D, w1.D) +} + +func TestDateJSONInvalid(t *testing.T) { + var d Date + if err := d.UnmarshalJSON([]byte(`"not-a-date"`)); err == nil { + t.Errorf("expected parse error") + } + if err := d.UnmarshalJSON([]byte(`123`)); err == nil { + t.Errorf("expected json error for number") + } +} + +func TestDateText(t *testing.T) { + d := Date{Year: 2023, Month: 5, Day: 7} + b, err := d.MarshalText() + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), "2023-05-07") + + var d2 Date + if err := d2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, d2, d) + + if err := d2.UnmarshalText([]byte("garbage")); err == nil { + t.Errorf("expected error") + } +} + +func TestDateBinaryGob(t *testing.T) { + d := Date{Year: 2023, Month: 5, Day: 7} + + bin, err := d.MarshalBinary() + if err != nil { + t.Fatal(err) + } + var d2 Date + if err := d2.UnmarshalBinary(bin); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, d2, d) + + gob, err := d.GobEncode() + if err != nil { + t.Fatal(err) + } + var d3 Date + if err := d3.GobDecode(gob); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, d3, d) +} + +func TestDateAccessors(t *testing.T) { + d := Date{Year: 2023, Month: 5, Day: 17} + y, m, day := d.Date() + tst.AssertEqual(t, y, 2023) + tst.AssertEqual(t, m, time.May) + tst.AssertEqual(t, day, 17) + tst.AssertEqual(t, d.Weekday(), time.Wednesday) + + wy, ww := d.ISOWeek() + ey, ew := d.TimeUTC().ISOWeek() + tst.AssertEqual(t, wy, ey) + tst.AssertEqual(t, ww, ew) + + tst.AssertEqual(t, d.YearDay(), d.TimeUTC().YearDay()) + tst.AssertEqual(t, d.Unix(), d.TimeUTC().Unix()) + tst.AssertEqual(t, d.UnixMilli(), d.TimeUTC().UnixMilli()) + tst.AssertEqual(t, d.UnixMicro(), d.TimeUTC().UnixMicro()) + tst.AssertEqual(t, d.UnixNano(), d.TimeUTC().UnixNano()) + tst.AssertEqual(t, d.Format("2006/01/02"), "2023/05/17") +} + +func TestDateAddDate(t *testing.T) { + d := Date{Year: 2023, Month: 5, Day: 17} + d2 := d.AddDate(1, 2, 3) + tst.AssertEqual(t, d2.Year, 2024) + tst.AssertEqual(t, d2.Month, 7) + tst.AssertEqual(t, d2.Day, 20) +} + +func TestDateParseString(t *testing.T) { + tests := []struct { + input string + ok bool + expected Date + }{ + {"2023-05-07", true, Date{2023, 5, 7}}, + {"0001-01-01", true, Date{1, 1, 1}}, + {"2023-13-01", false, Date{}}, // bad month + {"2023-12-32", false, Date{}}, // bad day + {"2023-00-15", false, Date{}}, // month 0 + {"2023-05", false, Date{}}, // bad format + {"2023-05-07-extra", false, Date{}}, + {"abcd-ef-gh", false, Date{}}, + {"-1-05-07", false, Date{}}, // negative year + } + + for _, tc := range tests { + var d Date + err := d.ParseString(tc.input) + if tc.ok { + if err != nil { + t.Errorf("ParseString(%q) failed: %v", tc.input, err) + continue + } + tst.AssertEqual(t, d, tc.expected) + } else if err == nil { + t.Errorf("ParseString(%q) should have failed", tc.input) + } + } +} + +func TestNowDate(t *testing.T) { + now := time.Now().UTC() + d := NowDate(time.UTC) + tst.AssertEqual(t, d.Year, now.Year()) + tst.AssertEqual(t, d.Month, int(now.Month())) + tst.AssertEqual(t, d.Day, now.Day()) + + dl := NowDateLoc() + if dl.Year < 1970 { + t.Errorf("NowDateLoc returned implausible year: %d", dl.Year) + } + + du := NowDateUTC() + if du.Year < 1970 { + t.Errorf("NowDateUTC returned implausible year: %d", du.Year) + } +} diff --git a/rfctime/rfc3339_test.go b/rfctime/rfc3339_test.go new file mode 100644 index 0000000..157c6f9 --- /dev/null +++ b/rfctime/rfc3339_test.go @@ -0,0 +1,197 @@ +package rfctime + +import ( + "encoding/json" + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/timeext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestRFC3339TimeRoundtripJSON(t *testing.T) { + type Wrap struct { + Value RFC3339Time `json:"v"` + } + + val := NewRFC3339(time.Unix(1675951556, 0).In(timeext.TimezoneBerlin)) + w1 := Wrap{val} + + jstr1, err := json.Marshal(w1) + if err != nil { + t.Fatal(err) + } + + if string(jstr1) != "{\"v\":\"2023-02-09T15:05:56+01:00\"}" { + t.Errorf("unexpected json: %s", string(jstr1)) + } + + w2 := Wrap{} + if err := json.Unmarshal(jstr1, &w2); err != nil { + t.Fatal(err) + } + + jstr2, err := json.Marshal(w2) + if err != nil { + t.Fatal(err) + } + + tst.AssertEqual(t, string(jstr1), string(jstr2)) + + if !w1.Value.EqualAny(w2.Value) { + t.Errorf("time differs after roundtrip") + } +} + +func TestRFC3339TimeUnmarshalJSONInvalid(t *testing.T) { + var v RFC3339Time + if err := v.UnmarshalJSON([]byte(`"not-a-date"`)); err == nil { + t.Errorf("expected error parsing invalid date") + } + if err := v.UnmarshalJSON([]byte(`12345`)); err == nil { + t.Errorf("expected error for non-string json") + } +} + +func TestRFC3339TimeText(t *testing.T) { + val := NewRFC3339(time.Date(2023, 2, 9, 15, 5, 56, 0, time.UTC)) + b, err := val.MarshalText() + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), "2023-02-09T15:05:56Z") + + var v2 RFC3339Time + if err := v2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + if !v2.Equal(val) { + t.Errorf("text roundtrip mismatch") + } + + if err := v2.UnmarshalText([]byte("garbage")); err == nil { + t.Errorf("expected error on bad text") + } +} + +func TestRFC3339TimeBinaryAndGob(t *testing.T) { + val := NewRFC3339(time.Date(2023, 2, 9, 15, 5, 56, 123, time.UTC)) + + bin, err := val.MarshalBinary() + if err != nil { + t.Fatal(err) + } + var v2 RFC3339Time + if err := v2.UnmarshalBinary(bin); err != nil { + t.Fatal(err) + } + if !v2.Equal(val) { + t.Errorf("binary roundtrip mismatch") + } + + gob, err := val.GobEncode() + if err != nil { + t.Fatal(err) + } + var v3 RFC3339Time + if err := v3.GobDecode(gob); err != nil { + t.Fatal(err) + } + if !v3.Equal(val) { + t.Errorf("gob roundtrip mismatch") + } +} + +func TestRFC3339TimeAccessors(t *testing.T) { + loc, _ := time.LoadLocation("UTC") + tm := time.Date(2023, 5, 17, 14, 30, 45, 123456789, loc) + val := NewRFC3339(tm) + + tst.AssertEqual(t, val.Year(), 2023) + tst.AssertEqual(t, val.Month(), time.May) + tst.AssertEqual(t, val.Day(), 17) + tst.AssertEqual(t, val.Hour(), 14) + tst.AssertEqual(t, val.Minute(), 30) + tst.AssertEqual(t, val.Second(), 45) + tst.AssertEqual(t, val.Nanosecond(), 123456789) + tst.AssertEqual(t, val.Weekday(), time.Wednesday) + tst.AssertEqual(t, val.YearDay(), tm.YearDay()) + tst.AssertEqual(t, val.Unix(), tm.Unix()) + tst.AssertEqual(t, val.UnixMilli(), tm.UnixMilli()) + tst.AssertEqual(t, val.UnixMicro(), tm.UnixMicro()) + tst.AssertEqual(t, val.UnixNano(), tm.UnixNano()) + tst.AssertEqual(t, val.Location(), loc) + tst.AssertEqual(t, val.Format(time.RFC3339), tm.Format(time.RFC3339)) + tst.AssertEqual(t, val.GoString(), tm.GoString()) + tst.AssertEqual(t, val.String(), tm.String()) + tst.AssertEqual(t, val.Serialize(), tm.Format(time.RFC3339)) + tst.AssertEqual(t, val.FormatStr(), time.RFC3339) + + y, mo, d := val.Date() + tst.AssertEqual(t, y, 2023) + tst.AssertEqual(t, mo, time.May) + tst.AssertEqual(t, d, 17) + + wy, ww := val.ISOWeek() + ey, ew := tm.ISOWeek() + tst.AssertEqual(t, wy, ey) + tst.AssertEqual(t, ww, ew) + + h, m, s := val.Clock() + tst.AssertEqual(t, h, 14) + tst.AssertEqual(t, m, 30) + tst.AssertEqual(t, s, 45) + + tst.AssertEqual(t, val.IsZero(), false) + tst.AssertEqual(t, RFC3339Time{}.IsZero(), true) +} + +func TestRFC3339TimeAddSub(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 0, time.UTC) + a := NewRFC3339(tm) + b := a.Add(2 * time.Hour) + + tst.AssertEqual(t, b.Sub(a), 2*time.Hour) + tst.AssertEqual(t, b.After(a), true) + tst.AssertEqual(t, a.Before(b), true) + tst.AssertEqual(t, a.After(b), false) + tst.AssertEqual(t, b.Before(a), false) + + c := a.AddDate(1, 2, 3) + tst.AssertEqual(t, c.Year(), 2024) + tst.AssertEqual(t, c.Month(), time.July) + tst.AssertEqual(t, c.Day(), 20) +} + +func TestRFC3339TimeEqual(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 0, time.UTC) + a := NewRFC3339(tm) + b := NewRFC3339(tm) + c := NewRFC3339(tm.Add(time.Second)) + + tst.AssertEqual(t, a.Equal(b), true) + tst.AssertEqual(t, a.Equal(c), false) + tst.AssertEqual(t, a.EqualAny(b), true) + tst.AssertEqual(t, a.EqualAny(c), false) + tst.AssertEqual(t, a.EqualAny(nil), false) + // Cross-type comparison via tt() + tst.AssertEqual(t, a.EqualAny(NewRFC3339Nano(tm)), true) + tst.AssertEqual(t, a.EqualAny(tm), true) +} + +func TestRFC3339TimeToNano(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 12345, time.UTC) + a := NewRFC3339(tm) + n := a.ToNano() + tst.AssertEqual(t, n.UnixNano(), tm.UnixNano()) +} + +func TestNowRFC3339(t *testing.T) { + before := time.Now() + v := NowRFC3339() + after := time.Now() + + if v.Time().Before(before.Add(-time.Second)) || v.Time().After(after.Add(time.Second)) { + t.Errorf("NowRFC3339 not within expected range") + } +} diff --git a/rfctime/seconds_test.go b/rfctime/seconds_test.go new file mode 100644 index 0000000..4533d21 --- /dev/null +++ b/rfctime/seconds_test.go @@ -0,0 +1,50 @@ +package rfctime + +import ( + "encoding/json" + "math" + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestSecondsF64Basics(t *testing.T) { + d := NewSecondsF64(2*time.Second + 500*time.Millisecond) + tst.AssertEqual(t, d.Duration(), 2500*time.Millisecond) + tst.AssertEqual(t, d.Seconds(), 2.5) + tst.AssertEqual(t, d.Milliseconds(), int64(2500)) + tst.AssertEqual(t, d.Microseconds(), int64(2500000)) + tst.AssertEqual(t, d.Nanoseconds(), int64(2500000000)) + tst.AssertEqual(t, d.Minutes(), 2.5/60.0) + tst.AssertEqual(t, d.Hours(), 2.5/3600.0) + tst.AssertEqual(t, d.String(), (2500 * time.Millisecond).String()) +} + +func TestSecondsF64JSON(t *testing.T) { + type Wrap struct { + D SecondsF64 `json:"d"` + } + + w1 := Wrap{D: NewSecondsF64(2500 * time.Millisecond)} + b, err := json.Marshal(w1) + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), `{"d":2.5}`) + + var w2 Wrap + if err := json.Unmarshal(b, &w2); err != nil { + t.Fatal(err) + } + if math.Abs(float64(w2.D.Duration()-w1.D.Duration())) > float64(time.Microsecond) { + t.Errorf("roundtrip mismatch: %v vs %v", w1.D, w2.D) + } +} + +func TestSecondsF64UnmarshalJSONInvalid(t *testing.T) { + var d SecondsF64 + if err := d.UnmarshalJSON([]byte(`"not-a-number"`)); err == nil { + t.Errorf("expected error") + } +} diff --git a/rfctime/time_test.go b/rfctime/time_test.go new file mode 100644 index 0000000..ef83abb --- /dev/null +++ b/rfctime/time_test.go @@ -0,0 +1,100 @@ +package rfctime + +import ( + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestTimeNew(t *testing.T) { + v := NewTime(14, 30, 45, 123) + tst.AssertEqual(t, v.Hour, 14) + tst.AssertEqual(t, v.Minute, 30) + tst.AssertEqual(t, v.Second, 45) + tst.AssertEqual(t, v.NanoSecond, 123) +} + +func TestTimeFromTS(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 123456789, time.UTC) + v := NewTimeFromTS(tm) + tst.AssertEqual(t, v.Hour, 14) + tst.AssertEqual(t, v.Minute, 30) + tst.AssertEqual(t, v.Second, 45) + tst.AssertEqual(t, v.NanoSecond, 123456789) +} + +func TestTimeSerialize(t *testing.T) { + v := NewTime(14, 30, 45, 123456789) + tst.AssertEqual(t, v.Serialize(), "0014:30:45.123456789") + tst.AssertEqual(t, v.String(), "0014:30:45.123456789") + tst.AssertEqual(t, v.GoString(), "rfctime.NewTime(14, 30, 45, 123456789)") + tst.AssertEqual(t, v.FormatStr(), "15:04:05.999999999") +} + +func TestTimeSerializeShort(t *testing.T) { + tst.AssertEqual(t, NewTime(14, 30, 0, 0).SerializeShort(), "14:30") + tst.AssertEqual(t, NewTime(14, 30, 45, 0).SerializeShort(), "14:30:45") + tst.AssertEqual(t, NewTime(14, 30, 45, 123).SerializeShort(), "14:30:45.000000123") + tst.AssertEqual(t, NewTime(0, 0, 0, 0).SerializeShort(), "00:00") +} + +func TestTimeDeserialize(t *testing.T) { + tests := []struct { + input string + ok bool + expected Time + }{ + {"14:30", true, Time{Hour: 14, Minute: 30, Second: 0, NanoSecond: 0}}, + {"14:30:45", true, Time{Hour: 14, Minute: 30, Second: 45, NanoSecond: 0}}, + {"14:30:45.123", true, Time{Hour: 14, Minute: 30, Second: 45, NanoSecond: 123000000}}, + {"14:30:45.123456789", true, Time{Hour: 14, Minute: 30, Second: 45, NanoSecond: 123456789}}, + {"00:00:00.000000000", true, Time{Hour: 0, Minute: 0, Second: 0, NanoSecond: 0}}, + {"14", false, Time{}}, + {"14:30:45.123:extra", false, Time{}}, + {"ab:cd", false, Time{}}, + {"14:bb", false, Time{}}, + {"14:30:cc", false, Time{}}, + {"14:30:45.zz", false, Time{}}, + } + + for _, tc := range tests { + var v Time + err := v.Deserialize(tc.input) + if tc.ok { + if err != nil { + t.Errorf("Deserialize(%q) failed: %v", tc.input, err) + continue + } + tst.AssertEqual(t, v, tc.expected) + } else if err == nil { + t.Errorf("Deserialize(%q) should have failed", tc.input) + } + } +} + +func TestNowTime(t *testing.T) { + now := time.Now().UTC() + v := NowTime(time.UTC) + // Within a couple of seconds + if abs(v.Hour-now.Hour()) > 1 && !(now.Hour() == 23 && v.Hour == 0) { + t.Errorf("NowTime hour mismatch: %d vs %d", v.Hour, now.Hour()) + } + + vl := NowTimeLoc() + if vl.Hour < 0 || vl.Hour > 23 { + t.Errorf("NowTimeLoc invalid hour: %d", vl.Hour) + } + + vu := NowTimeUTC() + if vu.Hour < 0 || vu.Hour > 23 { + t.Errorf("NowTimeUTC invalid hour: %d", vu.Hour) + } +} + +func abs(x int) int { + if x < 0 { + return -x + } + return x +} diff --git a/rfctime/unix_test.go b/rfctime/unix_test.go new file mode 100644 index 0000000..98411c0 --- /dev/null +++ b/rfctime/unix_test.go @@ -0,0 +1,330 @@ +package rfctime + +import ( + "encoding/json" + "strconv" + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestUnixTimeRoundtripJSON(t *testing.T) { + type Wrap struct { + Value UnixTime `json:"v"` + } + + val := NewUnix(time.Unix(1675951556, 0).UTC()) + w1 := Wrap{val} + + jstr1, err := json.Marshal(w1) + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(jstr1), `{"v":"1675951556"}`) + + w2 := Wrap{} + if err := json.Unmarshal(jstr1, &w2); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, w2.Value.Unix(), val.Unix()) +} + +func TestUnixTimeUnmarshalJSONInvalid(t *testing.T) { + var v UnixTime + if err := v.UnmarshalJSON([]byte(`"not-a-number"`)); err == nil { + t.Errorf("expected parse error") + } + if err := v.UnmarshalJSON([]byte(`{}`)); err == nil { + t.Errorf("expected json error on object") + } +} + +func TestUnixTimeText(t *testing.T) { + val := NewUnix(time.Unix(1675951556, 0)) + b, err := val.MarshalText() + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), "1675951556") + + var v2 UnixTime + if err := v2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, v2.Unix(), val.Unix()) + + if err := v2.UnmarshalText([]byte("garbage")); err == nil { + t.Errorf("expected error") + } +} + +func TestUnixTimeBinaryGob(t *testing.T) { + val := NewUnix(time.Date(2023, 5, 17, 14, 30, 45, 0, time.UTC)) + + bin, err := val.MarshalBinary() + if err != nil { + t.Fatal(err) + } + var v2 UnixTime + if err := v2.UnmarshalBinary(bin); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, v2.Unix(), val.Unix()) + + gob, err := val.GobEncode() + if err != nil { + t.Fatal(err) + } + var v3 UnixTime + if err := v3.GobDecode(gob); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, v3.Unix(), val.Unix()) +} + +func TestUnixTimeAccessors(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 12345, time.UTC) + val := NewUnix(tm) + + tst.AssertEqual(t, val.Year(), 2023) + tst.AssertEqual(t, val.Month(), time.May) + tst.AssertEqual(t, val.Day(), 17) + tst.AssertEqual(t, val.Hour(), 14) + tst.AssertEqual(t, val.Minute(), 30) + tst.AssertEqual(t, val.Second(), 45) + tst.AssertEqual(t, val.Nanosecond(), 12345) + tst.AssertEqual(t, val.Weekday(), time.Wednesday) + tst.AssertEqual(t, val.Unix(), tm.Unix()) + tst.AssertEqual(t, val.UnixMilli(), tm.UnixMilli()) + tst.AssertEqual(t, val.UnixMicro(), tm.UnixMicro()) + tst.AssertEqual(t, val.UnixNano(), tm.UnixNano()) + tst.AssertEqual(t, val.Format(time.RFC3339), tm.Format(time.RFC3339)) + tst.AssertEqual(t, val.GoString(), tm.GoString()) + tst.AssertEqual(t, val.String(), tm.String()) + tst.AssertEqual(t, val.Serialize(), strconv.FormatInt(tm.Unix(), 10)) + tst.AssertEqual(t, val.IsZero(), false) + tst.AssertEqual(t, UnixTime{}.IsZero(), true) + + y, mo, d := val.Date() + tst.AssertEqual(t, y, 2023) + tst.AssertEqual(t, mo, time.May) + tst.AssertEqual(t, d, 17) + + wy, ww := val.ISOWeek() + ey, ew := tm.ISOWeek() + tst.AssertEqual(t, wy, ey) + tst.AssertEqual(t, ww, ew) + + h, m, s := val.Clock() + tst.AssertEqual(t, h, 14) + tst.AssertEqual(t, m, 30) + tst.AssertEqual(t, s, 45) + + tst.AssertEqual(t, val.YearDay(), tm.YearDay()) +} + +func TestUnixTimeAddSubCompare(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 0, time.UTC) + a := NewUnix(tm) + b := a.Add(time.Hour) + + tst.AssertEqual(t, b.Sub(a), time.Hour) + tst.AssertEqual(t, b.After(a), true) + tst.AssertEqual(t, a.Before(b), true) + + c := a.AddDate(0, 1, 0) + tst.AssertEqual(t, c.Month(), time.June) + + d := NewUnix(tm) + tst.AssertEqual(t, a.Equal(d), true) + tst.AssertEqual(t, a.EqualAny(d), true) + tst.AssertEqual(t, a.EqualAny(b), false) + tst.AssertEqual(t, a.EqualAny(nil), false) +} + +func TestNowUnix(t *testing.T) { + before := time.Now() + v := NowUnix() + after := time.Now() + + if v.Time().Before(before.Add(-time.Second)) || v.Time().After(after.Add(time.Second)) { + t.Errorf("NowUnix not within expected range") + } +} + +// ---------- UnixMilliTime ---------- + +func TestUnixMilliTimeRoundtripJSON(t *testing.T) { + type Wrap struct { + Value UnixMilliTime `json:"v"` + } + + val := NewUnixMilli(time.UnixMilli(1675951556789).UTC()) + w1 := Wrap{val} + + jstr1, err := json.Marshal(w1) + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(jstr1), `{"v":"1675951556789"}`) + + w2 := Wrap{} + if err := json.Unmarshal(jstr1, &w2); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, w2.Value.UnixMilli(), val.UnixMilli()) +} + +func TestUnixMilliTimeText(t *testing.T) { + val := NewUnixMilli(time.UnixMilli(1675951556789)) + b, err := val.MarshalText() + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), "1675951556789") + + var v2 UnixMilliTime + if err := v2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, v2.UnixMilli(), val.UnixMilli()) +} + +func TestUnixMilliTimeUnmarshalJSONInvalid(t *testing.T) { + var v UnixMilliTime + if err := v.UnmarshalJSON([]byte(`"abc"`)); err == nil { + t.Errorf("expected error") + } + if err := v.UnmarshalJSON([]byte(`[]`)); err == nil { + t.Errorf("expected error on array") + } +} + +func TestUnixMilliTimeAccessors(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 1000000, time.UTC) + val := NewUnixMilli(tm) + + tst.AssertEqual(t, val.Year(), 2023) + tst.AssertEqual(t, val.Serialize(), strconv.FormatInt(tm.UnixMilli(), 10)) + tst.AssertEqual(t, val.IsZero(), false) + tst.AssertEqual(t, UnixMilliTime{}.IsZero(), true) + + a := val.Add(time.Hour) + tst.AssertEqual(t, a.Sub(val), time.Hour) + tst.AssertEqual(t, a.After(val), true) + tst.AssertEqual(t, val.Before(a), true) + + d := NewUnixMilli(tm) + tst.AssertEqual(t, val.Equal(d), true) + tst.AssertEqual(t, val.EqualAny(d), true) + tst.AssertEqual(t, val.EqualAny(nil), false) +} + +func TestNowUnixMilli(t *testing.T) { + before := time.Now() + v := NowUnixMilli() + after := time.Now() + + if v.Time().Before(before.Add(-time.Second)) || v.Time().After(after.Add(time.Second)) { + t.Errorf("NowUnixMilli not within expected range") + } +} + +// ---------- UnixNanoTime ---------- + +func TestUnixNanoTimeRoundtripJSON(t *testing.T) { + type Wrap struct { + Value UnixNanoTime `json:"v"` + } + + val := NewUnixNano(time.Unix(0, 1675951556820915171).UTC()) + w1 := Wrap{val} + + jstr1, err := json.Marshal(w1) + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(jstr1), `{"v":"1675951556820915171"}`) + + w2 := Wrap{} + if err := json.Unmarshal(jstr1, &w2); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, w2.Value.UnixNano(), val.UnixNano()) +} + +func TestUnixNanoTimeText(t *testing.T) { + val := NewUnixNano(time.Unix(0, 1675951556820915171)) + b, err := val.MarshalText() + if err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, string(b), "1675951556820915171") + + var v2 UnixNanoTime + if err := v2.UnmarshalText(b); err != nil { + t.Fatal(err) + } + tst.AssertEqual(t, v2.UnixNano(), val.UnixNano()) + + if err := v2.UnmarshalText([]byte("xyz")); err == nil { + t.Errorf("expected error") + } +} + +func TestUnixNanoTimeUnmarshalJSONInvalid(t *testing.T) { + var v UnixNanoTime + if err := v.UnmarshalJSON([]byte(`"abc"`)); err == nil { + t.Errorf("expected error") + } +} + +func TestUnixNanoTimeAccessors(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 123456789, time.UTC) + val := NewUnixNano(tm) + + tst.AssertEqual(t, val.Year(), 2023) + tst.AssertEqual(t, val.Nanosecond(), 123456789) + tst.AssertEqual(t, val.Serialize(), strconv.FormatInt(tm.UnixNano(), 10)) + tst.AssertEqual(t, val.IsZero(), false) + tst.AssertEqual(t, UnixNanoTime{}.IsZero(), true) + + a := val.Add(2 * time.Second) + tst.AssertEqual(t, a.Sub(val), 2*time.Second) + tst.AssertEqual(t, a.After(val), true) + tst.AssertEqual(t, val.Before(a), true) + + c := val.AddDate(0, 0, 1) + tst.AssertEqual(t, c.Day(), 18) + + d := NewUnixNano(tm) + tst.AssertEqual(t, val.Equal(d), true) + tst.AssertEqual(t, val.EqualAny(d), true) + tst.AssertEqual(t, val.EqualAny(nil), false) +} + +func TestNowUnixNano(t *testing.T) { + before := time.Now() + v := NowUnixNano() + after := time.Now() + + if v.Time().Before(before.Add(-time.Second)) || v.Time().After(after.Add(time.Second)) { + t.Errorf("NowUnixNano not within expected range") + } +} + +func TestUnixCrossTypeEqualAny(t *testing.T) { + tm := time.Date(2023, 5, 17, 14, 30, 45, 0, time.UTC) + u := NewUnix(tm) + um := NewUnixMilli(tm) + un := NewUnixNano(tm) + r := NewRFC3339(tm) + rn := NewRFC3339Nano(tm) + + tst.AssertEqual(t, u.EqualAny(um), true) + tst.AssertEqual(t, u.EqualAny(un), true) + tst.AssertEqual(t, u.EqualAny(r), true) + tst.AssertEqual(t, u.EqualAny(rn), true) + tst.AssertEqual(t, u.EqualAny(tm), true) +} diff --git a/scn/scn_test.go b/scn/scn_test.go new file mode 100644 index 0000000..bdce988 --- /dev/null +++ b/scn/scn_test.go @@ -0,0 +1,286 @@ +package scn + +import ( + "testing" + "time" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func TestNew(t *testing.T) { + c := New("my-token") + if c == nil { + t.Fatal("New returned nil") + } + tst.AssertEqual(t, c.token, "my-token") +} + +func TestNewEmptyToken(t *testing.T) { + c := New("") + if c == nil { + t.Fatal("New returned nil") + } + tst.AssertEqual(t, c.token, "") +} + +func TestNewReturnsDistinctInstances(t *testing.T) { + c1 := New("token-a") + c2 := New("token-b") + if c1 == c2 { + t.Fatal("New should return distinct instances") + } + tst.AssertEqual(t, c1.token, "token-a") + tst.AssertEqual(t, c2.token, "token-b") +} + +func TestConnectionMessage(t *testing.T) { + c := New("tok") + mb := c.Message("Hello") + + if mb == nil { + t.Fatal("Message returned nil") + } + if mb.conn != c { + t.Error("MessageBuilder.conn does not point to source Connection") + } + tst.AssertEqual(t, mb.title, "Hello") + if mb.content != nil { + t.Error("expected content to be nil") + } + if mb.channel != nil { + t.Error("expected channel to be nil") + } + if mb.time != nil { + t.Error("expected time to be nil") + } + if mb.sendername != nil { + t.Error("expected sendername to be nil") + } + if mb.priority != nil { + t.Error("expected priority to be nil") + } +} + +func TestConnectionTitle(t *testing.T) { + c := New("tok") + mb := c.Title("Hello") + + if mb == nil { + t.Fatal("Title returned nil") + } + if mb.conn != c { + t.Error("MessageBuilder.conn does not point to source Connection") + } + tst.AssertEqual(t, mb.title, "Hello") + if mb.content != nil { + t.Error("expected content to be nil") + } +} + +func TestMessageAndTitleAreEquivalent(t *testing.T) { + c := New("tok") + mbMsg := c.Message("X") + mbTitle := c.Title("X") + + tst.AssertEqual(t, mbMsg.title, mbTitle.title) + tst.AssertEqual(t, mbMsg.conn, mbTitle.conn) +} + +func TestBuilderChannel(t *testing.T) { + c := New("tok") + mb := c.Message("t") + res := mb.Channel("foo-channel") + + if res != mb { + t.Error("Channel did not return same builder") + } + if mb.channel == nil { + t.Fatal("expected channel to be set") + } + tst.AssertEqual(t, *mb.channel, "foo-channel") +} + +func TestBuilderContent(t *testing.T) { + c := New("tok") + mb := c.Message("t") + res := mb.Content("body") + + if res != mb { + t.Error("Content did not return same builder") + } + if mb.content == nil { + t.Fatal("expected content to be set") + } + tst.AssertEqual(t, *mb.content, "body") +} + +func TestBuilderTime(t *testing.T) { + c := New("tok") + mb := c.Message("t") + now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC) + res := mb.Time(now) + + if res != mb { + t.Error("Time did not return same builder") + } + if mb.time == nil { + t.Fatal("expected time to be set") + } + if !mb.time.Equal(now) { + t.Errorf("expected %v, got %v", now, *mb.time) + } +} + +func TestBuilderSenderName(t *testing.T) { + c := New("tok") + mb := c.Message("t") + res := mb.SenderName("alice") + + if res != mb { + t.Error("SenderName did not return same builder") + } + if mb.sendername == nil { + t.Fatal("expected sendername to be set") + } + tst.AssertEqual(t, *mb.sendername, "alice") +} + +func TestBuilderPriority(t *testing.T) { + c := New("tok") + mb := c.Message("t") + res := mb.Priority(2) + + if res != mb { + t.Error("Priority did not return same builder") + } + if mb.priority == nil { + t.Fatal("expected priority to be set") + } + tst.AssertEqual(t, *mb.priority, 2) +} + +func TestBuilderChaining(t *testing.T) { + c := New("tok") + tt := time.Date(2030, 5, 6, 7, 8, 9, 0, time.UTC) + + mb := c.Message("hello"). + Channel("ch"). + Content("content"). + Time(tt). + SenderName("bob"). + Priority(7) + + if mb == nil { + t.Fatal("chained builder returned nil") + } + tst.AssertEqual(t, mb.title, "hello") + if mb.channel == nil || *mb.channel != "ch" { + t.Error("channel not set correctly") + } + if mb.content == nil || *mb.content != "content" { + t.Error("content not set correctly") + } + if mb.time == nil || !mb.time.Equal(tt) { + t.Error("time not set correctly") + } + if mb.sendername == nil || *mb.sendername != "bob" { + t.Error("sendername not set correctly") + } + if mb.priority == nil || *mb.priority != 7 { + t.Error("priority not set correctly") + } +} + +func TestBuilderOverwriteValues(t *testing.T) { + c := New("tok") + mb := c.Message("t"). + Channel("first"). + Content("first"). + SenderName("first"). + Priority(1) + + mb.Channel("second"). + Content("second"). + SenderName("second"). + Priority(2) + + tst.AssertEqual(t, *mb.channel, "second") + tst.AssertEqual(t, *mb.content, "second") + tst.AssertEqual(t, *mb.sendername, "second") + tst.AssertEqual(t, *mb.priority, 2) +} + +func TestBuilderIndependentInstances(t *testing.T) { + c := New("tok") + + a := c.Message("A").Content("aa").Priority(1) + b := c.Message("B").Content("bb").Priority(9) + + if a == b { + t.Fatal("expected distinct builders") + } + tst.AssertEqual(t, a.title, "A") + tst.AssertEqual(t, b.title, "B") + tst.AssertEqual(t, *a.content, "aa") + tst.AssertEqual(t, *b.content, "bb") + tst.AssertEqual(t, *a.priority, 1) + tst.AssertEqual(t, *b.priority, 9) +} + +func TestBuilderTimeZonePreserved(t *testing.T) { + c := New("tok") + loc, err := time.LoadLocation("Europe/Berlin") + if err != nil { + t.Skipf("timezone db not available: %v", err) + } + tt := time.Date(2025, 6, 15, 12, 0, 0, 0, loc) + + mb := c.Message("t").Time(tt) + + if mb.time == nil { + t.Fatal("expected time to be set") + } + if !mb.time.Equal(tt) { + t.Errorf("expected %v, got %v", tt, *mb.time) + } + if mb.time.Unix() != tt.Unix() { + t.Errorf("expected unix %d, got %d", tt.Unix(), mb.time.Unix()) + } +} + +func TestBuilderNegativePriority(t *testing.T) { + c := New("tok") + mb := c.Message("t").Priority(-1) + if mb.priority == nil { + t.Fatal("expected priority to be set") + } + tst.AssertEqual(t, *mb.priority, -1) +} + +func TestBuilderEmptyStrings(t *testing.T) { + c := New("tok") + mb := c.Message(""). + Channel(""). + Content(""). + SenderName("") + + tst.AssertEqual(t, mb.title, "") + if mb.channel == nil || *mb.channel != "" { + t.Error("channel should be set to empty string (not nil)") + } + if mb.content == nil || *mb.content != "" { + t.Error("content should be set to empty string (not nil)") + } + if mb.sendername == nil || *mb.sendername != "" { + t.Error("sendername should be set to empty string (not nil)") + } +} + +func TestErrorTypesAreDistinct(t *testing.T) { + errs := []any{ErrAuthFailed, ErrQuota, ErrBadRequest, ErrInternalServerErr, ErrOther} + for i := range errs { + if errs[i] == nil { + t.Errorf("error type at index %d is nil", i) + } + } +} diff --git a/sq/builderUnit_test.go b/sq/builderUnit_test.go new file mode 100644 index 0000000..31a7ab4 --- /dev/null +++ b/sq/builderUnit_test.go @@ -0,0 +1,213 @@ +package sq + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestBuildInsertStatementBasic(t *testing.T) { + type r struct { + ID string `db:"id"` + Name string `db:"name"` + } + + q := fakeQueryable{} + sqlstr, pp, err := BuildInsertStatement(q, "users", r{ID: "1", Name: "alice"}) + tst.AssertNoErr(t, err) + + tst.AssertTrue(t, strings.HasPrefix(sqlstr, "INSERT INTO users (")) + tst.AssertTrue(t, strings.Contains(sqlstr, "id")) + tst.AssertTrue(t, strings.Contains(sqlstr, "name")) + tst.AssertEqual(t, 2, len(pp)) + + values := []any{} + for _, v := range pp { + values = append(values, v) + } + + hasID, hasName := false, false + for _, v := range values { + if vs, ok := v.(string); ok { + if vs == "1" { + hasID = true + } + if vs == "alice" { + hasName = true + } + } + } + tst.AssertTrue(t, hasID) + tst.AssertTrue(t, hasName) +} + +func TestBuildInsertStatementSkipsUnexported(t *testing.T) { + type r struct { + ID string `db:"id"` + hidden string `db:"hidden"` //nolint:unused + } + + q := fakeQueryable{} + sqlstr, pp, err := BuildInsertStatement(q, "users", r{ID: "1"}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 1, len(pp)) + tst.AssertTrue(t, !strings.Contains(sqlstr, "hidden")) +} + +func TestBuildInsertStatementSkipsNoTagAndDash(t *testing.T) { + type r struct { + ID string `db:"id"` + Skip1 string `db:"-"` + Skip2 string + } + + q := fakeQueryable{} + sqlstr, pp, err := BuildInsertStatement(q, "users", r{ID: "1", Skip1: "x", Skip2: "y"}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 1, len(pp)) + tst.AssertTrue(t, !strings.Contains(sqlstr, "Skip")) +} + +func TestBuildInsertStatementNoFields(t *testing.T) { + type r struct { + Skip string + } + q := fakeQueryable{} + _, _, err := BuildInsertStatement(q, "x", r{}) + if err == nil { + t.Fatal("expected error for no usable fields") + } +} + +func TestBuildInsertStatementNilPointer(t *testing.T) { + type r struct { + ID string `db:"id"` + Note *string `db:"note"` + } + + q := fakeQueryable{} + sqlstr, pp, err := BuildInsertStatement(q, "users", r{ID: "1", Note: nil}) + tst.AssertNoErr(t, err) + + // Only id is parameterized; nil pointer becomes literal NULL + tst.AssertEqual(t, 1, len(pp)) + tst.AssertTrue(t, strings.Contains(sqlstr, "NULL")) +} + +func TestBuildInsertStatementWithConverter(t *testing.T) { + type r struct { + ID string `db:"id"` + Flag bool `db:"flag"` + } + + q := fakeQueryable{converters: []DBTypeConverter{ConverterBoolToBit}} + _, pp, err := BuildInsertStatement(q, "users", r{ID: "1", Flag: true}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 2, len(pp)) + + foundOne := false + for _, v := range pp { + if vi, ok := v.(int64); ok && vi == 1 { + foundOne = true + } + } + tst.AssertTrue(t, foundOne) +} + +func TestBuildUpdateStatementBasic(t *testing.T) { + type r struct { + ID string `db:"id"` + Name string `db:"name"` + } + + q := fakeQueryable{} + sqlstr, pp, err := BuildUpdateStatement(q, "users", r{ID: "1", Name: "alice"}, "id") + tst.AssertNoErr(t, err) + + tst.AssertTrue(t, strings.HasPrefix(sqlstr, "UPDATE users SET ")) + tst.AssertTrue(t, strings.Contains(sqlstr, "name = :")) + tst.AssertTrue(t, strings.Contains(sqlstr, "(id = :")) + tst.AssertEqual(t, 2, len(pp)) +} + +func TestBuildUpdateStatementMissingID(t *testing.T) { + type r struct { + Name string `db:"name"` + } + + q := fakeQueryable{} + _, _, err := BuildUpdateStatement(q, "users", r{Name: "alice"}, "id") + if err == nil { + t.Fatal("expected error for missing id column") + } +} + +func TestBuildUpdateStatementOnlyID(t *testing.T) { + type r struct { + ID string `db:"id"` + } + + q := fakeQueryable{} + _, _, err := BuildUpdateStatement(q, "users", r{ID: "1"}, "id") + if err == nil { + t.Fatal("expected error when no SET clauses") + } +} + +func TestBuildUpdateStatementNilPointer(t *testing.T) { + type r struct { + ID string `db:"id"` + Note *string `db:"note"` + } + + q := fakeQueryable{} + sqlstr, _, err := BuildUpdateStatement(q, "users", r{ID: "1", Note: nil}, "id") + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.Contains(sqlstr, "note = NULL")) +} + +func TestBuildInsertMultipleStatementBasic(t *testing.T) { + type r struct { + ID string `db:"id"` + Name string `db:"name"` + } + + q := fakeQueryable{} + sqlstr, pp, err := BuildInsertMultipleStatement(q, "users", []r{ + {ID: "1", Name: "alice"}, + {ID: "2", Name: "bob"}, + }) + tst.AssertNoErr(t, err) + + tst.AssertTrue(t, strings.Contains(sqlstr, `INSERT INTO "users"`)) + tst.AssertTrue(t, strings.Contains(sqlstr, `"id"`)) + tst.AssertTrue(t, strings.Contains(sqlstr, `"name"`)) + // 2 rows × 2 fields = 4 placeholders + tst.AssertEqual(t, 4, len(pp)) + + // Two value tuples should appear -> exactly one "), (" separator + tst.AssertEqual(t, 1, strings.Count(sqlstr, "), (")) +} + +func TestBuildInsertMultipleStatementEmpty(t *testing.T) { + type r struct { + ID string `db:"id"` + } + q := fakeQueryable{} + _, _, err := BuildInsertMultipleStatement(q, "x", []r{}) + if err == nil { + t.Fatal("expected error for empty input") + } +} + +func TestBuildInsertMultipleStatementNilPointer(t *testing.T) { + type r struct { + ID string `db:"id"` + Note *string `db:"note"` + } + + q := fakeQueryable{} + sqlstr, _, err := BuildInsertMultipleStatement(q, "users", []r{{ID: "1", Note: nil}}) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, strings.Contains(sqlstr, "NULL")) +} diff --git a/sq/commentTrimmer_test.go b/sq/commentTrimmer_test.go new file mode 100644 index 0000000..cb13914 --- /dev/null +++ b/sq/commentTrimmer_test.go @@ -0,0 +1,74 @@ +package sq + +import ( + "context" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestFnTrimCommentsLineOnly(t *testing.T) { + // The line-only comment is replaced with an empty line. + sql := "SELECT *\n-- this is a comment\nFROM users" + pp := PP{} + err := fnTrimComments(context.Background(), "QUERY", nil, &sql, &pp) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, !strings.Contains(sql, "comment")) + tst.AssertTrue(t, strings.Contains(sql, "SELECT *")) + tst.AssertTrue(t, strings.Contains(sql, "FROM users")) +} + +func TestFnTrimCommentsTrailing(t *testing.T) { + sql := "SELECT * -- inline\nFROM users -- end" + pp := PP{} + err := fnTrimComments(context.Background(), "QUERY", nil, &sql, &pp) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "SELECT *\nFROM users", sql) +} + +func TestFnTrimCommentsIndented(t *testing.T) { + sql := "SELECT *\n -- indented comment\nFROM users" + pp := PP{} + err := fnTrimComments(context.Background(), "QUERY", nil, &sql, &pp) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, !strings.Contains(sql, "indented comment")) + tst.AssertTrue(t, strings.Contains(sql, "SELECT *")) + tst.AssertTrue(t, strings.Contains(sql, "FROM users")) +} + +func TestFnTrimCommentsTrimsTrailingWhitespace(t *testing.T) { + sql := "SELECT * \t\nFROM users " + pp := PP{} + err := fnTrimComments(context.Background(), "QUERY", nil, &sql, &pp) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "SELECT *\nFROM users", sql) +} + +func TestFnTrimCommentsNoComment(t *testing.T) { + sql := "SELECT id\nFROM users\nWHERE id=1" + pp := PP{} + err := fnTrimComments(context.Background(), "QUERY", nil, &sql, &pp) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "SELECT id\nFROM users\nWHERE id=1", sql) +} + +func TestFnTrimCommentsEmpty(t *testing.T) { + sql := "" + pp := PP{} + err := fnTrimComments(context.Background(), "QUERY", nil, &sql, &pp) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "", sql) +} + +func TestCommentTrimmerListener(t *testing.T) { + sql := "SELECT *\n-- a comment\nFROM x" + pp := PP{} + err := CommentTrimmer.PreQuery(context.Background(), nil, &sql, &pp, PreQueryMeta{}) + tst.AssertNoErr(t, err) + tst.AssertTrue(t, !strings.Contains(sql, "a comment")) + + sql2 := "INSERT INTO x VALUES (1) -- xx" + err = CommentTrimmer.PreExec(context.Background(), nil, &sql2, &pp, PreExecMeta{}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "INSERT INTO x VALUES (1)", sql2) +} diff --git a/sq/converterDefault_test.go b/sq/converterDefault_test.go new file mode 100644 index 0000000..37313a0 --- /dev/null +++ b/sq/converterDefault_test.go @@ -0,0 +1,199 @@ +package sq + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/rfctime" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" + "time" +) + +func TestConverterBoolToBit(t *testing.T) { + v, err := ConverterBoolToBit.ModelToDB(true) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(1), v.(int64)) + + v, err = ConverterBoolToBit.ModelToDB(false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(0), v.(int64)) + + v, err = ConverterBoolToBit.DBToModel(int64(1)) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, true, v.(bool)) + + v, err = ConverterBoolToBit.DBToModel(int64(0)) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, false, v.(bool)) + + _, err = ConverterBoolToBit.DBToModel(int64(2)) + if err == nil { + t.Fatal("expected error for value not in {0,1}") + } +} + +func TestConverterTimeToUnixMillis(t *testing.T) { + t0 := time.Date(2024, 6, 15, 12, 34, 56, int(789*time.Millisecond), time.UTC) + expected := t0.UnixMilli() + + v, err := ConverterTimeToUnixMillis.ModelToDB(t0) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(int64)) + + v, err = ConverterTimeToUnixMillis.DBToModel(expected) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(time.Time).UnixMilli()) +} + +func TestConverterRFCUnixMilliTime(t *testing.T) { + t0 := rfctime.NewUnixMilli(time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)) + expected := t0.UnixMilli() + + v, err := ConverterRFCUnixMilliTimeToUnixMillis.ModelToDB(t0) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(int64)) + + v, err = ConverterRFCUnixMilliTimeToUnixMillis.DBToModel(expected) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(rfctime.UnixMilliTime).UnixMilli()) +} + +func TestConverterRFCUnixNanoTime(t *testing.T) { + t0 := rfctime.NewUnixNano(time.Date(2020, 1, 2, 3, 4, 5, 123456789, time.UTC)) + expected := t0.UnixNano() + + v, err := ConverterRFCUnixNanoTimeToUnixNanos.ModelToDB(t0) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(int64)) + + v, err = ConverterRFCUnixNanoTimeToUnixNanos.DBToModel(expected) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(rfctime.UnixNanoTime).UnixNano()) +} + +func TestConverterRFCUnixTime(t *testing.T) { + t0 := rfctime.NewUnix(time.Date(2020, 1, 2, 3, 4, 5, 0, time.UTC)) + expected := t0.Unix() + + v, err := ConverterRFCUnixTimeToUnixSeconds.ModelToDB(t0) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(int64)) + + v, err = ConverterRFCUnixTimeToUnixSeconds.DBToModel(expected) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, expected, v.(rfctime.UnixTime).Unix()) +} + +func TestConverterRFC339Time(t *testing.T) { + t0 := rfctime.NewRFC3339(time.Date(2020, 6, 15, 9, 30, 45, 0, time.UTC)) + + v, err := ConverterRFC339TimeToString.ModelToDB(t0) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "2020-06-15 09:30:45", v.(string)) + + v, err = ConverterRFC339TimeToString.DBToModel("2020-06-15 09:30:45") + tst.AssertNoErr(t, err) + tt := v.(rfctime.RFC3339Time).Time().UTC() + tst.AssertEqual(t, 2020, tt.Year()) + tst.AssertEqual(t, time.June, tt.Month()) + tst.AssertEqual(t, 15, tt.Day()) + tst.AssertEqual(t, 9, tt.Hour()) + tst.AssertEqual(t, 30, tt.Minute()) + tst.AssertEqual(t, 45, tt.Second()) + + _, err = ConverterRFC339TimeToString.DBToModel("garbage") + if err == nil { + t.Fatal("expected parse error") + } +} + +func TestConverterRFC339NanoTime(t *testing.T) { + t0 := rfctime.NewRFC3339Nano(time.Date(2020, 6, 15, 9, 30, 45, 123456789, time.UTC)) + + v, err := ConverterRFC339NanoTimeToString.ModelToDB(t0) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "2020-06-15 09:30:45.123456789", v.(string)) + + v, err = ConverterRFC339NanoTimeToString.DBToModel("2020-06-15 09:30:45.123456789") + tst.AssertNoErr(t, err) + tt := v.(rfctime.RFC3339NanoTime).Time().UTC() + tst.AssertEqual(t, 123456789, tt.Nanosecond()) + + _, err = ConverterRFC339NanoTimeToString.DBToModel("not a date") + if err == nil { + t.Fatal("expected parse error") + } +} + +func TestConverterRFCDate(t *testing.T) { + d := rfctime.Date{Year: 2024, Month: 3, Day: 9} + v, err := ConverterRFCDateToString.ModelToDB(d) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "2024-03-09", v.(string)) + + v, err = ConverterRFCDateToString.DBToModel("2024-03-09") + tst.AssertNoErr(t, err) + d2 := v.(rfctime.Date) + tst.AssertEqual(t, 2024, d2.Year) + tst.AssertEqual(t, 3, d2.Month) + tst.AssertEqual(t, 9, d2.Day) + + _, err = ConverterRFCDateToString.DBToModel("invalid") + if err == nil { + t.Fatal("expected parse error") + } +} + +func TestConverterRFCTime(t *testing.T) { + tm := rfctime.NewTime(13, 30, 45, 0) + v, err := ConverterRFCTimeToString.ModelToDB(tm) + tst.AssertNoErr(t, err) + roundtrip, err := ConverterRFCTimeToString.DBToModel(v.(string)) + tst.AssertNoErr(t, err) + tm2 := roundtrip.(rfctime.Time) + tst.AssertEqual(t, 13, tm2.Hour) + tst.AssertEqual(t, 30, tm2.Minute) + tst.AssertEqual(t, 45, tm2.Second) + + _, err = ConverterRFCTimeToString.DBToModel("xx:xx:xx") + if err == nil { + t.Fatal("expected parse error") + } +} + +func TestConverterRFCSecondsF64(t *testing.T) { + d := 12*time.Second + 500*time.Millisecond + s := rfctime.NewSecondsF64(d) + v, err := ConverterRFCSecondsF64ToString.ModelToDB(s) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 12.5, v.(float64)) + + v, err = ConverterRFCSecondsF64ToString.DBToModel(12.5) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 12.5, v.(rfctime.SecondsF64).Seconds()) +} + +func TestConverterJsonObjToString(t *testing.T) { + tst.AssertEqual(t, "sq.JsonObj", ConverterJsonObjToString.ModelTypeString()) + tst.AssertEqual(t, "string", ConverterJsonObjToString.DBTypeString()) + + v, err := ConverterJsonObjToString.ModelToDB(JsonObj{"x": float64(1)}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, `{"x":1}`, v.(string)) + + r, err := ConverterJsonObjToString.DBToModel(`{"x":1}`) + tst.AssertNoErr(t, err) + tst.AssertStrRepEqual(t, r.(JsonObj)["x"], float64(1)) +} + +func TestConverterJsonArrToString(t *testing.T) { + tst.AssertEqual(t, "sq.JsonArr", ConverterJsonArrToString.ModelTypeString()) + tst.AssertEqual(t, "string", ConverterJsonArrToString.DBTypeString()) + + v, err := ConverterJsonArrToString.ModelToDB(JsonArr{float64(1), float64(2)}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, `[1,2]`, v.(string)) + + r, err := ConverterJsonArrToString.DBToModel(`[1,2]`) + tst.AssertNoErr(t, err) + arr := r.(JsonArr) + tst.AssertEqual(t, 2, len(arr)) +} diff --git a/sq/converter_test.go b/sq/converter_test.go new file mode 100644 index 0000000..4a591a8 --- /dev/null +++ b/sq/converter_test.go @@ -0,0 +1,154 @@ +package sq + +import ( + "context" + "database/sql" + "errors" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "github.com/jmoiron/sqlx" + "testing" +) + +type fakeQueryable struct { + converters []DBTypeConverter +} + +func (f fakeQueryable) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) { + return nil, errors.New("not implemented") +} + +func (f fakeQueryable) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) { + return nil, errors.New("not implemented") +} + +func (f fakeQueryable) ListConverter() []DBTypeConverter { + return f.converters +} + +func TestNewDBTypeConverterTypeStrings(t *testing.T) { + conv := NewDBTypeConverter(func(v bool) (int64, error) { return 0, nil }, func(v int64) (bool, error) { return false, nil }) + tst.AssertEqual(t, "bool", conv.ModelTypeString()) + tst.AssertEqual(t, "int64", conv.DBTypeString()) +} + +func TestNewDBTypeConverterModelToDB(t *testing.T) { + conv := NewDBTypeConverter(func(v bool) (int64, error) { + if v { + return 1, nil + } + return 0, nil + }, func(v int64) (bool, error) { + return v != 0, nil + }) + + r, err := conv.ModelToDB(true) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(1), r.(int64)) + + r, err = conv.ModelToDB(false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(0), r.(int64)) +} + +func TestNewDBTypeConverterDBToModel(t *testing.T) { + conv := NewDBTypeConverter(func(v bool) (int64, error) { return 0, nil }, func(v int64) (bool, error) { + return v != 0, nil + }) + + r, err := conv.DBToModel(int64(1)) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, true, r.(bool)) + + r, err = conv.DBToModel(int64(0)) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, false, r.(bool)) +} + +func TestNewDBTypeConverterTypeMismatch(t *testing.T) { + conv := NewDBTypeConverter(func(v bool) (int64, error) { return 0, nil }, func(v int64) (bool, error) { return false, nil }) + + _, err := conv.ModelToDB("not a bool") + if err == nil { + t.Fatal("expected error on type mismatch in ModelToDB") + } + + _, err = conv.DBToModel("not int64") + if err == nil { + t.Fatal("expected error on type mismatch in DBToModel") + } +} + +func TestNewAutoDBTypeConverter(t *testing.T) { + conv := NewAutoDBTypeConverter(JsonObj{}) + + tst.AssertEqual(t, "sq.JsonObj", conv.ModelTypeString()) + tst.AssertEqual(t, "string", conv.DBTypeString()) + + r, err := conv.ModelToDB(JsonObj{"k": "v"}) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, `{"k":"v"}`, r.(string)) + + r, err = conv.DBToModel(`{"k":"v"}`) + tst.AssertNoErr(t, err) + parsed, ok := r.(JsonObj) + tst.AssertTrue(t, ok) + tst.AssertStrRepEqual(t, parsed["k"], "v") +} + +func TestConvertValueToDBNoConverter(t *testing.T) { + q := fakeQueryable{} + + r, err := convertValueToDB(q, "hello") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "hello", r.(string)) + + r, err = convertValueToDB(q, 42) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 42, r.(int)) +} + +func TestConvertValueToDBWithConverter(t *testing.T) { + q := fakeQueryable{converters: []DBTypeConverter{ConverterBoolToBit}} + + r, err := convertValueToDB(q, true) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(1), r.(int64)) + + r, err = convertValueToDB(q, false) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(0), r.(int64)) +} + +func TestConvertValueToDBNilPointer(t *testing.T) { + q := fakeQueryable{} + + var s *string + r, err := convertValueToDB(q, s) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, true, r == nil) +} + +func TestConvertValueToDBNonNilPointer(t *testing.T) { + q := fakeQueryable{converters: []DBTypeConverter{ConverterBoolToBit}} + + v := true + r, err := convertValueToDB(q, &v) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, int64(1), r.(int64)) +} + +func TestConvertValueToModelNoConverter(t *testing.T) { + q := fakeQueryable{} + + r, err := convertValueToModel(q, "hello", "string") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, "hello", r.(string)) +} + +func TestConvertValueToModelWithConverter(t *testing.T) { + q := fakeQueryable{converters: []DBTypeConverter{ConverterBoolToBit}} + + r, err := convertValueToModel(q, int64(1), "bool") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, true, r.(bool)) +} diff --git a/sq/filter_test.go b/sq/filter_test.go new file mode 100644 index 0000000..29542a9 --- /dev/null +++ b/sq/filter_test.go @@ -0,0 +1,70 @@ +package sq + +import ( + ct "git.blackforestbytes.com/BlackForestBytes/goext/cursortoken" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestNewEmptyPaginateFilter(t *testing.T) { + f := NewEmptyPaginateFilter() + + pp := PP{} + flt, join, joinTbl := f.SQL(pp) + tst.AssertEqual(t, "1=1", flt) + tst.AssertEqual(t, "", join) + tst.AssertEqual(t, 0, len(joinTbl)) + tst.AssertEqual(t, 0, len(pp)) + + tst.AssertEqual(t, 0, len(f.Sort())) +} + +func TestNewSimplePaginateFilter(t *testing.T) { + sortOrder := []FilterSort{ + {Field: "name", Direction: ct.SortASC}, + } + filterParams := PP{"p1": "value1"} + + f := NewSimplePaginateFilter("name = :p1", filterParams, sortOrder) + + pp := PP{} + flt, join, joinTbl := f.SQL(pp) + tst.AssertEqual(t, "name = :p1", flt) + tst.AssertEqual(t, "", join) + tst.AssertEqual(t, 0, len(joinTbl)) + // filterParams should be merged into pp + tst.AssertEqual(t, 1, len(pp)) + tst.AssertEqual(t, "value1", pp["p1"]) + + srt := f.Sort() + tst.AssertEqual(t, 1, len(srt)) + tst.AssertEqual(t, "name", srt[0].Field) + tst.AssertEqual(t, ct.SortASC, srt[0].Direction) +} + +func TestNewPaginateFilter(t *testing.T) { + sortOrder := []FilterSort{ + {Field: "id", Direction: ct.SortDESC}, + } + + called := 0 + f := NewPaginateFilter(func(params PP) (string, string, []string) { + called++ + params.Add("hello") + return "id > 0", "JOIN other ON other.id = main.id", []string{"other"} + }, sortOrder) + + pp := PP{} + flt, join, joinTbl := f.SQL(pp) + tst.AssertEqual(t, "id > 0", flt) + tst.AssertEqual(t, "JOIN other ON other.id = main.id", join) + tst.AssertEqual(t, 1, len(joinTbl)) + tst.AssertEqual(t, "other", joinTbl[0]) + tst.AssertEqual(t, 1, called) + tst.AssertEqual(t, 1, len(pp)) + + srt := f.Sort() + tst.AssertEqual(t, 1, len(srt)) + tst.AssertEqual(t, "id", srt[0].Field) + tst.AssertEqual(t, ct.SortDESC, srt[0].Direction) +} diff --git a/sq/json_test.go b/sq/json_test.go new file mode 100644 index 0000000..a80795d --- /dev/null +++ b/sq/json_test.go @@ -0,0 +1,85 @@ +package sq + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestJsonObjMarshalToDB(t *testing.T) { + j := JsonObj{} + out, err := j.MarshalToDB(JsonObj{"key": "value", "num": float64(7)}) + tst.AssertNoErr(t, err) + + // JSON map ordering is not guaranteed - parse and verify + roundtrip, err := j.UnmarshalToModel(out) + tst.AssertNoErr(t, err) + tst.AssertStrRepEqual(t, roundtrip["key"], "value") + tst.AssertStrRepEqual(t, roundtrip["num"], float64(7)) +} + +func TestJsonObjUnmarshalToModelInvalid(t *testing.T) { + j := JsonObj{} + _, err := j.UnmarshalToModel("{not valid json}") + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestJsonArrMarshalToDB(t *testing.T) { + j := JsonArr{} + out, err := j.MarshalToDB(JsonArr{float64(1), "two", true}) + tst.AssertNoErr(t, err) + + tst.AssertEqual(t, `[1,"two",true]`, out) +} + +func TestJsonArrRoundtrip(t *testing.T) { + j := JsonArr{} + out, err := j.MarshalToDB(JsonArr{"a", "b", "c"}) + tst.AssertNoErr(t, err) + + r, err := j.UnmarshalToModel(out) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 3, len(r)) + tst.AssertEqual(t, "a", r[0]) + tst.AssertEqual(t, "b", r[1]) + tst.AssertEqual(t, "c", r[2]) +} + +func TestJsonArrUnmarshalInvalid(t *testing.T) { + j := JsonArr{} + _, err := j.UnmarshalToModel("not json") + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestAutoJsonRoundtrip(t *testing.T) { + type inner struct { + A int `json:"a"` + B string `json:"b"` + } + + aj := AutoJson[inner]{Value: inner{A: 42, B: "foo"}} + zero := AutoJson[inner]{} + + out, err := zero.MarshalToDB(aj) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, `{"a":42,"b":"foo"}`, out) + + r, err := zero.UnmarshalToModel(out) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, 42, r.Value.A) + tst.AssertEqual(t, "foo", r.Value.B) +} + +func TestAutoJsonUnmarshalInvalid(t *testing.T) { + type inner struct { + A int `json:"a"` + } + zero := AutoJson[inner]{} + _, err := zero.UnmarshalToModel("garbage") + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} diff --git a/sq/listener_test.go b/sq/listener_test.go new file mode 100644 index 0000000..fcb3030 --- /dev/null +++ b/sq/listener_test.go @@ -0,0 +1,208 @@ +package sq + +import ( + "context" + "errors" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestNewPrePingListener(t *testing.T) { + called := 0 + expected := errors.New("ping err") + l := NewPrePingListener(func(ctx context.Context, meta PrePingMeta) error { + called++ + return expected + }) + + err := l.PrePing(context.Background(), PrePingMeta{}) + tst.AssertEqual(t, 1, called) + tst.AssertTrue(t, errors.Is(err, expected)) + + // Other handlers must not error / panic + tst.AssertNoErr(t, l.PreTxBegin(context.Background(), 0, PreTxBeginMeta{})) + tst.AssertNoErr(t, l.PreTxCommit(0, PreTxCommitMeta{})) + tst.AssertNoErr(t, l.PreTxRollback(0, PreTxRollbackMeta{})) + tst.AssertNoErr(t, l.PreQuery(context.Background(), nil, nil, nil, PreQueryMeta{})) + tst.AssertNoErr(t, l.PreExec(context.Background(), nil, nil, nil, PreExecMeta{})) + + // Post variants must not panic + l.PostPing(nil, PostPingMeta{}) + l.PostTxBegin(0, nil, PostTxBeginMeta{}) + l.PostTxCommit(0, nil, PostTxCommitMeta{}) + l.PostTxRollback(0, nil, PostTxRollbackMeta{}) + l.PostQuery(nil, "", "", PP{}, nil, PostQueryMeta{}) + l.PostExec(nil, "", "", PP{}, nil, PostExecMeta{}) +} + +func TestNewPreTxBeginListener(t *testing.T) { + called := 0 + l := NewPreTxBeginListener(func(ctx context.Context, txid uint16, meta PreTxBeginMeta) error { + called++ + tst.AssertEqual(t, uint16(7), txid) + return nil + }) + tst.AssertNoErr(t, l.PreTxBegin(context.Background(), 7, PreTxBeginMeta{})) + tst.AssertEqual(t, 1, called) +} + +func TestNewPreTxCommitListener(t *testing.T) { + called := 0 + l := NewPreTxCommitListener(func(txid uint16, meta PreTxCommitMeta) error { + called++ + return nil + }) + tst.AssertNoErr(t, l.PreTxCommit(0, PreTxCommitMeta{})) + tst.AssertEqual(t, 1, called) +} + +func TestNewPreTxRollbackListener(t *testing.T) { + called := 0 + l := NewPreTxRollbackListener(func(txid uint16, meta PreTxRollbackMeta) error { + called++ + return nil + }) + tst.AssertNoErr(t, l.PreTxRollback(0, PreTxRollbackMeta{})) + tst.AssertEqual(t, 1, called) +} + +func TestNewPreQueryListener(t *testing.T) { + called := 0 + l := NewPreQueryListener(func(ctx context.Context, txID *uint16, sql *string, params *PP, meta PreQueryMeta) error { + called++ + *sql = "modified" + return nil + }) + + sql := "original" + pp := PP{} + tst.AssertNoErr(t, l.PreQuery(context.Background(), nil, &sql, &pp, PreQueryMeta{})) + tst.AssertEqual(t, 1, called) + tst.AssertEqual(t, "modified", sql) +} + +func TestNewPreExecListener(t *testing.T) { + called := 0 + l := NewPreExecListener(func(ctx context.Context, txID *uint16, sql *string, params *PP, meta PreExecMeta) error { + called++ + return nil + }) + sql := "x" + pp := PP{} + tst.AssertNoErr(t, l.PreExec(context.Background(), nil, &sql, &pp, PreExecMeta{})) + tst.AssertEqual(t, 1, called) +} + +func TestNewPreListenerBoth(t *testing.T) { + queryCalls := 0 + execCalls := 0 + l := NewPreListener(func(ctx context.Context, cmdtype string, txID *uint16, sql *string, params *PP) error { + switch cmdtype { + case "QUERY": + queryCalls++ + case "EXEC": + execCalls++ + } + return nil + }) + + sql := "s" + pp := PP{} + tst.AssertNoErr(t, l.PreQuery(context.Background(), nil, &sql, &pp, PreQueryMeta{})) + tst.AssertNoErr(t, l.PreExec(context.Background(), nil, &sql, &pp, PreExecMeta{})) + + tst.AssertEqual(t, 1, queryCalls) + tst.AssertEqual(t, 1, execCalls) +} + +func TestNewPostPingListener(t *testing.T) { + called := 0 + l := NewPostPingListener(func(result error, meta PostPingMeta) { + called++ + }) + l.PostPing(nil, PostPingMeta{}) + tst.AssertEqual(t, 1, called) +} + +func TestNewPostTxBeginListener(t *testing.T) { + called := 0 + l := NewPostTxBeginListener(func(txid uint16, result error, meta PostTxBeginMeta) { + called++ + }) + l.PostTxBegin(0, nil, PostTxBeginMeta{}) + tst.AssertEqual(t, 1, called) +} + +func TestNewPostTxCommitListener(t *testing.T) { + called := 0 + l := NewPostTxCommitListener(func(txid uint16, result error, meta PostTxCommitMeta) { + called++ + }) + l.PostTxCommit(0, nil, PostTxCommitMeta{}) + tst.AssertEqual(t, 1, called) +} + +func TestNewPostTxRollbackListener(t *testing.T) { + called := 0 + l := NewPostTxRollbackListener(func(txid uint16, result error, meta PostTxRollbackMeta) { + called++ + }) + l.PostTxRollback(0, nil, PostTxRollbackMeta{}) + tst.AssertEqual(t, 1, called) +} + +func TestNewPostQueryListener(t *testing.T) { + called := 0 + l := NewPostQueryListener(func(txID *uint16, sqlOriginal string, sqlReal string, params PP, result error, meta PostQueryMeta) { + called++ + }) + l.PostQuery(nil, "", "", PP{}, nil, PostQueryMeta{}) + tst.AssertEqual(t, 1, called) +} + +func TestNewPostExecListener(t *testing.T) { + called := 0 + l := NewPostExecListener(func(txID *uint16, sqlOriginal string, sqlReal string, params PP, result error, meta PostExecMeta) { + called++ + }) + l.PostExec(nil, "", "", PP{}, nil, PostExecMeta{}) + tst.AssertEqual(t, 1, called) +} + +func TestNewPostListenerBoth(t *testing.T) { + queryCalls := 0 + execCalls := 0 + l := NewPostListener(func(cmdtype string, txID *uint16, sqlOriginal string, sqlReal string, result error, params PP) { + switch cmdtype { + case "QUERY": + queryCalls++ + case "EXEC": + execCalls++ + } + }) + l.PostQuery(nil, "", "", PP{}, nil, PostQueryMeta{}) + l.PostExec(nil, "", "", PP{}, nil, PostExecMeta{}) + tst.AssertEqual(t, 1, queryCalls) + tst.AssertEqual(t, 1, execCalls) +} + +func TestGenListenerNilHandlersDontPanic(t *testing.T) { + // A listener constructed with one constructor only sets one handler. + // Calls to other handlers should be safe no-ops. + l := NewPostPingListener(func(result error, meta PostPingMeta) {}) + + // All Pre* return nil + tst.AssertNoErr(t, l.PrePing(context.Background(), PrePingMeta{})) + tst.AssertNoErr(t, l.PreTxBegin(context.Background(), 0, PreTxBeginMeta{})) + tst.AssertNoErr(t, l.PreTxCommit(0, PreTxCommitMeta{})) + tst.AssertNoErr(t, l.PreTxRollback(0, PreTxRollbackMeta{})) + tst.AssertNoErr(t, l.PreQuery(context.Background(), nil, nil, nil, PreQueryMeta{})) + tst.AssertNoErr(t, l.PreExec(context.Background(), nil, nil, nil, PreExecMeta{})) + + // All Post* are no-ops + l.PostTxBegin(0, nil, PostTxBeginMeta{}) + l.PostTxCommit(0, nil, PostTxCommitMeta{}) + l.PostTxRollback(0, nil, PostTxRollbackMeta{}) + l.PostQuery(nil, "", "", PP{}, nil, PostQueryMeta{}) + l.PostExec(nil, "", "", PP{}, nil, PostExecMeta{}) +} diff --git a/sq/params_test.go b/sq/params_test.go new file mode 100644 index 0000000..7ff766a --- /dev/null +++ b/sq/params_test.go @@ -0,0 +1,88 @@ +package sq + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +func TestPPID(t *testing.T) { + id := PPID() + tst.AssertTrue(t, strings.HasPrefix(id, "p_")) + if len(id) != 2+8 { + t.Errorf("expected length 10, got %d (id=%q)", len(id), id) + } + + // uniqueness - very high probability with 8 base62 chars + seen := map[string]bool{} + for range 1000 { + x := PPID() + if seen[x] { + t.Errorf("duplicate PPID: %s", x) + } + seen[x] = true + } +} + +func TestPPAdd(t *testing.T) { + pp := PP{} + id1 := pp.Add(123) + id2 := pp.Add("hello") + + tst.AssertNotEqual(t, id1, id2) + tst.AssertEqual(t, 123, pp[id1]) + tst.AssertEqual(t, "hello", pp[id2]) + tst.AssertEqual(t, 2, len(pp)) +} + +func TestPPAddAll(t *testing.T) { + a := PP{"a": 1, "b": 2} + b := PP{"c": 3, "d": 4} + a.AddAll(b) + + tst.AssertEqual(t, 4, len(a)) + tst.AssertEqual(t, 1, a["a"]) + tst.AssertEqual(t, 2, a["b"]) + tst.AssertEqual(t, 3, a["c"]) + tst.AssertEqual(t, 4, a["d"]) +} + +func TestPPAddAllOverwrite(t *testing.T) { + a := PP{"a": 1, "b": 2} + b := PP{"a": 99} + a.AddAll(b) + + tst.AssertEqual(t, 2, len(a)) + tst.AssertEqual(t, 99, a["a"]) + tst.AssertEqual(t, 2, a["b"]) +} + +func TestJoin(t *testing.T) { + a := PP{"a": 1, "b": 2} + b := PP{"c": 3} + c := PP{"d": 4, "a": 99} + + r := Join(a, b, c) + + tst.AssertEqual(t, 4, len(r)) + tst.AssertEqual(t, 99, r["a"]) + tst.AssertEqual(t, 2, r["b"]) + tst.AssertEqual(t, 3, r["c"]) + tst.AssertEqual(t, 4, r["d"]) + + // Source maps must remain unchanged + tst.AssertEqual(t, 2, len(a)) + tst.AssertEqual(t, 1, a["a"]) +} + +func TestJoinEmpty(t *testing.T) { + r := Join() + tst.AssertEqual(t, 0, len(r)) +} + +func TestJoinSingle(t *testing.T) { + a := PP{"a": 1} + r := Join(a) + tst.AssertEqual(t, 1, len(r)) + tst.AssertEqual(t, 1, r["a"]) +} diff --git a/syncext/atomic_test.go b/syncext/atomic_test.go new file mode 100644 index 0000000..a286875 --- /dev/null +++ b/syncext/atomic_test.go @@ -0,0 +1,213 @@ +package syncext + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestAtomicGetSet(t *testing.T) { + a := NewAtomic(42) + + if v := a.Get(); v != 42 { + t.Errorf("expected 42, got %d", v) + } + + old := a.Set(100) + if old != 42 { + t.Errorf("expected old value 42, got %d", old) + } + + if v := a.Get(); v != 100 { + t.Errorf("expected 100, got %d", v) + } +} + +func TestAtomicGetSetString(t *testing.T) { + a := NewAtomic("hello") + + if v := a.Get(); v != "hello" { + t.Errorf("expected 'hello', got %q", v) + } + + old := a.Set("world") + if old != "hello" { + t.Errorf("expected old value 'hello', got %q", old) + } + + if v := a.Get(); v != "world" { + t.Errorf("expected 'world', got %q", v) + } +} + +func TestAtomicUpdate(t *testing.T) { + a := NewAtomic(10) + + a.Update(func(old int) int { + return old * 2 + }) + + if v := a.Get(); v != 20 { + t.Errorf("expected 20, got %d", v) + } + + a.Update(func(old int) int { + return old + 5 + }) + + if v := a.Get(); v != 25 { + t.Errorf("expected 25, got %d", v) + } +} + +func TestAtomicCompareAndSwap(t *testing.T) { + a := NewAtomic(5) + + if !a.CompareAndSwap(5, 10) { + t.Error("CAS should have succeeded") + } + if v := a.Get(); v != 10 { + t.Errorf("expected 10, got %d", v) + } + + if a.CompareAndSwap(5, 20) { + t.Error("CAS should have failed") + } + if v := a.Get(); v != 10 { + t.Errorf("expected 10, got %d", v) + } +} + +func TestAtomicWaitAlreadyMatching(t *testing.T) { + a := NewAtomic(7) + + done := make(chan struct{}) + go func() { + a.Wait(7) + close(done) + }() + + select { + case <-done: + // ok + case <-time.After(500 * time.Millisecond): + t.Error("Wait should return immediately if value already matches") + } +} + +func TestAtomicWaitWithTimeoutNoMatch(t *testing.T) { + a := NewAtomic(1) + + err := a.WaitWithTimeout(50*time.Millisecond, 999) + if err == nil { + t.Error("expected timeout error") + } +} + +func TestAtomicWaitWithTimeoutMatchAfterSet(t *testing.T) { + a := NewAtomic(1) + + go func() { + time.Sleep(20 * time.Millisecond) + a.Set(99) + }() + + err := a.WaitWithTimeout(500*time.Millisecond, 99) + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestAtomicWaitWithContextCancel(t *testing.T) { + a := NewAtomic(1) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + err := a.WaitWithContext(ctx, 999) + if err == nil { + t.Error("expected ctx error") + } +} + +func TestAtomicWaitWithContextAlreadyCancelled(t *testing.T) { + a := NewAtomic(1) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := a.WaitWithContext(ctx, 1) + if err == nil { + t.Error("expected ctx error") + } +} + +func TestAtomicWaitForChange(t *testing.T) { + a := NewAtomic(1) + + ch := a.WaitForChange() + + go func() { + time.Sleep(20 * time.Millisecond) + a.Set(2) + }() + + select { + case v := <-ch: + if v != 2 { + t.Errorf("expected 2, got %d", v) + } + case <-time.After(500 * time.Millisecond): + t.Error("WaitForChange did not deliver") + } +} + +func TestAtomicConcurrentSet(t *testing.T) { + a := NewAtomic(0) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(v int) { + defer wg.Done() + a.Set(v) + }(i) + } + wg.Wait() + + v := a.Get() + if v < 0 || v >= 50 { + t.Errorf("unexpected final value %d", v) + } +} + +func TestAtomicConcurrentUpdate(t *testing.T) { + a := NewAtomic(0) + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a.Update(func(old int) int { return old + 1 }) + }() + } + wg.Wait() + + if v := a.Get(); v != 100 { + t.Errorf("expected 100, got %d", v) + } +} + +func TestAtomicWaitWithTimeoutZero(t *testing.T) { + a := NewAtomic(1) + + err := a.WaitWithTimeout(0, 999) + if err == nil { + t.Error("expected error for zero timeout with non-matching value") + } +} diff --git a/syncext/bool_test.go b/syncext/bool_test.go new file mode 100644 index 0000000..d6569ac --- /dev/null +++ b/syncext/bool_test.go @@ -0,0 +1,124 @@ +package syncext + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestAtomicBoolGetSet(t *testing.T) { + b := NewAtomicBool(false) + + if b.Get() { + t.Error("expected false") + } + + old := b.Set(true) + if old { + t.Error("expected old value false") + } + + if !b.Get() { + t.Error("expected true") + } + + old = b.Set(false) + if !old { + t.Error("expected old value true") + } +} + +func TestAtomicBoolWaitAlreadyMatching(t *testing.T) { + b := NewAtomicBool(true) + + done := make(chan struct{}) + go func() { + b.Wait(true) + close(done) + }() + + select { + case <-done: + // ok + case <-time.After(500 * time.Millisecond): + t.Error("Wait should return immediately if value already matches") + } +} + +func TestAtomicBoolWaitWithTimeoutNoMatch(t *testing.T) { + b := NewAtomicBool(false) + + err := b.WaitWithTimeout(50*time.Millisecond, true) + if err == nil { + t.Error("expected timeout error") + } +} + +func TestAtomicBoolWaitWithTimeoutMatchAfterSet(t *testing.T) { + b := NewAtomicBool(false) + + go func() { + time.Sleep(20 * time.Millisecond) + b.Set(true) + }() + + err := b.WaitWithTimeout(500*time.Millisecond, true) + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestAtomicBoolWaitWithContextCancel(t *testing.T) { + b := NewAtomicBool(false) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + err := b.WaitWithContext(ctx, true) + if err == nil { + t.Error("expected ctx error") + } +} + +func TestAtomicBoolWaitWithContextAlreadyCancelled(t *testing.T) { + b := NewAtomicBool(false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := b.WaitWithContext(ctx, false) + if err == nil { + t.Error("expected ctx error") + } +} + +func TestAtomicBoolWaitWithContextMatching(t *testing.T) { + b := NewAtomicBool(true) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := b.WaitWithContext(ctx, true) + if err != nil { + t.Errorf("expected nil, got %v", err) + } +} + +func TestAtomicBoolConcurrentSet(t *testing.T) { + b := NewAtomicBool(false) + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func(v bool) { + defer wg.Done() + b.Set(v) + }(i%2 == 0) + } + wg.Wait() +} + diff --git a/syncext/channel_extra_test.go b/syncext/channel_extra_test.go new file mode 100644 index 0000000..adb1906 --- /dev/null +++ b/syncext/channel_extra_test.go @@ -0,0 +1,155 @@ +package syncext + +import ( + "context" + "testing" + "time" +) + +func TestWriteChannelWithTimeoutSuccess(t *testing.T) { + c := make(chan int, 1) + + ok := WriteChannelWithTimeout(c, 42, 100*time.Millisecond) + if !ok { + t.Error("expected write to succeed") + } + + select { + case v := <-c: + if v != 42 { + t.Errorf("expected 42, got %d", v) + } + default: + t.Error("no value received") + } +} + +func TestWriteChannelWithTimeoutFull(t *testing.T) { + c := make(chan int, 1) + c <- 1 + + ok := WriteChannelWithTimeout(c, 2, 50*time.Millisecond) + if ok { + t.Error("expected write to timeout") + } +} + +func TestWriteChannelWithTimeoutUnbuffered(t *testing.T) { + c := make(chan int) + + go func() { + time.Sleep(10 * time.Millisecond) + <-c + }() + + ok := WriteChannelWithTimeout(c, 99, 200*time.Millisecond) + if !ok { + t.Error("expected write to succeed") + } +} + +func TestWriteChannelWithTimeoutUnbufferedTimeout(t *testing.T) { + c := make(chan int) + + ok := WriteChannelWithTimeout(c, 99, 50*time.Millisecond) + if ok { + t.Error("expected timeout") + } +} + +func TestWriteChannelWithContextSuccess(t *testing.T) { + c := make(chan int, 1) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := WriteChannelWithContext(ctx, c, 7) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if v := <-c; v != 7 { + t.Errorf("expected 7, got %d", v) + } +} + +func TestWriteChannelWithContextCancel(t *testing.T) { + c := make(chan int) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + err := WriteChannelWithContext(ctx, c, 7) + if err == nil { + t.Error("expected ctx error") + } +} + +func TestWriteChannelWithContextAlreadyCancelled(t *testing.T) { + c := make(chan int) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := WriteChannelWithContext(ctx, c, 7) + if err == nil { + t.Error("expected ctx error") + } +} + +func TestReadNonBlockingEmpty(t *testing.T) { + c := make(chan int, 1) + + _, ok := ReadNonBlocking(c) + if ok { + t.Error("expected non-blocking read to return false on empty channel") + } +} + +func TestReadNonBlockingHasValue(t *testing.T) { + c := make(chan int, 1) + c <- 55 + + v, ok := ReadNonBlocking(c) + if !ok { + t.Error("expected non-blocking read to return true") + } + if v != 55 { + t.Errorf("expected 55, got %d", v) + } +} + +func TestWriteNonBlockingSuccess(t *testing.T) { + c := make(chan int, 1) + + ok := WriteNonBlocking(c, 33) + if !ok { + t.Error("expected non-blocking write to succeed") + } + + if v := <-c; v != 33 { + t.Errorf("expected 33, got %d", v) + } +} + +func TestWriteNonBlockingFull(t *testing.T) { + c := make(chan int, 1) + c <- 1 + + ok := WriteNonBlocking(c, 2) + if ok { + t.Error("expected non-blocking write to fail when full") + } +} + +func TestWriteNonBlockingUnbufferedNoReceiver(t *testing.T) { + c := make(chan int) + + ok := WriteNonBlocking(c, 1) + if ok { + t.Error("expected non-blocking write to fail without receiver") + } +} diff --git a/termext/colors_test.go b/termext/colors_test.go new file mode 100644 index 0000000..60ee203 --- /dev/null +++ b/termext/colors_test.go @@ -0,0 +1,166 @@ +package termext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "strings" + "testing" +) + +const ( + resetSeq = "" + redSeq = "" + greenSeq = "" + yellowSeq = "" + blueSeq = "" + purpleSeq = "" + cyanSeq = "" + graySeq = "" + whiteSeq = "" +) + +func TestRedEmpty(t *testing.T) { + tst.AssertEqual(t, Red(""), redSeq+resetSeq) +} + +func TestGreenEmpty(t *testing.T) { + tst.AssertEqual(t, Green(""), greenSeq+resetSeq) +} + +func TestYellowEmpty(t *testing.T) { + tst.AssertEqual(t, Yellow(""), yellowSeq+resetSeq) +} + +func TestBlueEmpty(t *testing.T) { + tst.AssertEqual(t, Blue(""), blueSeq+resetSeq) +} + +func TestPurpleEmpty(t *testing.T) { + tst.AssertEqual(t, Purple(""), purpleSeq+resetSeq) +} + +func TestCyanEmpty(t *testing.T) { + tst.AssertEqual(t, Cyan(""), cyanSeq+resetSeq) +} + +func TestGrayEmpty(t *testing.T) { + tst.AssertEqual(t, Gray(""), graySeq+resetSeq) +} + +func TestWhiteEmpty(t *testing.T) { + tst.AssertEqual(t, White(""), whiteSeq+resetSeq) +} + +func TestColorsContainOriginalString(t *testing.T) { + input := "hello world" + tst.AssertTrue(t, strings.Contains(Red(input), input)) + tst.AssertTrue(t, strings.Contains(Green(input), input)) + tst.AssertTrue(t, strings.Contains(Yellow(input), input)) + tst.AssertTrue(t, strings.Contains(Blue(input), input)) + tst.AssertTrue(t, strings.Contains(Purple(input), input)) + tst.AssertTrue(t, strings.Contains(Cyan(input), input)) + tst.AssertTrue(t, strings.Contains(Gray(input), input)) + tst.AssertTrue(t, strings.Contains(White(input), input)) +} + +func TestColorsEndWithReset(t *testing.T) { + tst.AssertTrue(t, strings.HasSuffix(Red("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(Green("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(Yellow("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(Blue("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(Purple("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(Cyan("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(Gray("x"), resetSeq)) + tst.AssertTrue(t, strings.HasSuffix(White("x"), resetSeq)) +} + +func TestColorsStartWithCorrectSequence(t *testing.T) { + tst.AssertTrue(t, strings.HasPrefix(Red("x"), redSeq)) + tst.AssertTrue(t, strings.HasPrefix(Green("x"), greenSeq)) + tst.AssertTrue(t, strings.HasPrefix(Yellow("x"), yellowSeq)) + tst.AssertTrue(t, strings.HasPrefix(Blue("x"), blueSeq)) + tst.AssertTrue(t, strings.HasPrefix(Purple("x"), purpleSeq)) + tst.AssertTrue(t, strings.HasPrefix(Cyan("x"), cyanSeq)) + tst.AssertTrue(t, strings.HasPrefix(Gray("x"), graySeq)) + tst.AssertTrue(t, strings.HasPrefix(White("x"), whiteSeq)) +} + +func TestColorsAreDistinct(t *testing.T) { + input := "value" + results := []string{ + Red(input), + Green(input), + Yellow(input), + Blue(input), + Purple(input), + Cyan(input), + Gray(input), + White(input), + } + for i := 0; i < len(results); i++ { + for j := i + 1; j < len(results); j++ { + tst.AssertNotEqual(t, results[i], results[j]) + } + } +} + +func TestCleanStringEmpty(t *testing.T) { + tst.AssertEqual(t, CleanString(""), "") +} + +func TestCleanStringWithoutColors(t *testing.T) { + input := "plain text without any colors" + tst.AssertEqual(t, CleanString(input), input) +} + +func TestCleanStringMultipleColors(t *testing.T) { + input := Red("foo") + " " + Green("bar") + " " + Blue("baz") + tst.AssertEqual(t, CleanString(input), "foo bar baz") +} + +func TestCleanStringNested(t *testing.T) { + input := Red(Green("inner")) + tst.AssertEqual(t, CleanString(input), "inner") +} + +func TestCleanStringIdempotent(t *testing.T) { + input := Yellow("hello") + Purple("world") + cleaned := CleanString(input) + tst.AssertEqual(t, CleanString(cleaned), cleaned) +} + +func TestCleanStringEmptyColorWraps(t *testing.T) { + tst.AssertEqual(t, CleanString(Red("")), "") + tst.AssertEqual(t, CleanString(Green("")), "") + tst.AssertEqual(t, CleanString(Yellow("")), "") + tst.AssertEqual(t, CleanString(Blue("")), "") + tst.AssertEqual(t, CleanString(Purple("")), "") + tst.AssertEqual(t, CleanString(Cyan("")), "") + tst.AssertEqual(t, CleanString(Gray("")), "") + tst.AssertEqual(t, CleanString(White("")), "") +} + +func TestCleanStringPreservesNonAnsiContent(t *testing.T) { + input := "before " + Red("middle") + " after\nnewline\ttab" + expected := "before middle after\nnewline\ttab" + tst.AssertEqual(t, CleanString(input), expected) +} + +func TestCleanStringRemovesBareResetSequence(t *testing.T) { + input := "abc" + resetSeq + "def" + tst.AssertEqual(t, CleanString(input), "abcdef") +} + +func TestCleanStringUnicode(t *testing.T) { + input := Red("héllo wörld 你好 🌍") + tst.AssertEqual(t, CleanString(input), "héllo wörld 你好 🌍") +} + +func TestColorRoundTrip(t *testing.T) { + cases := []string{"", "x", "hello", "multi\nline", "with spaces", "héllo", "🌈"} + wrappers := []func(string) string{Red, Green, Yellow, Blue, Purple, Cyan, Gray, White} + for _, c := range cases { + for _, w := range wrappers { + tst.AssertEqual(t, CleanString(w(c)), c) + } + } +} diff --git a/termext/supportscolors_test.go b/termext/supportscolors_test.go new file mode 100644 index 0000000..711d497 --- /dev/null +++ b/termext/supportscolors_test.go @@ -0,0 +1,26 @@ +package termext + +import ( + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "testing" +) + +func TestSupportsColorsNoPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("SupportsColors panicked: %v", r) + } + }() + _ = SupportsColors() +} + +func TestSupportsColorsReturnsBool(t *testing.T) { + v := SupportsColors() + tst.AssertTrue(t, v == true || v == false) +} + +func TestSupportsColorsIsDeterministic(t *testing.T) { + a := SupportsColors() + b := SupportsColors() + tst.AssertEqual(t, a, b) +} diff --git a/timeext/duration_test.go b/timeext/duration_test.go new file mode 100644 index 0000000..917c442 --- /dev/null +++ b/timeext/duration_test.go @@ -0,0 +1,145 @@ +package timeext + +import ( + "testing" + "time" +) + +func TestFromNanoseconds(t *testing.T) { + if got := FromNanoseconds(0); got != 0 { + t.Errorf("expected 0, got %v", got) + } + if got := FromNanoseconds(1); got != time.Nanosecond { + t.Errorf("expected 1ns, got %v", got) + } + if got := FromNanoseconds(1000); got != 1000*time.Nanosecond { + t.Errorf("expected 1000ns, got %v", got) + } + if got := FromNanoseconds(int64(123456789)); got != time.Duration(123456789) { + t.Errorf("expected 123456789ns, got %v", got) + } +} + +func TestFromMicroseconds(t *testing.T) { + if got := FromMicroseconds(1); got != time.Microsecond { + t.Errorf("expected 1us, got %v", got) + } + if got := FromMicroseconds(1000); got != time.Millisecond { + t.Errorf("expected 1ms, got %v", got) + } + if got := FromMicroseconds(2.5); got != time.Microsecond*2+time.Nanosecond*500 { + t.Errorf("expected 2.5us, got %v", got) + } +} + +func TestFromMilliseconds(t *testing.T) { + if got := FromMilliseconds(1); got != time.Millisecond { + t.Errorf("expected 1ms, got %v", got) + } + if got := FromMilliseconds(1000); got != time.Second { + t.Errorf("expected 1s, got %v", got) + } +} + +func TestFromSeconds(t *testing.T) { + if got := FromSeconds(1); got != time.Second { + t.Errorf("expected 1s, got %v", got) + } + if got := FromSeconds(60); got != time.Minute { + t.Errorf("expected 1min, got %v", got) + } + if got := FromSeconds(0.5); got != 500*time.Millisecond { + t.Errorf("expected 0.5s, got %v", got) + } +} + +func TestFromMinutes(t *testing.T) { + if got := FromMinutes(1); got != time.Minute { + t.Errorf("expected 1min, got %v", got) + } + if got := FromMinutes(60); got != time.Hour { + t.Errorf("expected 1h, got %v", got) + } +} + +func TestFromHours(t *testing.T) { + if got := FromHours(1); got != time.Hour { + t.Errorf("expected 1h, got %v", got) + } + if got := FromHours(24); got != 24*time.Hour { + t.Errorf("expected 24h, got %v", got) + } +} + +func TestFromDays(t *testing.T) { + if got := FromDays(1); got != 24*time.Hour { + t.Errorf("expected 1d, got %v", got) + } + if got := FromDays(7); got != 7*24*time.Hour { + t.Errorf("expected 7d, got %v", got) + } + if got := FromDays(0); got != 0 { + t.Errorf("expected 0, got %v", got) + } +} + +func TestFormatNaturalDurationEnglish(t *testing.T) { + tests := []struct { + dur time.Duration + want string + }{ + {time.Second, "1 second ago"}, + {2 * time.Second, "2 seconds ago"}, + {30 * time.Second, "30 seconds ago"}, + {179 * time.Second, "179 seconds ago"}, + {180 * time.Second, "3 minutes ago"}, + {30 * time.Minute, "30 minutes ago"}, + {179 * time.Minute, "179 minutes ago"}, + {180 * time.Minute, "3 hours ago"}, + {24 * time.Hour, "24 hours ago"}, + {71 * time.Hour, "71 hours ago"}, + {72 * time.Hour, "3 days ago"}, + {20 * 24 * time.Hour, "20 days ago"}, + {21 * 24 * time.Hour, "3 weeks ago"}, + {11 * 7 * 24 * time.Hour, "11 weeks ago"}, + // The months tier divides hours by (24*7*30); the actual boundaries are unusual + // but we capture the current observable behavior: + {12 * 7 * 24 * time.Hour, "0 months ago"}, + {90 * 7 * 24 * time.Hour, "3 months ago"}, + } + for _, tt := range tests { + got := FormatNaturalDurationEnglish(tt.dur) + if got != tt.want { + t.Errorf("FormatNaturalDurationEnglish(%v) = %q; want %q", tt.dur, got, tt.want) + } + } +} + +func TestFormatDurationGerman(t *testing.T) { + tests := []struct { + dur time.Duration + want string + }{ + {time.Second, "1s"}, + {30 * time.Second, "30s"}, + {179 * time.Second, "179s"}, + {180 * time.Second, "3min"}, + {30 * time.Minute, "30min"}, + {179 * time.Minute, "179min"}, + {180 * time.Minute, "3h"}, + {24 * time.Hour, "24h"}, + {71 * time.Hour, "71h"}, + {72 * time.Hour, "3 Tage"}, + {20 * 24 * time.Hour, "20 Tage"}, + {21 * 24 * time.Hour, "3 Wochen"}, + {11 * 7 * 24 * time.Hour, "11 Wochen"}, + {12 * 7 * 24 * time.Hour, "0 Monate"}, + {90 * 7 * 24 * time.Hour, "3 Monate"}, + } + for _, tt := range tests { + got := FormatDurationGerman(tt.dur) + if got != tt.want { + t.Errorf("FormatDurationGerman(%v) = %q; want %q", tt.dur, got, tt.want) + } + } +} diff --git a/timeext/month_test.go b/timeext/month_test.go new file mode 100644 index 0000000..4148e02 --- /dev/null +++ b/timeext/month_test.go @@ -0,0 +1,74 @@ +package timeext + +import ( + "testing" + "time" +) + +func TestMonthNameGermanShort3(t *testing.T) { + tests := []struct { + m time.Month + want string + }{ + {time.January, "Jan"}, + {time.February, "Feb"}, + {time.March, "Mär"}, + {time.April, "Apr"}, + {time.May, "Mai"}, + {time.June, "Jun"}, + {time.July, "Jul"}, + {time.August, "Aug"}, + {time.September, "Sep"}, + {time.October, "Okt"}, + {time.November, "Nov"}, + {time.December, "Dez"}, + } + for _, tt := range tests { + got := MonthNameGermanShort3(tt.m) + if got != tt.want { + t.Errorf("MonthNameGermanShort3(%v) = %q; want %q", tt.m, got, tt.want) + } + } +} + +func TestMonthNameGermanShort3_Invalid(t *testing.T) { + got := MonthNameGermanShort3(time.Month(13)) + want := "%!Month(13)" + if got != want { + t.Errorf("MonthNameGermanShort3(13) = %q; want %q", got, want) + } +} + +func TestMonthNameGermanLong(t *testing.T) { + tests := []struct { + m time.Month + want string + }{ + {time.January, "Januar"}, + {time.February, "Februar"}, + {time.March, "März"}, + {time.April, "April"}, + {time.May, "Mai"}, + {time.June, "Juni"}, + {time.July, "Juli"}, + {time.August, "August"}, + {time.September, "September"}, + {time.October, "Oktober"}, + {time.November, "November"}, + {time.December, "Dezember"}, + } + for _, tt := range tests { + got := MonthNameGermanLong(tt.m) + if got != tt.want { + t.Errorf("MonthNameGermanLong(%v) = %q; want %q", tt.m, got, tt.want) + } + } +} + +func TestMonthNameGermanLong_Invalid(t *testing.T) { + got := MonthNameGermanLong(time.Month(0)) + want := "%!Month(0)" + if got != want { + t.Errorf("MonthNameGermanLong(0) = %q; want %q", got, want) + } +} diff --git a/timeext/range_test.go b/timeext/range_test.go new file mode 100644 index 0000000..38dcf81 --- /dev/null +++ b/timeext/range_test.go @@ -0,0 +1,180 @@ +package timeext + +import ( + "testing" + "time" +) + +func TestOpenTimeRange_String_Empty(t *testing.T) { + r := OpenTimeRange{} + if got := r.String(); got != "[]" { + t.Errorf("expected [], got %q", got) + } +} + +func TestOpenTimeRange_String_FromOnly(t *testing.T) { + tm := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{From: &tm} + got := r.String() + if got == "" || got[0] != '[' || got[len(got)-1] != ']' { + t.Errorf("unexpected format: %q", got) + } +} + +func TestOpenTimeRange_String_ToOnly(t *testing.T) { + tm := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{To: &tm} + got := r.String() + if got == "" || got[0] != '[' || got[len(got)-1] != ']' { + t.Errorf("unexpected format: %q", got) + } +} + +func TestOpenTimeRange_String_Both(t *testing.T) { + t1 := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + t2 := time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{From: &t1, To: &t2} + got := r.String() + if got == "" || got[0] != '[' || got[len(got)-1] != ']' { + t.Errorf("unexpected format: %q", got) + } +} + +func TestOpenTimeRange_Contains_Empty(t *testing.T) { + r := OpenTimeRange{} + if !r.Contains(time.Now()) { + t.Errorf("empty range should contain anything") + } +} + +func TestOpenTimeRange_Contains_FromOnly(t *testing.T) { + from := time.Date(2022, 6, 1, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{From: &from} + + if r.Contains(time.Date(2022, 5, 1, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should not contain time before From") + } + if !r.Contains(from) { + t.Errorf("should contain From itself") + } + if !r.Contains(time.Date(2022, 7, 1, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should contain time after From") + } +} + +func TestOpenTimeRange_Contains_ToOnly(t *testing.T) { + to := time.Date(2022, 6, 1, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{To: &to} + + if !r.Contains(time.Date(2022, 5, 1, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should contain time before To") + } + if r.Contains(to) { + t.Errorf("should not contain To itself (exclusive)") + } + if r.Contains(time.Date(2022, 7, 1, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should not contain time after To") + } +} + +func TestOpenTimeRange_Contains_Both(t *testing.T) { + from := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + to := time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{From: &from, To: &to} + + if r.Contains(time.Date(2021, 12, 31, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should not contain time before From") + } + if !r.Contains(time.Date(2022, 6, 1, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should contain time within range") + } + if r.Contains(time.Date(2023, 1, 1, 0, 0, 0, 0, time.UTC)) { + t.Errorf("should not contain time after To") + } +} + +func TestNewOpenTimeRange_Nil(t *testing.T) { + if NewOpenTimeRange(nil, nil) != nil { + t.Errorf("expected nil for both nil inputs") + } +} + +func TestNewOpenTimeRange_FromOnly(t *testing.T) { + tm := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + r := NewOpenTimeRange(&tm, nil) + if r == nil || r.From == nil || !r.From.Equal(tm) || r.To != nil { + t.Errorf("unexpected result: %v", r) + } +} + +func TestNewOpenTimeRange_ToOnly(t *testing.T) { + tm := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + r := NewOpenTimeRange(nil, &tm) + if r == nil || r.To == nil || !r.To.Equal(tm) || r.From != nil { + t.Errorf("unexpected result: %v", r) + } +} + +func TestNewOpenTimeRange_Both(t *testing.T) { + from := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + to := time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC) + r := NewOpenTimeRange(&from, &to) + if r == nil || !r.From.Equal(from) || !r.To.Equal(to) { + t.Errorf("unexpected result: %v", r) + } +} + +func TestOpenTimeRange_ToMongoPipeline_Empty(t *testing.T) { + r := OpenTimeRange{} + pipeline := r.ToMongoPipeline("ts") + if len(pipeline) != 0 { + t.Errorf("expected empty pipeline, got %v", pipeline) + } +} + +func TestOpenTimeRange_ToMongoPipeline_FromOnly(t *testing.T) { + from := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{From: &from} + pipeline := r.ToMongoPipeline("ts") + if len(pipeline) != 1 { + t.Errorf("expected 1 stage, got %d", len(pipeline)) + } +} + +func TestOpenTimeRange_ToMongoPipeline_ToOnly(t *testing.T) { + to := time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{To: &to} + pipeline := r.ToMongoPipeline("ts") + if len(pipeline) != 1 { + t.Errorf("expected 1 stage, got %d", len(pipeline)) + } +} + +func TestOpenTimeRange_ToMongoPipeline_Both(t *testing.T) { + from := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + to := time.Date(2022, 12, 31, 0, 0, 0, 0, time.UTC) + r := OpenTimeRange{From: &from, To: &to} + pipeline := r.ToMongoPipeline("ts") + if len(pipeline) != 2 { + t.Errorf("expected 2 stages, got %d", len(pipeline)) + } +} + +func TestOpenTimeRange_AppendToMongoPipeline_Nil(t *testing.T) { + var r *OpenTimeRange + existing := []any{"existing"} + got := r.AppendToMongoPipeline(existing, "ts") + if len(got) != 1 { + t.Errorf("expected unchanged pipeline, got %v", got) + } +} + +func TestOpenTimeRange_AppendToMongoPipeline_NonNil(t *testing.T) { + from := time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC) + r := &OpenTimeRange{From: &from} + existing := []any{"existing"} + got := r.AppendToMongoPipeline(existing, "ts") + if len(got) != 2 { + t.Errorf("expected pipeline with 2 entries, got %v", got) + } +} diff --git a/timeext/translation_test.go b/timeext/translation_test.go new file mode 100644 index 0000000..9fd04eb --- /dev/null +++ b/timeext/translation_test.go @@ -0,0 +1,43 @@ +package timeext + +import ( + "testing" + "time" +) + +func TestWeekdayNameGerman(t *testing.T) { + tests := []struct { + d time.Weekday + want string + }{ + {time.Sunday, "Sonntag"}, + {time.Monday, "Montag"}, + {time.Tuesday, "Dienstag"}, + {time.Wednesday, "Mittwoch"}, + {time.Thursday, "Donnerstag"}, + {time.Friday, "Freitag"}, + {time.Saturday, "Samstag"}, + } + for _, tt := range tests { + got := WeekdayNameGerman(tt.d) + if got != tt.want { + t.Errorf("WeekdayNameGerman(%v) = %q; want %q", tt.d, got, tt.want) + } + } +} + +func TestWeekdayNameGerman_Invalid(t *testing.T) { + got := WeekdayNameGerman(time.Weekday(8)) + want := "%!Weekday(8)" + if got != want { + t.Errorf("WeekdayNameGerman(8) = %q; want %q", got, want) + } +} + +func TestWeekdayNameGerman_Negative(t *testing.T) { + got := WeekdayNameGerman(time.Weekday(-1)) + want := "%!Weekday(-1)" + if got != want { + t.Errorf("WeekdayNameGerman(-1) = %q; want %q", got, want) + } +} diff --git a/timeext/weekcount_test.go b/timeext/weekcount_test.go new file mode 100644 index 0000000..9a75ee0 --- /dev/null +++ b/timeext/weekcount_test.go @@ -0,0 +1,81 @@ +package timeext + +import ( + "testing" + "time" +) + +func TestGetIsoWeekCount(t *testing.T) { + // The implementation subtracts 1 from the ISO week numbers internally, + // so for a 53-week year it returns 52, and for a 52-week year it returns 51. + tests := []struct { + year int + want int + }{ + {2020, 52}, // 2020 is a 53-week ISO year + {2021, 51}, + {2022, 51}, + {2023, 51}, + {2024, 51}, + } + for _, tt := range tests { + got := GetIsoWeekCount(tt.year) + if got != tt.want { + t.Errorf("GetIsoWeekCount(%d) = %d; want %d", tt.year, got, tt.want) + } + } +} + +func TestGetAggregateIsoWeekCount_Year1900(t *testing.T) { + got := GetAggregateIsoWeekCount(1900) + if got != 0 { + t.Errorf("GetAggregateIsoWeekCount(1900) = %d; want 0", got) + } +} + +func TestGetAggregateIsoWeekCount_Monotonic(t *testing.T) { + // The aggregate count must be strictly monotonically increasing year over year (for years > 1900) + prev := GetAggregateIsoWeekCount(1900) + for y := 1901; y <= 2030; y++ { + cur := GetAggregateIsoWeekCount(y) + if cur <= prev { + t.Errorf("GetAggregateIsoWeekCount(%d)=%d not greater than GetAggregateIsoWeekCount(%d)=%d", y, cur, y-1, prev) + } + prev = cur + } +} + +func TestGetAggregateIsoWeekCount_BeforeBaseline(t *testing.T) { + // for years < 1900 the aggregate is negative + got := GetAggregateIsoWeekCount(1899) + if got >= 0 { + t.Errorf("GetAggregateIsoWeekCount(1899) = %d; want < 0", got) + } +} + +func TestGetGlobalWeeknumber_Monotonic(t *testing.T) { + // Walking forward by one week should never decrease the global week number + t0 := time.Date(2020, 1, 6, 0, 0, 0, 0, TimezoneBerlin) // Monday, ISO week 2 of 2020 + prev := GetGlobalWeeknumber(t0) + for i := 1; i < 200; i++ { + ti := t0.AddDate(0, 0, i*7) + cur := GetGlobalWeeknumber(ti) + if cur < prev { + t.Errorf("week number decreased at offset %d weeks: %d -> %d", i, prev, cur) + } + prev = cur + } +} + +func TestGetGlobalWeeknumber_DifferentYearsDiffer(t *testing.T) { + w2020 := GetGlobalWeeknumber(time.Date(2020, 6, 1, 0, 0, 0, 0, TimezoneBerlin)) + w2021 := GetGlobalWeeknumber(time.Date(2021, 6, 1, 0, 0, 0, 0, TimezoneBerlin)) + if w2021 <= w2020 { + t.Errorf("expected w2021 > w2020, got %d and %d", w2021, w2020) + } + // Approximately 52 weeks apart + delta := w2021 - w2020 + if delta < 50 || delta > 54 { + t.Errorf("expected ~52 week difference, got %d", delta) + } +} diff --git a/totpext/totp_test.go b/totpext/totp_test.go new file mode 100644 index 0000000..d72cdcd --- /dev/null +++ b/totpext/totp_test.go @@ -0,0 +1,290 @@ +package totpext + +import ( + "crypto/sha1" + "encoding/base32" + "fmt" + "net/url" + "regexp" + "strings" + "testing" + "time" +) + +// RFC 6238 reference seed (ASCII "12345678901234567890") +var rfcSeed = []byte("12345678901234567890") + +// generateTOTP tests against RFC 6238 reference test vectors (Appendix B). +// The reference vectors yield 8-digit codes; the package uses 6 digits, so we +// take the last 6 digits of the published 8-digit value. +func TestGenerateTOTPRFCVectors(t *testing.T) { + cases := []struct { + unix int64 + expected string // last 6 digits of the RFC 8-digit value + }{ + {59, "287082"}, + {1111111109, "081804"}, + {1111111111, "050471"}, + {1234567890, "005924"}, + {2000000000, "279037"}, + } + + for _, c := range cases { + got := generateTOTP(sha1.New, rfcSeed, c.unix/30, 6) + if got != c.expected { + t.Errorf("generateTOTP(unix=%d) = %q, expected %q", c.unix, got, c.expected) + } + } +} + +func TestGenerateTOTPLengthAndDigits(t *testing.T) { + for digits := 6; digits <= 8; digits++ { + got := generateTOTP(sha1.New, rfcSeed, 0, digits) + if len(got) != digits { + t.Errorf("expected length %d, got %d (%q)", digits, len(got), got) + } + matched, _ := regexp.MatchString("^[0-9]+$", got) + if !matched { + t.Errorf("expected all-digit string, got %q", got) + } + } +} + +func TestGenerateTOTPLeftPadsWithZeros(t *testing.T) { + // Use a high digit count to make zero-padding likely for some t. + // Find a t that produces a value shorter than digits without padding. + digits := 8 + for ts := range 5000 { + got := generateTOTP(sha1.New, rfcSeed, int64(ts), digits) + if len(got) != digits { + t.Fatalf("length mismatch at ts=%d: %q (len=%d)", ts, got, len(got)) + } + } +} + +func TestGenerateTOTPDeterministic(t *testing.T) { + a := generateTOTP(sha1.New, rfcSeed, 12345, 6) + b := generateTOTP(sha1.New, rfcSeed, 12345, 6) + if a != b { + t.Errorf("generateTOTP not deterministic: %q vs %q", a, b) + } +} + +func TestGenerateTOTPDifferentSecrets(t *testing.T) { + a := generateTOTP(sha1.New, []byte("AAAAAAAAAAAAAAAAAAAA"), 100, 6) + b := generateTOTP(sha1.New, []byte("BBBBBBBBBBBBBBBBBBBB"), 100, 6) + if a == b { + t.Errorf("expected different TOTPs for different secrets, both = %q", a) + } +} + +func TestGenerateTOTPDifferentTimes(t *testing.T) { + a := generateTOTP(sha1.New, rfcSeed, 100, 6) + b := generateTOTP(sha1.New, rfcSeed, 101, 6) + if a == b { + t.Errorf("expected different TOTPs for different times, both = %q", a) + } +} + +func TestTOTPFormat(t *testing.T) { + secret, err := GenerateSecret() + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + code := TOTP(secret) + if len(code) != 6 { + t.Errorf("expected 6 digit code, got %q (len=%d)", code, len(code)) + } + matched, _ := regexp.MatchString("^[0-9]{6}$", code) + if !matched { + t.Errorf("expected 6 digit numeric code, got %q", code) + } +} + +func TestTOTPMatchesGenerateTOTPForCurrentTime(t *testing.T) { + secret := rfcSeed + // Generate both as close together as possible to share the same 30s window. + for range 3 { + now := time.Now().Unix() + windowStart := now / 30 + got := TOTP(secret) + expected := generateTOTP(sha1.New, secret, windowStart, 6) + // If we crossed a window boundary, retry. + if time.Now().Unix()/30 != windowStart { + continue + } + if got != expected { + t.Errorf("TOTP() = %q, expected %q (window=%d)", got, expected, windowStart) + } + return + } + t.Skip("could not capture stable 30s window for comparison") +} + +func TestValidateCurrentWindow(t *testing.T) { + secret, err := GenerateSecret() + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + code := TOTP(secret) + if !Validate(secret, code) { + t.Errorf("Validate rejected a freshly generated TOTP") + } +} + +func TestValidatePreviousAndNextWindow(t *testing.T) { + secret := rfcSeed + t0 := time.Now().Unix() / 30 + + prev := generateTOTP(sha1.New, secret, t0-1, 6) + next := generateTOTP(sha1.New, secret, t0+1, 6) + + if !Validate(secret, prev) { + t.Errorf("Validate rejected previous-window code %q", prev) + } + if !Validate(secret, next) { + t.Errorf("Validate rejected next-window code %q", next) + } +} + +func TestValidateRejectsOutOfWindow(t *testing.T) { + secret := rfcSeed + t0 := time.Now().Unix() / 30 + + // Two windows away — must be rejected. + farFuture := generateTOTP(sha1.New, secret, t0+5, 6) + farPast := generateTOTP(sha1.New, secret, t0-5, 6) + + // In the unlikely case of a hash collision with a valid window, skip. + current := generateTOTP(sha1.New, secret, t0, 6) + prev := generateTOTP(sha1.New, secret, t0-1, 6) + next := generateTOTP(sha1.New, secret, t0+1, 6) + validSet := map[string]bool{current: true, prev: true, next: true} + + if !validSet[farFuture] && Validate(secret, farFuture) { + t.Errorf("Validate accepted out-of-window future code %q", farFuture) + } + if !validSet[farPast] && Validate(secret, farPast) { + t.Errorf("Validate accepted out-of-window past code %q", farPast) + } +} + +func TestValidateRejectsGarbage(t *testing.T) { + secret, err := GenerateSecret() + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + for _, bad := range []string{"", "000000", "abcdef", "12345", "1234567"} { + if Validate(secret, bad) { + // "000000" could theoretically be valid; only fail if it's not the actual code. + if bad == "000000" && TOTP(secret) == "000000" { + continue + } + t.Errorf("Validate accepted garbage input %q", bad) + } + } +} + +func TestValidateRejectsWrongSecret(t *testing.T) { + a, _ := GenerateSecret() + b, _ := GenerateSecret() + code := TOTP(a) + // Extremely unlikely both 20-byte random secrets agree on any window. + if Validate(b, code) { + t.Errorf("Validate accepted code from a different secret") + } +} + +func TestGenerateSecretLength(t *testing.T) { + s, err := GenerateSecret() + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + if len(s) != 20 { + t.Errorf("expected 20-byte secret, got %d", len(s)) + } +} + +func TestGenerateSecretRandomness(t *testing.T) { + a, err := GenerateSecret() + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + b, err := GenerateSecret() + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + if string(a) == string(b) { + t.Errorf("two GenerateSecret calls returned identical output") + } +} + +func TestGenerateOTPAuthFormat(t *testing.T) { + key := []byte("12345678901234567890") + got := GenerateOTPAuth("MyApp", key, "user@example.com", "MyIssuer") + + if !strings.HasPrefix(got, "otpauth://totp/") { + t.Errorf("expected otpauth scheme prefix, got %q", got) + } + + u, err := url.Parse(got) + if err != nil { + t.Fatalf("URL parse failed: %v", err) + } + if u.Scheme != "otpauth" { + t.Errorf("scheme = %q, expected %q", u.Scheme, "otpauth") + } + if u.Host != "totp" { + t.Errorf("host = %q, expected %q", u.Host, "totp") + } + + // Path is "/MyApp:user@example.com" (account email is QueryEscaped before formatting). + expectedPath := fmt.Sprintf("/MyApp:%s", url.QueryEscape("user@example.com")) + if u.Path != expectedPath && u.EscapedPath() != expectedPath { + t.Errorf("path = %q (escaped %q), expected %q", u.Path, u.EscapedPath(), expectedPath) + } + + q := u.Query() + expectedSecret := base32.StdEncoding.EncodeToString(key) + if q.Get("secret") != expectedSecret { + t.Errorf("secret = %q, expected %q", q.Get("secret"), expectedSecret) + } + if q.Get("issuer") != "MyIssuer" { + t.Errorf("issuer = %q, expected %q", q.Get("issuer"), "MyIssuer") + } + if q.Get("algorithm") != "SHA1" { + t.Errorf("algorithm = %q, expected %q", q.Get("algorithm"), "SHA1") + } + if q.Get("period") != "30" { + t.Errorf("period = %q, expected %q", q.Get("period"), "30") + } + if q.Get("digits") != "6" { + t.Errorf("digits = %q, expected %q", q.Get("digits"), "6") + } +} + +func TestGenerateOTPAuthEscapesAccount(t *testing.T) { + key := []byte("12345678901234567890") + got := GenerateOTPAuth("App", key, "a b@c.com", "Iss") + // The account is run through url.QueryEscape, so the space becomes '+'. + if !strings.Contains(got, "App:a+b%40c.com") { + t.Errorf("expected escaped account in path, got %q", got) + } +} + +func TestGenerateOTPAuthSecretIsBase32(t *testing.T) { + key := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09} + got := GenerateOTPAuth("App", key, "u@e.com", "Iss") + u, err := url.Parse(got) + if err != nil { + t.Fatalf("URL parse failed: %v", err) + } + secret := u.Query().Get("secret") + decoded, err := base32.StdEncoding.DecodeString(secret) + if err != nil { + t.Fatalf("secret is not valid base32: %v (raw=%q)", err, secret) + } + if string(decoded) != string(key) { + t.Errorf("decoded secret = %v, expected %v", decoded, key) + } +} diff --git a/tst/assertions_test.go b/tst/assertions_test.go new file mode 100644 index 0000000..4839ac7 --- /dev/null +++ b/tst/assertions_test.go @@ -0,0 +1,240 @@ +package tst + +import ( + "testing" +) + +// These tests exercise the success-paths of each assertion: i.e. the cases +// where the assertion should NOT mark the test as failed. They verify the +// boolean/conditional logic of each function. We avoid testing the failure +// branches because the assertions take a concrete *testing.T (not an +// interface), so substituting a mock would require unsafe/reflection tricks +// that are brittle and cascade failures to the parent test. + +// --- AssertEqual ---------------------------------------------------------- + +func TestAssertEqual_Int(t *testing.T) { + AssertEqual(t, 42, 42) + AssertEqual(t, -1, -1) + AssertEqual(t, 0, 0) +} + +func TestAssertEqual_String(t *testing.T) { + AssertEqual(t, "hello", "hello") + AssertEqual(t, "", "") +} + +func TestAssertEqual_Bool(t *testing.T) { + AssertEqual(t, true, true) + AssertEqual(t, false, false) +} + +func TestAssertEqual_Float(t *testing.T) { + AssertEqual(t, 3.14, 3.14) +} + +// --- AssertArrayEqual ----------------------------------------------------- + +func TestAssertArrayEqual_Empty(t *testing.T) { + AssertArrayEqual(t, []int{}, []int{}) +} + +func TestAssertArrayEqual_Equal(t *testing.T) { + AssertArrayEqual(t, []int{1, 2, 3}, []int{1, 2, 3}) + AssertArrayEqual(t, []string{"a", "b"}, []string{"a", "b"}) +} + +func TestAssertArrayEqual_SingleElement(t *testing.T) { + AssertArrayEqual(t, []int{7}, []int{7}) +} + +// --- AssertNotEqual ------------------------------------------------------- + +func TestAssertNotEqual_Int(t *testing.T) { + AssertNotEqual(t, 1, 2) + AssertNotEqual(t, 0, -1) +} + +func TestAssertNotEqual_String(t *testing.T) { + AssertNotEqual(t, "foo", "bar") + AssertNotEqual(t, "", "x") +} + +func TestAssertNotEqual_Bool(t *testing.T) { + AssertNotEqual(t, true, false) +} + +// --- AssertDeepEqual ------------------------------------------------------ + +func TestAssertDeepEqual_Struct(t *testing.T) { + type s struct { + A int + B string + } + AssertDeepEqual(t, s{A: 1, B: "x"}, s{A: 1, B: "x"}) +} + +func TestAssertDeepEqual_Slice(t *testing.T) { + AssertDeepEqual(t, []int{1, 2, 3}, []int{1, 2, 3}) +} + +func TestAssertDeepEqual_Map(t *testing.T) { + AssertDeepEqual(t, map[string]int{"a": 1, "b": 2}, map[string]int{"a": 1, "b": 2}) +} + +func TestAssertDeepEqual_NilSlice(t *testing.T) { + var a, b []int + AssertDeepEqual(t, a, b) +} + +func TestAssertDeepEqual_NestedStruct(t *testing.T) { + type inner struct{ X int } + type outer struct { + I inner + S []string + } + a := outer{I: inner{X: 1}, S: []string{"a", "b"}} + b := outer{I: inner{X: 1}, S: []string{"a", "b"}} + AssertDeepEqual(t, a, b) +} + +// --- AssertSetDeepEqual --------------------------------------------------- + +func TestAssertSetDeepEqual_Empty(t *testing.T) { + AssertSetDeepEqual(t, []int{}, []int{}) +} + +func TestAssertSetDeepEqual_SameOrder(t *testing.T) { + AssertSetDeepEqual(t, []int{1, 2, 3}, []int{1, 2, 3}) +} + +func TestAssertSetDeepEqual_DifferentOrder(t *testing.T) { + AssertSetDeepEqual(t, []int{3, 1, 2}, []int{1, 2, 3}) +} + +func TestAssertSetDeepEqual_Strings(t *testing.T) { + AssertSetDeepEqual(t, []string{"b", "a", "c"}, []string{"a", "b", "c"}) +} + +func TestAssertSetDeepEqual_Structs(t *testing.T) { + type s struct{ V int } + AssertSetDeepEqual(t, []s{{V: 2}, {V: 1}}, []s{{V: 1}, {V: 2}}) +} + +// --- AssertNotDeepEqual --------------------------------------------------- + +func TestAssertNotDeepEqual_Struct(t *testing.T) { + type s struct { + A int + } + AssertNotDeepEqual(t, s{A: 1}, s{A: 2}) +} + +func TestAssertNotDeepEqual_Slice(t *testing.T) { + AssertNotDeepEqual(t, []int{1, 2}, []int{1, 2, 3}) +} + +func TestAssertNotDeepEqual_Map(t *testing.T) { + AssertNotDeepEqual(t, map[string]int{"a": 1}, map[string]int{"b": 1}) +} + +// --- AssertDeRefEqual ----------------------------------------------------- + +func TestAssertDeRefEqual_Int(t *testing.T) { + v := 42 + AssertDeRefEqual(t, &v, 42) +} + +func TestAssertDeRefEqual_String(t *testing.T) { + v := "hello" + AssertDeRefEqual(t, &v, "hello") +} + +func TestAssertDeRefEqual_Bool(t *testing.T) { + v := true + AssertDeRefEqual(t, &v, true) +} + +// --- AssertPtrEqual ------------------------------------------------------- + +func TestAssertPtrEqual_BothNil(t *testing.T) { + var a, b *int + AssertPtrEqual(t, a, b) +} + +func TestAssertPtrEqual_SameValue(t *testing.T) { + a, b := 5, 5 + AssertPtrEqual(t, &a, &b) +} + +func TestAssertPtrEqual_DifferentPointersSameValue(t *testing.T) { + a, b := "hi", "hi" + AssertPtrEqual(t, &a, &b) +} + +func TestAssertPtrEqual_SamePointer(t *testing.T) { + a := 99 + AssertPtrEqual(t, &a, &a) +} + +// --- AssertHexEqual ------------------------------------------------------- + +func TestAssertHexEqual_Empty(t *testing.T) { + AssertHexEqual(t, "", []byte{}) +} + +func TestAssertHexEqual_Bytes(t *testing.T) { + AssertHexEqual(t, "deadbeef", []byte{0xde, 0xad, 0xbe, 0xef}) +} + +func TestAssertHexEqual_SingleByte(t *testing.T) { + AssertHexEqual(t, "ff", []byte{0xff}) +} + +func TestAssertHexEqual_AllZero(t *testing.T) { + AssertHexEqual(t, "0000", []byte{0x00, 0x00}) +} + +// --- AssertTrue / AssertFalse --------------------------------------------- + +func TestAssertTrue(t *testing.T) { + AssertTrue(t, true) +} + +func TestAssertFalse(t *testing.T) { + AssertFalse(t, false) +} + +// --- AssertNoErr ---------------------------------------------------------- + +func TestAssertNoErr_Nil(t *testing.T) { + AssertNoErr(t, nil) +} + +// --- AssertStrRepEqual ---------------------------------------------------- + +func TestAssertStrRepEqual_Same(t *testing.T) { + AssertStrRepEqual(t, 42, 42) + AssertStrRepEqual(t, "abc", "abc") +} + +func TestAssertStrRepEqual_DifferentTypesSameRep(t *testing.T) { + // 42 and "42" both stringify to "42" via %v + AssertStrRepEqual(t, 42, "42") +} + +func TestAssertStrRepEqual_Bool(t *testing.T) { + AssertStrRepEqual(t, true, true) +} + +// --- AssertStrRepNotEqual ------------------------------------------------- + +func TestAssertStrRepNotEqual_Different(t *testing.T) { + AssertStrRepNotEqual(t, 42, 43) + AssertStrRepNotEqual(t, "abc", "xyz") +} + +func TestAssertStrRepNotEqual_BoolVsInt(t *testing.T) { + // "true" != "1" + AssertStrRepNotEqual(t, true, 1) +} diff --git a/tst/identAssertions_test.go b/tst/identAssertions_test.go new file mode 100644 index 0000000..959ec7e --- /dev/null +++ b/tst/identAssertions_test.go @@ -0,0 +1,71 @@ +package tst + +import ( + "testing" +) + +// Success-path tests for the Ident-prefixed assertions. They exercise the +// non-failing branches of each function. + +// --- AssertIdentEqual ----------------------------------------------------- + +func TestAssertIdentEqual_Int(t *testing.T) { + AssertIdentEqual(t, "value", 1, 1) + AssertIdentEqual(t, "value", 0, 0) +} + +func TestAssertIdentEqual_String(t *testing.T) { + AssertIdentEqual(t, "name", "alice", "alice") +} + +func TestAssertIdentEqual_Bool(t *testing.T) { + AssertIdentEqual(t, "flag", true, true) + AssertIdentEqual(t, "flag", false, false) +} + +func TestAssertIdentEqual_EmptyIdent(t *testing.T) { + AssertIdentEqual(t, "", 7, 7) +} + +// --- AssertIdentNotEqual -------------------------------------------------- + +func TestAssertIdentNotEqual_Int(t *testing.T) { + AssertIdentNotEqual(t, "value", 1, 2) +} + +func TestAssertIdentNotEqual_String(t *testing.T) { + AssertIdentNotEqual(t, "name", "alice", "bob") +} + +// --- AssertIdentPtrEqual -------------------------------------------------- + +func TestAssertIdentPtrEqual_BothNil(t *testing.T) { + var a, b *int + AssertIdentPtrEqual(t, "ptr", a, b) +} + +func TestAssertIdentPtrEqual_SameValue(t *testing.T) { + a, b := 10, 10 + AssertIdentPtrEqual(t, "ptr", &a, &b) +} + +func TestAssertIdentPtrEqual_SameStringValue(t *testing.T) { + a, b := "x", "x" + AssertIdentPtrEqual(t, "ptr", &a, &b) +} + +// --- AssertIdentTrue / AssertIdentFalse ----------------------------------- + +func TestAssertIdentTrue(t *testing.T) { + AssertIdentTrue(t, "ok", true) +} + +// AssertIdentFalse has a known quirk in the original implementation: +// `if !value { t.Errorf(...) }` — i.e. it only fails when value is false, +// which means "true" is the success path. We test the success path as +// implemented (we MUST NOT change existing code). +func TestAssertIdentFalse_SuccessPathAsImplemented(t *testing.T) { + // As coded, AssertIdentFalse fails when value is false. So passing + // `true` is the no-failure path according to the current implementation. + AssertIdentFalse(t, "ok", true) +} diff --git a/tst/must_test.go b/tst/must_test.go new file mode 100644 index 0000000..ee84f16 --- /dev/null +++ b/tst/must_test.go @@ -0,0 +1,49 @@ +package tst + +import ( + "strconv" + "testing" +) + +// --- Must ----------------------------------------------------------------- + +func TestMust_NoError_Int(t *testing.T) { + v := Must(123, nil)(t) + AssertEqual(t, v, 123) +} + +func TestMust_NoError_String(t *testing.T) { + v := Must("hello", nil)(t) + AssertEqual(t, v, "hello") +} + +func TestMust_NoError_Slice(t *testing.T) { + v := Must([]int{1, 2, 3}, nil)(t) + AssertArrayEqual(t, v, []int{1, 2, 3}) +} + +func TestMust_NoError_Struct(t *testing.T) { + type s struct { + X int + Y string + } + v := Must(s{X: 7, Y: "abc"}, nil)(t) + AssertEqual(t, v, s{X: 7, Y: "abc"}) +} + +func TestMust_StrconvAtoi(t *testing.T) { + v := Must(strconv.Atoi("42"))(t) + AssertEqual(t, v, 42) +} + +func TestMust_ZeroValueOnNoError(t *testing.T) { + v := Must(0, nil)(t) + AssertEqual(t, v, 0) +} + +func TestMust_ReturnedFnIsNotNil(t *testing.T) { + fn := Must("anything", nil) + if fn == nil { + t.Fatal("Must should return a non-nil function") + } +} diff --git a/wmo/wmo_test.go b/wmo/wmo_test.go new file mode 100644 index 0000000..b1def76 --- /dev/null +++ b/wmo/wmo_test.go @@ -0,0 +1,585 @@ +package wmo + +import ( + "context" + "errors" + "reflect" + "testing" + + ct "git.blackforestbytes.com/BlackForestBytes/goext/cursortoken" + "git.blackforestbytes.com/BlackForestBytes/goext/langext" + "git.blackforestbytes.com/BlackForestBytes/goext/tst" + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +// --- Mock helpers for Decodable / Cursorable ------------------------------- + +type mockDecodable struct { + docs []bson.M + cursor int + err error +} + +func (m *mockDecodable) Decode(v any) error { + if m.err != nil { + return m.err + } + if m.cursor >= len(m.docs) { + return mongo.ErrNoDocuments + } + raw, err := bson.Marshal(m.docs[m.cursor]) + if err != nil { + return err + } + return bson.Unmarshal(raw, v) +} + +type mockCursor struct { + docs []bson.M + idx int + closed bool + allErr error + startBefore int +} + +func newMockCursor(docs []bson.M) *mockCursor { + return &mockCursor{docs: docs, idx: -1} +} + +func (m *mockCursor) Decode(v any) error { + if m.idx < 0 || m.idx >= len(m.docs) { + return mongo.ErrNoDocuments + } + raw, err := bson.Marshal(m.docs[m.idx]) + if err != nil { + return err + } + return bson.Unmarshal(raw, v) +} + +func (m *mockCursor) Err() error { return nil } +func (m *mockCursor) Close(_ context.Context) error { m.closed = true; return nil } +func (m *mockCursor) RemainingBatchLength() int { return len(m.docs) - (m.idx + 1) } +func (m *mockCursor) Next(_ context.Context) bool { m.idx++; return m.idx < len(m.docs) } +func (m *mockCursor) All(_ context.Context, results any) error { + if m.allErr != nil { + return m.allErr + } + raws := make([]bson.Raw, 0, len(m.docs)) + for _, d := range m.docs { + r, err := bson.Marshal(d) + if err != nil { + return err + } + raws = append(raws, r) + } + rv := reflect.ValueOf(results).Elem() + rv.SetLen(0) + elemType := rv.Type().Elem() + for _, r := range raws { + ev := reflect.New(elemType) + if err := bson.Unmarshal(r, ev.Interface()); err != nil { + return err + } + rv.Set(reflect.Append(rv, ev.Elem())) + } + return nil +} + +// --- init() / reflection edge cases --------------------------------------- + +func TestInitIgnoresUnsupportedFields(t *testing.T) { + + type Inlined struct { + Inner string `bson:"inner"` + } + + type TestData struct { + ID string `bson:"_id"` + Skipped string `bson:"-"` + NoTag string // no bson tag => skipped + unexported string `bson:"unexp"` + WithOpts string `bson:"opt,omitempty"` + Inline Inlined `bson:",inline"` + } + + coll := W[TestData](&mongo.Collection{}) + + _, errSkipped := coll.getFieldType("-") + tst.AssertTrue(t, errSkipped != nil) + + _, errNoTag := coll.getFieldType("NoTag") + tst.AssertTrue(t, errNoTag != nil) + + _, errUnexp := coll.getFieldType("unexp") + tst.AssertTrue(t, errUnexp != nil) + + id, err := coll.getFieldType("_id") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, id.Name, "ID") + + opt, err := coll.getFieldType("opt") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, opt.Name, "WithOpts") + + inner, err := coll.getFieldType("inner") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, inner.Name, "Inner") + + _ = TestData{unexported: "x"} +} + +func TestInitRecursivePointerStructDoesNotInfinitelyRecurse(t *testing.T) { + + type Recursive struct { + Val int `bson:"val"` + Inner *Recursive `bson:"inner"` + } + + type TestData struct { + ID string `bson:"_id"` + Rec Recursive `bson:"rec"` + } + + coll := W[TestData](&mongo.Collection{}) + + val, err := coll.getFieldType("rec.val") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, val.Name, "Val") + + inner, err := coll.getFieldType("rec.inner") + tst.AssertNoErr(t, err) + tst.AssertEqual(t, inner.IsPointer, true) + + _, err = coll.getFieldType("rec.inner.inner.inner.inner.val") + tst.AssertTrue(t, err != nil) +} + +func TestGetFieldTypeUnknownField(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + _, err := coll.getFieldType("does_not_exist") + tst.AssertTrue(t, err != nil) +} + +func TestGetFieldValueUnknownField(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + _, err := coll.getFieldValue(TestData{ID: "x"}, "does_not_exist") + tst.AssertTrue(t, err != nil) +} + +func TestGetFieldValueInterfaceUnknownType(t *testing.T) { + + type TestImpl struct { + ID string `bson:"_id"` + } + + type Iface any + + coll := W[Iface](&mongo.Collection{}) + // no decoder registered, the impl-type-map is empty + _, err := coll.getFieldValue(TestImpl{ID: "x"}, "_id") + tst.AssertTrue(t, err != nil) +} + +func TestGetFieldValueInterfaceUnknownField(t *testing.T) { + + type TestImpl struct { + ID string `bson:"_id"` + } + + type Iface any + + df := func(ctx context.Context, dec Decodable) (Iface, error) { return TestImpl{}, nil } + coll := W[Iface](&mongo.Collection{}).WithDecodeFunc(df, TestImpl{}) + + _, err := coll.getFieldValue(TestImpl{ID: "x"}, "missing_field") + tst.AssertTrue(t, err != nil) +} + +func TestEnsureInitializedReflectionIdempotent(t *testing.T) { + + type TestImpl struct { + ID string `bson:"_id"` + } + + type Iface any + + df := func(ctx context.Context, dec Decodable) (Iface, error) { return TestImpl{}, nil } + coll := W[Iface](&mongo.Collection{}).WithDecodeFunc(df, TestImpl{}) + + tst.AssertEqual(t, 1, len(coll.implDataTypeMap)) + + // pointer-deref path + idempotency: calling again should not duplicate. + coll.EnsureInitializedReflection(&TestImpl{ID: "x"}) + tst.AssertEqual(t, 1, len(coll.implDataTypeMap)) + + coll.EnsureInitializedReflection(TestImpl{ID: "y"}) + tst.AssertEqual(t, 1, len(coll.implDataTypeMap)) +} + +func TestEnsureInitializedReflectionNoOpForStruct(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + before := len(coll.implDataTypeMap) + coll.EnsureInitializedReflection(TestData{ID: "x"}) + tst.AssertEqual(t, before, len(coll.implDataTypeMap)) +} + +// --- hooks ---------------------------------------------------------------- + +func TestWithUnmarshalHookRunsOnDecodeSingle(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + Value string `bson:"value"` + } + + coll := W[TestData](&mongo.Collection{}).WithUnmarshalHook(func(d TestData) TestData { + d.Value = d.Value + "_hook" + return d + }) + + dec := &mockDecodable{docs: []bson.M{{"_id": "1", "value": "raw"}}} + res, err := coll.decodeSingle(context.Background(), dec) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, res.Value, "raw_hook") +} + +func TestWithUnmarshalHookRunsOnDecodeAll(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + Value string `bson:"value"` + } + + calls := 0 + coll := W[TestData](&mongo.Collection{}).WithUnmarshalHook(func(d TestData) TestData { + calls++ + d.Value = d.Value + "!" + return d + }) + + cur := newMockCursor([]bson.M{ + {"_id": "1", "value": "a"}, + {"_id": "2", "value": "b"}, + }) + res, err := coll.decodeAll(context.Background(), cur) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(res), 2) + tst.AssertEqual(t, res[0].Value, "a!") + tst.AssertEqual(t, res[1].Value, "b!") + tst.AssertEqual(t, calls, 2) +} + +func TestWithMarshalHookAppendsHook(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}). + WithMarshalHook(func(d TestData) TestData { d.ID = "1-" + d.ID; return d }). + WithMarshalHook(func(d TestData) TestData { d.ID = "2-" + d.ID; return d }) + + tst.AssertEqual(t, len(coll.marshalHooks), 2) + + out := coll.marshalHooks[0](TestData{ID: "x"}) + tst.AssertEqual(t, out.ID, "1-x") + + out = coll.marshalHooks[1](TestData{ID: "x"}) + tst.AssertEqual(t, out.ID, "2-x") +} + +func TestCustomDecoderUsedInDecodeSingle(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + Value string `bson:"value"` + } + + type Iface any + + df := func(ctx context.Context, dec Decodable) (Iface, error) { + var raw bson.M + if err := dec.Decode(&raw); err != nil { + return nil, err + } + return TestData{ID: raw["_id"].(string), Value: "from-custom"}, nil + } + + coll := W[Iface](&mongo.Collection{}).WithDecodeFunc(df, TestData{}) + + dec := &mockDecodable{docs: []bson.M{{"_id": "abc", "value": "raw"}}} + res, err := coll.decodeSingle(context.Background(), dec) + tst.AssertNoErr(t, err) + + td, ok := res.(TestData) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, td.ID, "abc") + tst.AssertEqual(t, td.Value, "from-custom") +} + +func TestCustomDecoderUsedInDecodeAll(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + type Iface any + + df := func(ctx context.Context, dec Decodable) (Iface, error) { + var raw bson.M + if err := dec.Decode(&raw); err != nil { + return nil, err + } + return TestData{ID: raw["_id"].(string) + "!"}, nil + } + + coll := W[Iface](&mongo.Collection{}).WithDecodeFunc(df, TestData{}) + + cur := newMockCursor([]bson.M{ + {"_id": "a"}, + {"_id": "b"}, + {"_id": "c"}, + }) + res, err := coll.decodeAll(context.Background(), cur) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(res), 3) + tst.AssertEqual(t, res[0].(TestData).ID, "a!") + tst.AssertEqual(t, res[2].(TestData).ID, "c!") +} + +func TestCustomDecoderErrorPropagates(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + type Iface any + + myErr := errors.New("custom-decoder-failure") + df := func(ctx context.Context, dec Decodable) (Iface, error) { + return nil, myErr + } + + coll := W[Iface](&mongo.Collection{}).WithDecodeFunc(df, TestData{}) + + _, err := coll.decodeSingle(context.Background(), &mockDecodable{docs: []bson.M{{"_id": "a"}}}) + tst.AssertTrue(t, err != nil) + + cur := newMockCursor([]bson.M{{"_id": "a"}}) + _, err = coll.decodeAll(context.Background(), cur) + tst.AssertTrue(t, err != nil) +} + +func TestDecodeSingleWithDecodeError(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + _, err := coll.decodeSingle(context.Background(), &mockDecodable{err: errors.New("boom")}) + tst.AssertTrue(t, err != nil) +} + +// --- pipeline / sort detection -------------------------------------------- + +func TestWithModifyingPipelineAppends(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}).WithModifyingPipeline(mongo.Pipeline{ + bson.D{{Key: "$set", Value: bson.M{"x": 1}}}, + }) + + tst.AssertEqual(t, len(coll.extraModPipeline), 1) + + stages := coll.extraModPipeline[0](context.Background()) + tst.AssertEqual(t, len(stages), 1) + tst.AssertEqual(t, stages[0][0].Key, "$set") +} + +func TestWithModifyingPipelineFunc(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + called := 0 + coll := W[TestData](&mongo.Collection{}).WithModifyingPipelineFunc(func(ctx context.Context) mongo.Pipeline { + called++ + return mongo.Pipeline{bson.D{{Key: "$project", Value: bson.M{"_id": 1}}}} + }) + + tst.AssertEqual(t, len(coll.extraModPipeline), 1) + stages := coll.extraModPipeline[0](context.Background()) + tst.AssertEqual(t, called, 1) + tst.AssertEqual(t, stages[0][0].Key, "$project") +} + +func TestNeedsDoubleSortFalseWithoutGroup(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}). + WithModifyingPipeline(mongo.Pipeline{ + bson.D{{Key: "$set", Value: bson.M{"x": 1}}}, + bson.D{{Key: "$unset", Value: "y"}}, + }) + + tst.AssertEqual(t, coll.needsDoubleSort(context.Background()), false) +} + +func TestNeedsDoubleSortTrueWithGroup(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}). + WithModifyingPipeline(mongo.Pipeline{ + bson.D{{Key: "$set", Value: bson.M{"x": 1}}}, + bson.D{{Key: "$group", Value: bson.M{"_id": "$cat"}}}, + }) + + tst.AssertEqual(t, coll.needsDoubleSort(context.Background()), true) +} + +func TestNeedsDoubleSortNoExtraPipeline(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + tst.AssertEqual(t, coll.needsDoubleSort(context.Background()), false) +} + +// --- token creation ------------------------------------------------------- + +func TestCreateTokenPrimaryOnly(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + Value int `bson:"value"` + } + + coll := W[TestData](&mongo.Collection{}) + + pageSize := 50 + tok, err := coll.createToken("_id", ct.SortASC, nil, nil, TestData{ID: "abc", Value: 1}, &pageSize) + tst.AssertNoErr(t, err) + + cts, ok := tok.(ct.CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, cts.ValuePrimary, "abc") + tst.AssertEqual(t, cts.ValueSecondary, "") + tst.AssertEqual(t, cts.Direction, ct.SortASC) + tst.AssertEqual(t, cts.PageSize, 50) +} + +func TestCreateTokenWithSecondary(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + Value int `bson:"value"` + } + + coll := W[TestData](&mongo.Collection{}) + + secField := "value" + secDir := ct.SortDESC + tok, err := coll.createToken("_id", ct.SortASC, &secField, &secDir, TestData{ID: "abc", Value: 7}, nil) + tst.AssertNoErr(t, err) + + cts, ok := tok.(ct.CTKeySort) + tst.AssertTrue(t, ok) + tst.AssertEqual(t, cts.ValuePrimary, "abc") + tst.AssertEqual(t, cts.ValueSecondary, "7") + tst.AssertEqual(t, cts.PageSize, 0) +} + +func TestCreateTokenUnknownPrimaryField(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + _, err := coll.createToken("nonexistent", ct.SortASC, nil, nil, TestData{ID: "x"}, nil) + tst.AssertTrue(t, err != nil) +} + +func TestCreateTokenUnknownSecondaryField(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + sec := "nope" + dir := ct.SortASC + _, err := coll.createToken("_id", ct.SortASC, &sec, &dir, TestData{ID: "x"}, nil) + tst.AssertTrue(t, err != nil) +} + +// --- W constructor & Coll basics ------------------------------------------ + +func TestWInitializesMaps(t *testing.T) { + + type TestData struct { + ID string `bson:"_id"` + } + + coll := W[TestData](&mongo.Collection{}) + + tst.AssertTrue(t, coll != nil) + tst.AssertTrue(t, coll.dataTypeMap != nil) + tst.AssertTrue(t, coll.implDataTypeMap != nil) + tst.AssertEqual(t, coll.isInterfaceDataType, false) + tst.AssertTrue(t, len(coll.dataTypeMap) > 0) +} + +func TestWInitForInterfaceLeavesDataTypeMapEmpty(t *testing.T) { + + type Iface any + + coll := W[Iface](&mongo.Collection{}) + + tst.AssertEqual(t, coll.isInterfaceDataType, true) + tst.AssertEqual(t, len(coll.dataTypeMap), 0) + tst.AssertEqual(t, len(coll.implDataTypeMap), 0) +} + +// quick sanity: the langext.Ptr helper used across tests behaves as expected. +func TestLangextPtrSanity(t *testing.T) { + p := langext.Ptr(42) + tst.AssertTrue(t, p != nil) + tst.AssertEqual(t, *p, 42) +} diff --git a/wpdf/utils_test.go b/wpdf/utils_test.go new file mode 100644 index 0000000..601c5fa --- /dev/null +++ b/wpdf/utils_test.go @@ -0,0 +1,42 @@ +package wpdf + +import ( + "testing" +) + +func TestHexToColor(t *testing.T) { + cases := []struct { + in uint32 + want PDFColor + }{ + {0x000000, PDFColor{R: 0, G: 0, B: 0}}, + {0xFFFFFF, PDFColor{R: 255, G: 255, B: 255}}, + {0xFF0000, PDFColor{R: 255, G: 0, B: 0}}, + {0x00FF00, PDFColor{R: 0, G: 255, B: 0}}, + {0x0000FF, PDFColor{R: 0, G: 0, B: 255}}, + {0x123456, PDFColor{R: 0x12, G: 0x34, B: 0x56}}, + {0xC0C0C0, PDFColor{R: 192, G: 192, B: 192}}, + } + for _, c := range cases { + got := hexToColor(c.in) + if got != c.want { + t.Errorf("hexToColor(%#x) = %+v, want %+v", c.in, got, c.want) + } + } +} + +func TestHexToColorIgnoresHigherBits(t *testing.T) { + got := hexToColor(0xFF123456) + want := PDFColor{R: 0x12, G: 0x34, B: 0x56} + if got != want { + t.Errorf("hexToColor(0xFF123456) = %+v, want %+v", got, want) + } +} + +func TestRgbToColor(t *testing.T) { + got := rgbToColor(10, 20, 30) + want := PDFColor{R: 10, G: 20, B: 30} + if got != want { + t.Errorf("rgbToColor(10,20,30) = %+v, want %+v", got, want) + } +} diff --git a/wpdf/wpdf_builder_test.go b/wpdf/wpdf_builder_test.go new file mode 100644 index 0000000..36094df --- /dev/null +++ b/wpdf/wpdf_builder_test.go @@ -0,0 +1,368 @@ +package wpdf + +import ( + "bytes" + "testing" +) + +func newBuilderWithPage(t *testing.T) *WPDFBuilder { + t.Helper() + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + return b +} + +func TestNewPDFBuilderDefaults(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + + if b == nil { + t.Fatal("expected non-nil builder") + } + if b.FPDF() == nil { + t.Fatal("expected non-nil underlying gofpdf builder") + } + + if b.fontName != FontHelvetica { + t.Errorf("default fontName = %v, want %v", b.fontName, FontHelvetica) + } + if b.fontStyle != Normal { + t.Errorf("default fontStyle = %v, want %v", b.fontStyle, Normal) + } + if b.fontSize != 12 { + t.Errorf("default fontSize = %v, want 12", b.fontSize) + } + + left, top, right, _ := b.GetMargins() + if left != 15 || top != 25 || right != 15 { + t.Errorf("default margins = (%v, %v, %v), want (15, 25, 15)", left, top, right) + } +} + +func TestNewPDFBuilderUnicode(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, true) + if b.tr == nil { + t.Fatal("expected unicode translator to be set") + } + // Translator should not panic on simple ASCII input + out := b.tr("hello") + if out == "" { + t.Errorf("translator returned empty string for non-empty input") + } +} + +func TestNewPDFBuilderNonUnicode(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + if got := b.tr("hello"); got != "hello" { + t.Errorf("non-unicode translator = %q, want %q", got, "hello") + } +} + +func TestNewPDFBuilderLandscape(t *testing.T) { + b := NewPDFBuilder(Landscape, SizeA4, false) + w, h := b.GetPageSize() + if w <= h { + t.Errorf("landscape: expected width>height, got w=%v h=%v", w, h) + } +} + +func TestNewPDFBuilderPortrait(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + w, h := b.GetPageSize() + if h <= w { + t.Errorf("portrait: expected height>width, got w=%v h=%v", w, h) + } +} + +func TestSetMargins(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.SetMargins(PDFMargins{Left: 5, Top: 10, Right: 7}) + + if got := b.GetMarginLeft(); got != 5 { + t.Errorf("MarginLeft = %v, want 5", got) + } + if got := b.GetMarginTop(); got != 10 { + t.Errorf("MarginTop = %v, want 10", got) + } + if got := b.GetMarginRight(); got != 7 { + t.Errorf("MarginRight = %v, want 7", got) + } + // MarginBottom is not set explicitly via SetMargins, just verify accessor doesn't panic. + _ = b.GetMarginBottom() +} + +func TestGetWorkAreaWidth(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.SetMargins(PDFMargins{Left: 10, Top: 20, Right: 5}) + pw := b.GetPageWidth() + want := pw - 10 - 5 + if got := b.GetWorkAreaWidth(); got != want { + t.Errorf("GetWorkAreaWidth = %v, want %v", got, want) + } +} + +func TestGetPageWidthHeight(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + w, h := b.GetPageSize() + if got := b.GetPageWidth(); got != w { + t.Errorf("PageWidth = %v, want %v", got, w) + } + if got := b.GetPageHeight(); got != h { + t.Errorf("PageHeight = %v, want %v", got, h) + } +} + +func TestSetGetTextColor(t *testing.T) { + b := newBuilderWithPage(t) + b.SetTextColor(10, 20, 30) + r, g, bl := b.GetTextColor() + if r != 10 || g != 20 || bl != 30 { + t.Errorf("GetTextColor = (%d,%d,%d), want (10,20,30)", r, g, bl) + } +} + +func TestSetGetDrawColor(t *testing.T) { + b := newBuilderWithPage(t) + b.SetDrawColor(40, 50, 60) + r, g, bl := b.GetDrawColor() + if r != 40 || g != 50 || bl != 60 { + t.Errorf("GetDrawColor = (%d,%d,%d), want (40,50,60)", r, g, bl) + } +} + +func TestSetGetFillColor(t *testing.T) { + b := newBuilderWithPage(t) + b.SetFillColor(70, 80, 90) + r, g, bl := b.GetFillColor() + if r != 70 || g != 80 || bl != 90 { + t.Errorf("GetFillColor = (%d,%d,%d), want (70,80,90)", r, g, bl) + } +} + +func TestSetGetLineWidth(t *testing.T) { + b := newBuilderWithPage(t) + b.SetLineWidth(2.5) + if got := b.GetLineWidth(); got != 2.5 { + t.Errorf("LineWidth = %v, want 2.5", got) + } +} + +func TestSetFont(t *testing.T) { + b := newBuilderWithPage(t) + b.SetFont(FontTimes, Bold, 16) + if b.fontName != FontTimes { + t.Errorf("fontName = %v, want %v", b.fontName, FontTimes) + } + if b.fontStyle != Bold { + t.Errorf("fontStyle = %v, want %v", b.fontStyle, Bold) + } + if got := b.GetFontSize(); got != 16 { + t.Errorf("FontSize = %v, want 16", got) + } + if b.cellHeight <= 0 { + t.Errorf("cellHeight must be >0 after SetFont, got %v", b.cellHeight) + } +} + +func TestSetCellSpacing(t *testing.T) { + b := newBuilderWithPage(t) + b.SetCellSpacing(3.5) + if b.cellSpacing != 3.5 { + t.Errorf("cellSpacing = %v, want 3.5", b.cellSpacing) + } +} + +func TestSetGetXY(t *testing.T) { + b := newBuilderWithPage(t) + + b.SetX(50) + if got := b.GetX(); got != 50 { + t.Errorf("X = %v, want 50", got) + } + + b.SetY(60) + if got := b.GetY(); got != 60 { + t.Errorf("Y = %v, want 60", got) + } + + b.SetXY(70, 80) + x, y := b.GetXY() + if x != 70 || y != 80 { + t.Errorf("GetXY = (%v,%v), want (70,80)", x, y) + } +} + +func TestIncX(t *testing.T) { + b := newBuilderWithPage(t) + b.SetX(10) + b.IncX(5) + if got := b.GetX(); got != 15 { + t.Errorf("after IncX: X = %v, want 15", got) + } +} + +func TestDebug(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + if b.debug != false { + t.Errorf("default debug = %v, want false", b.debug) + } + b.Debug(true) + if !b.debug { + t.Errorf("Debug(true) did not enable debug") + } + b.Debug(false) + if b.debug { + t.Errorf("Debug(false) did not disable debug") + } +} + +func TestPageNo(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + if got := b.PageNo(); got != 0 { + t.Errorf("PageNo before AddPage = %v, want 0", got) + } + b.AddPage() + if got := b.PageNo(); got != 1 { + t.Errorf("PageNo after first AddPage = %v, want 1", got) + } + b.AddPage() + if got := b.PageNo(); got != 2 { + t.Errorf("PageNo after second AddPage = %v, want 2", got) + } +} + +func TestBuildEmpty(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + bin, err := b.Build() + if err != nil { + t.Fatalf("Build() error: %v", err) + } + if len(bin) == 0 { + t.Fatal("Build() returned empty bytes") + } + // PDFs start with "%PDF-" + if !bytes.HasPrefix(bin, []byte("%PDF-")) { + t.Errorf("Build() output does not start with %%PDF- header: %q", bin[:min(10, len(bin))]) + } +} + +func TestBuildWithContent(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + b.Cell("Hello", NewPDFCellOpt().Width(50)) + b.Ln(5) + b.MultiCell("multiline content") + b.Rect(10, 10, RectFill, NewPDFRectOpt().X(20).Y(20).FillColor(255, 0, 0)) + b.Line(0, 0, 100, 100, NewPDFLineOpt().LineWidth(0.5)) + + bin, err := b.Build() + if err != nil { + t.Fatalf("Build() error: %v", err) + } + if !bytes.HasPrefix(bin, []byte("%PDF-")) { + t.Errorf("Build() output does not start with %%PDF- header") + } +} + +func TestGetStringWidth(t *testing.T) { + b := newBuilderWithPage(t) + w := b.GetStringWidth("Hello") + if w <= 0 { + t.Errorf("GetStringWidth = %v, want >0", w) + } + + wLong := b.GetStringWidth("Hello, this is a longer string") + if wLong <= w { + t.Errorf("longer string should have wider width: short=%v long=%v", w, wLong) + } +} + +func TestGetStringWidthWithFontOverride(t *testing.T) { + b := newBuilderWithPage(t) + b.SetFont(FontHelvetica, Normal, 10) + + wSmall := b.GetStringWidth("Hello", *NewPDFCellOpt().FontSize(10)) + wLarge := b.GetStringWidth("Hello", *NewPDFCellOpt().FontSize(40)) + + if wLarge <= wSmall { + t.Errorf("larger font should yield wider string: small=%v large=%v", wSmall, wLarge) + } + + // Original font must be restored. + if b.fontSize != 10 { + t.Errorf("font size not restored: %v, want 10", b.fontSize) + } +} + +func TestSetAutoPageBreak(t *testing.T) { + b := newBuilderWithPage(t) + b.SetAutoPageBreak(true, 10) + enabled, margin := b.FPDF().GetAutoPageBreak() + if !enabled || margin != 10 { + t.Errorf("AutoPageBreak = (%v, %v), want (true, 10)", enabled, margin) + } + + b.SetAutoPageBreak(false, 0) + enabled, _ = b.FPDF().GetAutoPageBreak() + if enabled { + t.Errorf("AutoPageBreak should be disabled") + } +} + +func TestSetFooterFunc(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + called := false + b.SetFooterFunc(func() { called = true }) + b.AddPage() + _, err := b.Build() + if err != nil { + t.Fatalf("Build error: %v", err) + } + if !called { + t.Errorf("footer func was not called") + } +} + +func TestBookmark(t *testing.T) { + b := newBuilderWithPage(t) + // Should not panic + b.Bookmark("section 1", 0, b.GetY()) + + _, err := b.Build() + if err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestLnAdvancesY(t *testing.T) { + b := newBuilderWithPage(t) + b.SetXY(20, 50) + b.Ln(10) + if y := b.GetY(); y <= 50 { + t.Errorf("Ln did not advance Y: before=50, after=%v", y) + } +} + +func TestDebugAddPage(t *testing.T) { + // Test with debug = true to cover that branch + b := NewPDFBuilder(Portrait, SizeA4, false) + b.Debug(true) + b.AddPage() + // Just verify Build still works + _, err := b.Build() + if err != nil { + t.Fatalf("Build with debug page error: %v", err) + } +} + +func TestDebugLn(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + b.Debug(true) + b.SetXY(20, 50) + b.Ln(10) + _, err := b.Build() + if err != nil { + t.Fatalf("Build with debug Ln error: %v", err) + } +} diff --git a/wpdf/wpdf_drawing_test.go b/wpdf/wpdf_drawing_test.go new file mode 100644 index 0000000..222f6d7 --- /dev/null +++ b/wpdf/wpdf_drawing_test.go @@ -0,0 +1,265 @@ +package wpdf + +import ( + "bytes" + "testing" +) + +func TestCellWithAllOptions(t *testing.T) { + b := newBuilderWithPage(t) + + b.Cell("test", + NewPDFCellOpt(). + Width(40). + Height(8). + Border(BorderFull). + LnPos(BreakToNextLine). + Align(AlignHorzCenter). + FillBackground(true). + Font(FontTimes, Bold, 10). + LnAfter(2). + TextColor(255, 0, 0). + BorderColor(0, 255, 0). + FillColor(0, 0, 255). + Alpha(0.5, BlendMultiply)) + + out, err := b.Build() + if err != nil { + t.Fatalf("Build error: %v", err) + } + if !bytes.HasPrefix(out, []byte("%PDF-")) { + t.Error("output not a PDF") + } +} + +func TestCellRestoresFont(t *testing.T) { + b := newBuilderWithPage(t) + b.SetFont(FontHelvetica, Normal, 12) + + b.Cell("text", NewPDFCellOpt().Width(50).Font(FontTimes, Bold, 20)) + + if b.fontName != FontHelvetica { + t.Errorf("fontName not restored: %v", b.fontName) + } + if b.fontStyle != Normal { + t.Errorf("fontStyle not restored: %v", b.fontStyle) + } + if b.fontSize != 12 { + t.Errorf("fontSize not restored: %v", b.fontSize) + } +} + +func TestCellRestoresColors(t *testing.T) { + b := newBuilderWithPage(t) + b.SetTextColor(10, 20, 30) + b.SetDrawColor(40, 50, 60) + b.SetFillColor(70, 80, 90) + + b.Cell("text", + NewPDFCellOpt().Width(50). + TextColor(1, 2, 3). + BorderColor(4, 5, 6). + FillColor(7, 8, 9)) + + r, g, bl := b.GetTextColor() + if r != 10 || g != 20 || bl != 30 { + t.Errorf("TextColor not restored: (%d,%d,%d)", r, g, bl) + } + + r, g, bl = b.GetDrawColor() + if r != 40 || g != 50 || bl != 60 { + t.Errorf("DrawColor not restored: (%d,%d,%d)", r, g, bl) + } + + r, g, bl = b.GetFillColor() + if r != 70 || g != 80 || bl != 90 { + t.Errorf("FillColor not restored: (%d,%d,%d)", r, g, bl) + } +} + +func TestCellAutoWidth(t *testing.T) { + b := newBuilderWithPage(t) + + startX := b.GetX() + b.Cell("Hello", NewPDFCellOpt().AutoWidth().AutoWidthPaddingX(2).LnPos(BreakToRight)) + + endX := b.GetX() + if endX <= startX { + t.Errorf("AutoWidth: X did not advance (start=%v end=%v)", startX, endX) + } +} + +func TestCellWithExplicitX(t *testing.T) { + b := newBuilderWithPage(t) + + b.Cell("text", NewPDFCellOpt().Width(50).X(70).LnPos(BreakToRight)) + // X should be 70 + 50 = 120 after cell with BreakToRight + if got := b.GetX(); got < 120-0.01 || got > 120+0.01 { + t.Errorf("X after cell at X=70 with width=50 = %v, want ~120", got) + } +} + +func TestCellWithDebugTrueBreakToRight(t *testing.T) { + b := newBuilderWithPage(t) + b.Debug(true) + b.Cell("text", NewPDFCellOpt().Width(20).LnPos(BreakToRight)) + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestCellWithDebugTrueBreakToBelow(t *testing.T) { + b := newBuilderWithPage(t) + b.Debug(true) + b.Cell("text", NewPDFCellOpt().Width(20).LnPos(BreakToBelow)) + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestMultiCellAllOptions(t *testing.T) { + b := newBuilderWithPage(t) + b.MultiCell("hello\nworld", + NewPDFMultiCellOpt(). + Width(50). + Height(6). + Border(BorderFull). + Align(AlignLeft). + FillBackground(true). + Font(FontTimes, Italic, 10). + LnAfter(1). + X(20). + TextColor(1, 2, 3). + BorderColor(4, 5, 6). + FillColor(7, 8, 9). + Alpha(0.5, BlendNormal)) + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestMultiCellRestoresFont(t *testing.T) { + b := newBuilderWithPage(t) + b.SetFont(FontHelvetica, Normal, 12) + + b.MultiCell("text", NewPDFMultiCellOpt().Width(50).Font(FontTimes, Bold, 20)) + + if b.fontName != FontHelvetica { + t.Errorf("fontName not restored: %v", b.fontName) + } + if b.fontStyle != Normal { + t.Errorf("fontStyle not restored: %v", b.fontStyle) + } + if b.fontSize != 12 { + t.Errorf("fontSize not restored: %v", b.fontSize) + } +} + +func TestMultiCellWithDebug(t *testing.T) { + b := newBuilderWithPage(t) + b.Debug(true) + b.MultiCell("hello world", NewPDFMultiCellOpt().Width(50)) + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestRectAllOptions(t *testing.T) { + b := newBuilderWithPage(t) + b.Rect(20, 10, RectFillOutline, + NewPDFRectOpt(). + X(15).Y(20). + LineWidth(0.5). + DrawColor(10, 20, 30). + FillColor(100, 110, 120). + Alpha(0.5, BlendNormal). + Rounded(2)) + + // Drawing color and line width should be restored after Rect. + r, g, bl := b.GetDrawColor() + if r != 0 || g != 0 || bl != 0 { + t.Errorf("draw color not restored: (%d,%d,%d)", r, g, bl) + } + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestRectStyles(t *testing.T) { + b := newBuilderWithPage(t) + b.Rect(10, 10, RectFill, NewPDFRectOpt().X(10).Y(10).FillColor(255, 0, 0)) + b.Rect(10, 10, RectOutline, NewPDFRectOpt().X(30).Y(10).DrawColor(0, 255, 0)) + b.Rect(10, 10, RectFillOutline, NewPDFRectOpt().X(50).Y(10).FillColor(0, 0, 255).DrawColor(255, 255, 0)) + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestRectIndividualRadii(t *testing.T) { + b := newBuilderWithPage(t) + b.Rect(20, 20, RectOutline, + NewPDFRectOpt(). + X(10).Y(10). + RadiusTL(2).RadiusTR(3).RadiusBR(4).RadiusBL(5)) + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestLineAllOptions(t *testing.T) { + b := newBuilderWithPage(t) + b.Line(10, 10, 100, 100, + NewPDFLineOpt(). + LineWidth(0.5). + DrawColor(255, 0, 0). + Alpha(0.7, BlendDarken). + CapRound()) + + // Line width should be restored after Line. + if got := b.GetLineWidth(); got == 0.5 { + t.Errorf("LineWidth was not restored: %v", got) + } + + r, g, bl := b.GetDrawColor() + if r == 255 && g == 0 && bl == 0 { + t.Error("DrawColor not restored") + } + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestLineCapStyles(t *testing.T) { + b := newBuilderWithPage(t) + b.Line(10, 10, 100, 10, NewPDFLineOpt().LineWidth(2).CapButt()) + b.Line(10, 20, 100, 20, NewPDFLineOpt().LineWidth(2).CapRound()) + b.Line(10, 30, 100, 30, NewPDFLineOpt().LineWidth(2).CapSquare()) + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestCellWithExtraLn(t *testing.T) { + b := newBuilderWithPage(t) + yBefore := b.GetY() + b.Cell("text", NewPDFCellOpt().Width(50).LnAfter(5)) + yAfter := b.GetY() + if yAfter <= yBefore { + t.Errorf("Y did not advance after Cell with LnAfter") + } +} + +func TestMultiCellWithExtraLn(t *testing.T) { + b := newBuilderWithPage(t) + yBefore := b.GetY() + b.MultiCell("multiline\ntext", NewPDFMultiCellOpt().Width(50).LnAfter(5)) + yAfter := b.GetY() + if yAfter <= yBefore { + t.Errorf("Y did not advance after MultiCell with LnAfter") + } +} diff --git a/wpdf/wpdf_image_test.go b/wpdf/wpdf_image_test.go new file mode 100644 index 0000000..0d35be2 --- /dev/null +++ b/wpdf/wpdf_image_test.go @@ -0,0 +1,125 @@ +package wpdf + +import ( + "bytes" + "image" + "image/color" + "image/png" + "strings" + "testing" +) + +func makeTestPNG(t *testing.T, w, h int) []byte { + t.Helper() + img := image.NewRGBA(image.Rect(0, 0, w, h)) + for y := range h { + for x := range w { + img.Set(x, y, color.RGBA{R: uint8(x), G: uint8(y), B: 128, A: 255}) + } + } + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + t.Fatalf("png encode: %v", err) + } + return buf.Bytes() +} + +func TestRegisterImageDetectsPNG(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + + bin := makeTestPNG(t, 16, 16) + ref := b.RegisterImage(bin) + + if ref == nil { + t.Fatal("RegisterImage returned nil") + } + if !strings.HasPrefix(ref.Name, "fpdf_img_") { + t.Errorf("default name not prefixed: %q", ref.Name) + } + if ref.Mime != "image/png" { + t.Errorf("mime = %q, want image/png", ref.Mime) + } + if ref.Info == nil { + t.Error("info is nil") + } + if !bytes.Equal(ref.Bin, bin) { + t.Error("Bin not stored correctly") + } +} + +func TestRegisterImageWithCustomName(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + + bin := makeTestPNG(t, 8, 8) + ref := b.RegisterImage(bin, NewPDFImageRegisterOpt().Name("custom_image")) + + if ref.Name != "custom_image" { + t.Errorf("custom name not used: %q", ref.Name) + } +} + +func TestRegisterImageWithExplicitType(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + + bin := makeTestPNG(t, 8, 8) + ref := b.RegisterImage(bin, NewPDFImageRegisterOpt().ImageType("PNG")) + + if ref.Info == nil { + t.Error("info is nil") + } +} + +func TestRegisterImageLargeBuffer(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + + // Larger than 512 bytes to exercise that branch in detection. + bin := makeTestPNG(t, 256, 256) + if len(bin) <= 512 { + t.Fatalf("test png unexpectedly small: %d bytes", len(bin)) + } + ref := b.RegisterImage(bin) + + if ref.Mime != "image/png" { + t.Errorf("mime = %q, want image/png", ref.Mime) + } +} + +func TestImageDrawing(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + + bin := makeTestPNG(t, 16, 16) + ref := b.RegisterImage(bin) + + b.Image(ref, NewPDFImageOpt().X(10).Y(10).Width(20).Height(20)) + + out, err := b.Build() + if err != nil { + t.Fatalf("Build error: %v", err) + } + if !bytes.HasPrefix(out, []byte("%PDF-")) { + t.Error("output not a PDF") + } +} + +func TestImageDrawingDefaults(t *testing.T) { + b := NewPDFBuilder(Portrait, SizeA4, false) + b.AddPage() + + bin := makeTestPNG(t, 16, 16) + ref := b.RegisterImage(bin) + + b.Image(ref) // no opts: use info defaults + + out, err := b.Build() + if err != nil { + t.Fatalf("Build error: %v", err) + } + if !bytes.HasPrefix(out, []byte("%PDF-")) { + t.Error("output not a PDF") + } +} diff --git a/wpdf/wpdf_opts_test.go b/wpdf/wpdf_opts_test.go new file mode 100644 index 0000000..a3743c1 --- /dev/null +++ b/wpdf/wpdf_opts_test.go @@ -0,0 +1,455 @@ +package wpdf + +import ( + "testing" +) + +func TestPDFCellOptBuilders(t *testing.T) { + opt := NewPDFCellOpt(). + Width(100). + Height(20). + Border(BorderFull). + LnPos(BreakToRight). + Align(AlignHorzCenter). + FillBackground(true). + Link(5). + LinkStr("https://example.com"). + Font(FontTimes, Bold, 14). + LnAfter(2). + X(15). + AutoWidth(). + AutoWidthPaddingX(3). + TextColor(1, 2, 3). + BorderColor(4, 5, 6). + FillColor(7, 8, 9). + Alpha(0.5, BlendMultiply). + Debug(true) + + if opt.width == nil || *opt.width != 100 { + t.Error("width not set") + } + if opt.height == nil || *opt.height != 20 { + t.Error("height not set") + } + if opt.border == nil || *opt.border != BorderFull { + t.Error("border not set") + } + if opt.ln == nil || *opt.ln != BreakToRight { + t.Error("ln not set") + } + if opt.align == nil || *opt.align != AlignHorzCenter { + t.Error("align not set") + } + if opt.fill == nil || *opt.fill != true { + t.Error("fill not set") + } + if opt.link == nil || *opt.link != 5 { + t.Error("link not set") + } + if opt.linkStr == nil || *opt.linkStr != "https://example.com" { + t.Error("linkStr not set") + } + if opt.fontNameOverride == nil || *opt.fontNameOverride != FontTimes { + t.Error("fontNameOverride not set") + } + if opt.fontStyleOverride == nil || *opt.fontStyleOverride != Bold { + t.Error("fontStyleOverride not set") + } + if opt.fontSizeOverride == nil || *opt.fontSizeOverride != 14 { + t.Error("fontSizeOverride not set") + } + if opt.extraLn == nil || *opt.extraLn != 2 { + t.Error("extraLn not set") + } + if opt.x == nil || *opt.x != 15 { + t.Error("x not set") + } + if opt.autoWidth == nil || *opt.autoWidth != true { + t.Error("autoWidth not set") + } + if opt.autoWidthPaddingX == nil || *opt.autoWidthPaddingX != 3 { + t.Error("autoWidthPaddingX not set") + } + if opt.textColor == nil || *opt.textColor != (PDFColor{R: 1, G: 2, B: 3}) { + t.Error("textColor not set") + } + if opt.borderColor == nil || *opt.borderColor != (PDFColor{R: 4, G: 5, B: 6}) { + t.Error("borderColor not set") + } + if opt.fillColor == nil || *opt.fillColor != (PDFColor{R: 7, G: 8, B: 9}) { + t.Error("fillColor not set") + } + if opt.alphaOverride == nil || opt.alphaOverride.V1 != 0.5 || opt.alphaOverride.V2 != BlendMultiply { + t.Error("alphaOverride not set") + } + if opt.debug == nil || *opt.debug != true { + t.Error("debug not set") + } +} + +func TestPDFCellOptHexColors(t *testing.T) { + opt := NewPDFCellOpt(). + TextColorHex(0xFF0000). + BorderColorHex(0x00FF00). + FillColorHex(0x0000FF) + + if *opt.textColor != (PDFColor{R: 0xFF, G: 0, B: 0}) { + t.Errorf("textColorHex got %+v", opt.textColor) + } + if *opt.borderColor != (PDFColor{R: 0, G: 0xFF, B: 0}) { + t.Errorf("borderColorHex got %+v", opt.borderColor) + } + if *opt.fillColor != (PDFColor{R: 0, G: 0, B: 0xFF}) { + t.Errorf("fillColorHex got %+v", opt.fillColor) + } +} + +func TestPDFCellOptBoldItalic(t *testing.T) { + opt := NewPDFCellOpt().Bold() + if opt.fontStyleOverride == nil || *opt.fontStyleOverride != Bold { + t.Error("Bold() should set fontStyleOverride to Bold") + } + + opt2 := NewPDFCellOpt().Italic() + if opt2.fontStyleOverride == nil || *opt2.fontStyleOverride != Italic { + t.Error("Italic() should set fontStyleOverride to Italic") + } +} + +func TestPDFCellOptIndividualFontSetters(t *testing.T) { + opt := NewPDFCellOpt().FontName(FontCourier).FontStyle(Italic).FontSize(11) + if *opt.fontNameOverride != FontCourier { + t.Error("FontName") + } + if *opt.fontStyleOverride != Italic { + t.Error("FontStyle") + } + if *opt.fontSizeOverride != 11 { + t.Error("FontSize") + } +} + +func TestPDFCellOptCopy(t *testing.T) { + orig := NewPDFCellOpt().Width(50).Height(10) + cp := orig.Copy() + + if cp == orig { + t.Error("Copy() returned same pointer") + } + if *cp.width != 50 || *cp.height != 10 { + t.Error("copied opt does not match original") + } + + cp.Width(99) + if *orig.width != 50 { + t.Errorf("modifying copy affected original: orig.width=%v", *orig.width) + } +} + +func TestPDFCellOptToMulti(t *testing.T) { + cell := NewPDFCellOpt(). + Width(100). + Height(20). + Border(BorderTop). + Align(AlignRight). + FillBackground(true). + Font(FontTimes, Bold, 14). + LnAfter(3). + X(5). + TextColor(1, 2, 3). + BorderColor(4, 5, 6). + FillColor(7, 8, 9) + + multi := cell.ToMulti() + + if multi == nil { + t.Fatal("ToMulti returned nil") + } + if *multi.width != 100 { + t.Error("width not transferred") + } + if *multi.height != 20 { + t.Error("height not transferred") + } + if *multi.border != BorderTop { + t.Error("border not transferred") + } + if *multi.align != AlignRight { + t.Error("align not transferred") + } + if !*multi.fill { + t.Error("fill not transferred") + } + if *multi.fontNameOverride != FontTimes { + t.Error("fontName not transferred") + } + if *multi.fontStyleOverride != Bold { + t.Error("fontStyle not transferred") + } + if *multi.fontSizeOverride != 14 { + t.Error("fontSize not transferred") + } + if *multi.extraLn != 3 { + t.Error("extraLn not transferred") + } + if *multi.x != 5 { + t.Error("x not transferred") + } + if *multi.textColor != (PDFColor{R: 1, G: 2, B: 3}) { + t.Error("textColor not transferred") + } + if *multi.borderColor != (PDFColor{R: 4, G: 5, B: 6}) { + t.Error("borderColor not transferred") + } + if *multi.fillColor != (PDFColor{R: 7, G: 8, B: 9}) { + t.Error("fillColor not transferred") + } +} + +func TestPDFMultiCellOptBuilders(t *testing.T) { + opt := NewPDFMultiCellOpt(). + Width(80). + Height(10). + Border(BorderFull). + Align(AlignLeft). + FillBackground(true). + Font(FontTimes, Italic, 11). + LnAfter(1). + X(5). + TextColor(10, 20, 30). + BorderColor(40, 50, 60). + FillColor(70, 80, 90). + Alpha(0.7, BlendScreen). + Debug(true) + + if *opt.width != 80 || *opt.height != 10 || *opt.border != BorderFull { + t.Error("base opts not set") + } + if *opt.align != AlignLeft || !*opt.fill { + t.Error("align/fill not set") + } + if *opt.fontNameOverride != FontTimes || *opt.fontStyleOverride != Italic || *opt.fontSizeOverride != 11 { + t.Error("font not set") + } + if *opt.extraLn != 1 || *opt.x != 5 { + t.Error("layout opts not set") + } + if *opt.textColor != (PDFColor{R: 10, G: 20, B: 30}) { + t.Error("textColor not set") + } + if *opt.borderColor != (PDFColor{R: 40, G: 50, B: 60}) { + t.Error("borderColor not set") + } + if *opt.fillColor != (PDFColor{R: 70, G: 80, B: 90}) { + t.Error("fillColor not set") + } + if opt.alphaOverride.V1 != 0.7 || opt.alphaOverride.V2 != BlendScreen { + t.Error("alpha not set") + } + if !*opt.debug { + t.Error("debug not set") + } +} + +func TestPDFMultiCellOptHexColors(t *testing.T) { + opt := NewPDFMultiCellOpt(). + TextColorHex(0xAABBCC). + BorderColorHex(0x112233). + FillColorHex(0x445566) + + if *opt.textColor != (PDFColor{R: 0xAA, G: 0xBB, B: 0xCC}) { + t.Error("textColorHex") + } + if *opt.borderColor != (PDFColor{R: 0x11, G: 0x22, B: 0x33}) { + t.Error("borderColorHex") + } + if *opt.fillColor != (PDFColor{R: 0x44, G: 0x55, B: 0x66}) { + t.Error("fillColorHex") + } +} + +func TestPDFMultiCellOptBoldItalic(t *testing.T) { + if *NewPDFMultiCellOpt().Bold().fontStyleOverride != Bold { + t.Error("MultiCell Bold") + } + if *NewPDFMultiCellOpt().Italic().fontStyleOverride != Italic { + t.Error("MultiCell Italic") + } +} + +func TestPDFMultiCellOptIndividualFontSetters(t *testing.T) { + opt := NewPDFMultiCellOpt().FontName(FontCourier).FontStyle(Bold).FontSize(11) + if *opt.fontNameOverride != FontCourier { + t.Error("FontName") + } + if *opt.fontStyleOverride != Bold { + t.Error("FontStyle") + } + if *opt.fontSizeOverride != 11 { + t.Error("FontSize") + } +} + +func TestPDFMultiCellOptCopy(t *testing.T) { + orig := NewPDFMultiCellOpt().Width(50) + cp := orig.Copy() + if cp == orig { + t.Error("Copy() returned same pointer") + } + cp.Width(99) + if *orig.width != 50 { + t.Error("modifying copy affected original") + } +} + +func TestPDFRectOptBuilders(t *testing.T) { + opt := NewPDFRectOpt(). + X(5).Y(6). + LineWidth(0.3). + DrawColor(10, 20, 30). + FillColor(40, 50, 60). + Alpha(0.4, BlendOverlay). + Rounded(2). + Debug(true) + + if *opt.x != 5 || *opt.y != 6 { + t.Error("x/y not set") + } + if *opt.lineWidth != 0.3 { + t.Error("lineWidth not set") + } + if *opt.drawColor != (PDFColor{R: 10, G: 20, B: 30}) { + t.Error("drawColor not set") + } + if *opt.fillColor != (PDFColor{R: 40, G: 50, B: 60}) { + t.Error("fillColor not set") + } + if opt.alpha.V1 != 0.4 || opt.alpha.V2 != BlendOverlay { + t.Error("alpha not set") + } + if *opt.radiusTL != 2 || *opt.radiusTR != 2 || *opt.radiusBL != 2 || *opt.radiusBR != 2 { + t.Error("Rounded did not set all four corners") + } + if !*opt.debug { + t.Error("debug not set") + } +} + +func TestPDFRectOptIndividualRadii(t *testing.T) { + opt := NewPDFRectOpt().RadiusTL(1).RadiusTR(2).RadiusBR(3).RadiusBL(4) + + if *opt.radiusTL != 1 || *opt.radiusTR != 2 || *opt.radiusBR != 3 || *opt.radiusBL != 4 { + t.Errorf("individual radii not set: TL=%v TR=%v BR=%v BL=%v", + *opt.radiusTL, *opt.radiusTR, *opt.radiusBR, *opt.radiusBL) + } +} + +func TestPDFRectOptHexColors(t *testing.T) { + opt := NewPDFRectOpt().DrawColorHex(0xABCDEF).FillColorHex(0x123456) + if *opt.drawColor != (PDFColor{R: 0xAB, G: 0xCD, B: 0xEF}) { + t.Error("drawColorHex") + } + if *opt.fillColor != (PDFColor{R: 0x12, G: 0x34, B: 0x56}) { + t.Error("fillColorHex") + } +} + +func TestPDFLineOptBuilders(t *testing.T) { + opt := NewPDFLineOpt(). + LineWidth(0.5). + DrawColor(10, 20, 30). + Alpha(0.3, BlendDarken). + CapButt(). + Debug(true) + + if *opt.lineWidth != 0.5 { + t.Error("lineWidth not set") + } + if *opt.drawColor != (PDFColor{R: 10, G: 20, B: 30}) { + t.Error("drawColor not set") + } + if opt.alpha.V1 != 0.3 || opt.alpha.V2 != BlendDarken { + t.Error("alpha not set") + } + if *opt.capStyle != CapButt { + t.Error("capStyle CapButt not set") + } + if !*opt.debug { + t.Error("debug not set") + } +} + +func TestPDFLineOptCapStyles(t *testing.T) { + if *NewPDFLineOpt().CapButt().capStyle != CapButt { + t.Error("CapButt") + } + if *NewPDFLineOpt().CapRound().capStyle != CapRound { + t.Error("CapRound") + } + if *NewPDFLineOpt().CapSquare().capStyle != CapSquare { + t.Error("CapSquare") + } +} + +func TestPDFLineOptHexColor(t *testing.T) { + opt := NewPDFLineOpt().DrawColorHex(0xFEDCBA) + if *opt.drawColor != (PDFColor{R: 0xFE, G: 0xDC, B: 0xBA}) { + t.Error("drawColorHex") + } +} + +func TestPDFImageRegisterOptBuilders(t *testing.T) { + opt := NewPDFImageRegisterOpt(). + ImageType("PNG"). + ReadDpi(true). + AllowNegativePosition(true). + Name("custom_name") + + if *opt.imageType != "PNG" { + t.Error("imageType not set") + } + if !*opt.readDpi { + t.Error("readDpi not set") + } + if !*opt.allowNegativePosition { + t.Error("allowNegativePosition not set") + } + if *opt.name != "custom_name" { + t.Error("name not set") + } +} + +func TestPDFImageOptBuilders(t *testing.T) { + opt := NewPDFImageOpt(). + X(1).Y(2).Width(30).Height(40). + Flow(false). + Link(7).LinkStr("foo"). + ImageType("PNG"). + ReadDpi(true). + AllowNegativePosition(true). + Crop(0.1, 0.2, 0.3, 0.4). + Alpha(0.6, BlendMultiply). + Debug(true) + + if *opt.x != 1 || *opt.y != 2 || *opt.width != 30 || *opt.height != 40 { + t.Error("position/size not set") + } + if *opt.flow != false { + t.Error("flow not set") + } + if *opt.link != 7 || *opt.linkStr != "foo" { + t.Error("link not set") + } + if *opt.imageType != "PNG" || !*opt.readDpi || !*opt.allowNegativePosition { + t.Error("image options not set") + } + if opt.crop == nil || opt.crop.CropX != 0.1 || opt.crop.CropY != 0.2 || opt.crop.CropWidth != 0.3 || opt.crop.CropHeight != 0.4 { + t.Errorf("crop not set: %+v", opt.crop) + } + if opt.alphaOverride.V1 != 0.6 || opt.alphaOverride.V2 != BlendMultiply { + t.Error("alpha not set") + } + if !*opt.debug { + t.Error("debug not set") + } +} diff --git a/wpdf/wpdf_table_test.go b/wpdf/wpdf_table_test.go new file mode 100644 index 0000000..8d7c2ed --- /dev/null +++ b/wpdf/wpdf_table_test.go @@ -0,0 +1,553 @@ +package wpdf + +import ( + "bytes" + "math" + "testing" +) + +func TestTableBuilderInitialState(t *testing.T) { + b := newBuilderWithPage(t) + tb := b.Table() + + if tb == nil { + t.Fatal("Table() returned nil") + } + if tb.builder != b { + t.Error("builder back-reference not set") + } + if tb.padx != 2 { + t.Errorf("default padx = %v, want 2", tb.padx) + } + if tb.pady != 2 { + t.Errorf("default pady = %v, want 2", tb.pady) + } + if tb.defaultCellStyle == nil { + t.Error("default cell style is nil") + } + if tb.RowCount() != 0 { + t.Errorf("RowCount = %v, want 0", tb.RowCount()) + } +} + +func TestTableBuilderConfig(t *testing.T) { + b := newBuilderWithPage(t) + style := NewTableCellStyleOpt() + tb := b.Table(). + PadX(5). + PadY(7). + Widths("10", "20", "auto"). + DefaultStyle(style). + Debug(true) + + if tb.padx != 5 { + t.Errorf("padx = %v, want 5", tb.padx) + } + if tb.pady != 7 { + t.Errorf("pady = %v, want 7", tb.pady) + } + if tb.columnWidths == nil || len(*tb.columnWidths) != 3 { + t.Errorf("columnWidths not set correctly: %v", tb.columnWidths) + } + if tb.defaultCellStyle != style { + t.Error("defaultCellStyle not set") + } + if tb.debug == nil || !*tb.debug { + t.Error("debug not set") + } +} + +func TestTableBuilderAddRowDefault(t *testing.T) { + b := newBuilderWithPage(t) + tb := b.Table().AddRowDefaultStyle("a", "b", "c") + if tb.RowCount() != 1 { + t.Errorf("RowCount = %v, want 1", tb.RowCount()) + } + if len(tb.rows[0].cells) != 3 { + t.Errorf("cells = %v, want 3", len(tb.rows[0].cells)) + } + if tb.rows[0].cells[0].Content != "a" { + t.Errorf("cell[0].Content = %q, want %q", tb.rows[0].cells[0].Content, "a") + } +} + +func TestTableBuilderAddRowWithStyle(t *testing.T) { + b := newBuilderWithPage(t) + style := NewTableCellStyleOpt().Bold().FontSize(10) + + tb := b.Table().AddRowWithStyle(style, "x", "y") + if tb.RowCount() != 1 { + t.Fatalf("RowCount = %v, want 1", tb.RowCount()) + } + if len(tb.rows[0].cells) != 2 { + t.Fatalf("cells = %v, want 2", len(tb.rows[0].cells)) + } + for i, c := range tb.rows[0].cells { + if c.Style.fontStyleOverride == nil || *c.Style.fontStyleOverride != Bold { + t.Errorf("cell[%d] style not Bold", i) + } + if c.Style.fontSizeOverride == nil || *c.Style.fontSizeOverride != 10 { + t.Errorf("cell[%d] fontSize not 10", i) + } + } +} + +func TestTableBuilderAddRow(t *testing.T) { + b := newBuilderWithPage(t) + tc := TableCell{Content: "hello", Style: TableCellStyleOpt{}} + tb := b.Table().AddRow(tc, tc) + if tb.RowCount() != 1 || len(tb.rows[0].cells) != 2 { + t.Errorf("AddRow did not add expected cells") + } +} + +func TestTableBuilderBuildRowFlow(t *testing.T) { + b := newBuilderWithPage(t) + + tb := b.Table(). + BuildRow().Cell("a").Cell("b").Cell("c").BuildRow(). + BuildRow().Cells("d", "e", "f").BuildRow() + + if tb.RowCount() != 2 { + t.Errorf("RowCount = %v, want 2", tb.RowCount()) + } + if len(tb.rows[0].cells) != 3 || len(tb.rows[1].cells) != 3 { + t.Errorf("each row should have 3 cells; got %d and %d", + len(tb.rows[0].cells), len(tb.rows[1].cells)) + } + if tb.rows[0].cells[0].Content != "a" || tb.rows[1].cells[2].Content != "f" { + t.Error("cell content mismatch") + } +} + +func TestTableRowBuilderCellWithStyle(t *testing.T) { + b := newBuilderWithPage(t) + style := NewTableCellStyleOpt().Bold() + tb := b.Table().BuildRow().CellWithStyle("x", style).BuildRow() + + if tb.rows[0].cells[0].Style.fontStyleOverride == nil || + *tb.rows[0].cells[0].Style.fontStyleOverride != Bold { + t.Error("cell style not applied") + } +} + +func TestTableRowBuilderCellObjects(t *testing.T) { + b := newBuilderWithPage(t) + c1 := TableCell{Content: "alpha"} + c2 := TableCell{Content: "beta"} + + tb := b.Table().BuildRow().CellObject(c1).CellObjects(c2).BuildRow() + if len(tb.rows[0].cells) != 2 { + t.Fatalf("cells = %v, want 2", len(tb.rows[0].cells)) + } + if tb.rows[0].cells[0].Content != "alpha" || tb.rows[0].cells[1].Content != "beta" { + t.Error("cell objects not added") + } +} + +func TestTableRowBuilderRowStyle(t *testing.T) { + b := newBuilderWithPage(t) + rowStyle := NewTableCellStyleOpt().Italic() + + tb := b.Table().BuildRow().RowStyle(rowStyle).Cell("x").Cell("y").BuildRow() + + for i, c := range tb.rows[0].cells { + if c.Style.fontStyleOverride == nil || *c.Style.fontStyleOverride != Italic { + t.Errorf("cell[%d] should use row style (Italic)", i) + } + } +} + +func TestTableMaxFontSize(t *testing.T) { + row := tableRow{ + cells: []TableCell{ + {Style: *NewTableCellStyleOpt().FontSize(10)}, + {Style: *NewTableCellStyleOpt().FontSize(20)}, + {Style: *NewTableCellStyleOpt().FontSize(15)}, + }, + } + got := row.maxFontSize(8) + if got != 20 { + t.Errorf("maxFontSize = %v, want 20", got) + } + + rowEmpty := tableRow{cells: []TableCell{{Style: TableCellStyleOpt{}}}} + got = rowEmpty.maxFontSize(12) + if got != 12 { + t.Errorf("maxFontSize default = %v, want 12", got) + } +} + +func TestTableBuilderBuildEmpty(t *testing.T) { + b := newBuilderWithPage(t) + // Should not panic when no rows + b.Table().Build() +} + +func TestTableBuilderBuildNumeric(t *testing.T) { + b := newBuilderWithPage(t) + b.Table(). + Widths("30", "30", "30"). + AddRowDefaultStyle("a", "b", "c"). + AddRowDefaultStyle("d", "e", "f"). + Build() + + bin, err := b.Build() + if err != nil { + t.Fatalf("Build error: %v", err) + } + if !bytes.HasPrefix(bin, []byte("%PDF-")) { + t.Error("output not a PDF") + } +} + +func TestTableBuilderBuildAuto(t *testing.T) { + b := newBuilderWithPage(t) + b.Table(). + Widths("auto", "auto", "auto"). + AddRowDefaultStyle("a", "b", "c"). + Build() + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestTableBuilderBuildFr(t *testing.T) { + b := newBuilderWithPage(t) + b.Table(). + Widths("1fr", "2fr", "*"). + AddRowDefaultStyle("a", "b", "c"). + Build() + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestTableBuilderMixedColumnSpecs(t *testing.T) { + b := newBuilderWithPage(t) + b.Table(). + Widths("auto", "30", "1fr", "*"). + AddRowDefaultStyle("a", "b", "c", "d"). + AddRowDefaultStyle("longer text here", "x", "y", "z"). + Build() + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestTableBuilderNoColumnsDefined(t *testing.T) { + b := newBuilderWithPage(t) + // When Widths not specified, table uses "*" for each cell of first row + b.Table(). + AddRowDefaultStyle("a", "b"). + AddRowDefaultStyle("c", "d"). + Build() + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestTableBuilderMismatchedColumnCount(t *testing.T) { + b := newBuilderWithPage(t) + b.Table(). + Widths("10", "10"). + AddRowDefaultStyle("a", "b", "c"). // wrong column count + Build() + + // Should produce an error in the underlying gofpdf, surfaced by Build() + _, err := b.Build() + if err == nil { + t.Error("expected error for mismatched column count, got nil") + } +} + +func TestTableBuilderInvalidColumnWidth(t *testing.T) { + b := newBuilderWithPage(t) + b.Table(). + Widths("not-a-number"). + AddRowDefaultStyle("a"). + Build() + + _, err := b.Build() + if err == nil { + t.Error("expected error for invalid column width, got nil") + } +} + +func TestTableBuilderMultiCellRow(t *testing.T) { + b := newBuilderWithPage(t) + style := NewTableCellStyleOpt().MultiCell(true) + b.Table(). + Widths("auto", "1fr"). + AddRowWithStyle(style, "a", "Multi line\ntext content"). + Build() + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestTableBuilderEllipsizeRow(t *testing.T) { + b := newBuilderWithPage(t) + style := NewTableCellStyleOpt().MultiCell(false).Ellipsize(true) + b.Table(). + Widths("20"). + AddRowWithStyle(style, "this is a very long text that should be ellipsized"). + Build() + + if _, err := b.Build(); err != nil { + t.Fatalf("Build error: %v", err) + } +} + +func TestTableBuilderDefaultTableStyle(t *testing.T) { + s := defaultTableStyle() + if s == nil { + t.Fatal("defaultTableStyle returned nil") + } + if s.minWidth == nil || *s.minWidth != 5 { + t.Errorf("default minWidth = %v, want 5", s.minWidth) + } + if s.ellipsize == nil || !*s.ellipsize { + t.Errorf("default ellipsize should be true") + } + if s.multiCell == nil || *s.multiCell { + t.Errorf("default multiCell should be false") + } + if s.fontSizeOverride == nil || *s.fontSizeOverride != 8 { + t.Errorf("default font size = %v, want 8", s.fontSizeOverride) + } +} + +func TestTableCellStyleOptBuilders(t *testing.T) { + o := NewTableCellStyleOpt(). + MultiCell(true). + Ellipsize(false). + PaddingHorz(3). + MinWidth(7). + FillHeight(true) + + if !*o.multiCell { + t.Error("multiCell") + } + if *o.ellipsize { + t.Error("ellipsize") + } + if *o.paddingHorz != 3 { + t.Error("paddingHorz") + } + if *o.minWidth != 7 { + t.Error("minWidth") + } + if !*o.fillHeight { + t.Error("fillHeight") + } +} + +func TestTableCellStyleOptCellStyle(t *testing.T) { + cell := *NewPDFCellOpt().Width(50).Bold() + o := NewTableCellStyleOpt().CellStyle(cell) + + if o.PDFCellOpt.width == nil || *o.PDFCellOpt.width != 50 { + t.Error("CellStyle did not transfer width") + } + if o.PDFCellOpt.fontStyleOverride == nil || *o.PDFCellOpt.fontStyleOverride != Bold { + t.Error("CellStyle did not transfer style") + } +} + +func TestTableCellStyleOptDelegates(t *testing.T) { + o := NewTableCellStyleOpt(). + Width(10). + Height(20). + Border(BorderFull). + LnPos(BreakToBelow). + Align(AlignRight). + FillBackground(true). + Link(5). + LinkStr("link"). + Font(FontTimes, Bold, 11). + LnAfter(2). + X(3). + AutoWidth(). + AutoWidthPaddingX(1). + TextColor(1, 2, 3). + BorderColor(4, 5, 6). + FillColor(7, 8, 9). + Alpha(0.5, BlendNormal). + Debug(true) + + if *o.PDFCellOpt.width != 10 { + t.Error("width") + } + if *o.PDFCellOpt.height != 20 { + t.Error("height") + } + if *o.PDFCellOpt.border != BorderFull { + t.Error("border") + } + if *o.PDFCellOpt.ln != BreakToBelow { + t.Error("ln") + } + if *o.PDFCellOpt.align != AlignRight { + t.Error("align") + } + if !*o.PDFCellOpt.fill { + t.Error("fill") + } + if *o.PDFCellOpt.link != 5 { + t.Error("link") + } + if *o.PDFCellOpt.linkStr != "link" { + t.Error("linkStr") + } + if *o.PDFCellOpt.fontNameOverride != FontTimes { + t.Error("fontName") + } + if *o.PDFCellOpt.fontStyleOverride != Bold { + t.Error("fontStyle") + } + if *o.PDFCellOpt.fontSizeOverride != 11 { + t.Error("fontSize") + } + if *o.PDFCellOpt.extraLn != 2 { + t.Error("extraLn") + } + if *o.PDFCellOpt.x != 3 { + t.Error("x") + } + if !*o.PDFCellOpt.autoWidth { + t.Error("autoWidth") + } + if *o.PDFCellOpt.autoWidthPaddingX != 1 { + t.Error("autoWidthPaddingX") + } + if *o.PDFCellOpt.textColor != (PDFColor{R: 1, G: 2, B: 3}) { + t.Error("textColor") + } + if *o.PDFCellOpt.borderColor != (PDFColor{R: 4, G: 5, B: 6}) { + t.Error("borderColor") + } + if *o.PDFCellOpt.fillColor != (PDFColor{R: 7, G: 8, B: 9}) { + t.Error("fillColor") + } + if o.PDFCellOpt.alphaOverride.V1 != 0.5 || o.PDFCellOpt.alphaOverride.V2 != BlendNormal { + t.Error("alpha") + } + if !*o.PDFCellOpt.debug { + t.Error("debug") + } +} + +func TestTableCellStyleOptBoldItalic(t *testing.T) { + if *NewTableCellStyleOpt().Bold().PDFCellOpt.fontStyleOverride != Bold { + t.Error("Bold") + } + if *NewTableCellStyleOpt().Italic().PDFCellOpt.fontStyleOverride != Italic { + t.Error("Italic") + } +} + +func TestTableCellStyleOptHexColors(t *testing.T) { + o := NewTableCellStyleOpt(). + TextColorHex(0x010203). + BorderColorHex(0x040506). + FillColorHex(0x070809) + + if *o.PDFCellOpt.textColor != (PDFColor{R: 1, G: 2, B: 3}) { + t.Error("text hex") + } + if *o.PDFCellOpt.borderColor != (PDFColor{R: 4, G: 5, B: 6}) { + t.Error("border hex") + } + if *o.PDFCellOpt.fillColor != (PDFColor{R: 7, G: 8, B: 9}) { + t.Error("fill hex") + } +} + +func TestTableCellStyleOptIndividualFontSetters(t *testing.T) { + o := NewTableCellStyleOpt().FontName(FontCourier).FontStyle(Italic).FontSize(9) + if *o.PDFCellOpt.fontNameOverride != FontCourier { + t.Error("FontName") + } + if *o.PDFCellOpt.fontStyleOverride != Italic { + t.Error("FontStyle") + } + if *o.PDFCellOpt.fontSizeOverride != 9 { + t.Error("FontSize") + } +} + +func TestTableCalculateColumnsNumeric(t *testing.T) { + b := newBuilderWithPage(t) + tb := b.Table(). + Widths("30", "20", "40"). + AddRowDefaultStyle("a", "b", "c") + + w := tb.calculateColumns() + if len(w) != 3 { + t.Fatalf("widths = %v, want 3", len(w)) + } + if w[0] != 30 || w[1] != 20 || w[2] != 40 { + t.Errorf("widths = %v, want [30, 20, 40]", w) + } +} + +func TestTableCalculateColumnsFrSplit(t *testing.T) { + b := newBuilderWithPage(t) + b.SetMargins(PDFMargins{Left: 0, Top: 0, Right: 0}) + + tb := b.Table(). + PadX(0). + Widths("1fr", "1fr"). + AddRowDefaultStyle("a", "b") + + pageW := b.GetPageWidth() + w := tb.calculateColumns() + if len(w) != 2 { + t.Fatalf("widths = %v, want 2", len(w)) + } + + // fr columns are bounded by autoWidths (max content); since "a" and "b" + // are very narrow strings, both columns get the same auto-bounded width. + if math.Abs(w[0]-w[1]) > 0.01 { + t.Errorf("expected fr split widths roughly equal, got %v and %v", w[0], w[1]) + } + + // Total should not exceed available page width. + if w[0]+w[1] > pageW+0.01 { + t.Errorf("total width %v exceeds pageW %v", w[0]+w[1], pageW) + } +} + +func TestTableCalculateColumnsAutoUsesMinWidth(t *testing.T) { + b := newBuilderWithPage(t) + style := *NewTableCellStyleOpt() + mw := 50.0 + style.minWidth = &mw + + tb := b.Table(). + Widths("auto"). + AddRowWithStyle(&style, "x") + + w := tb.calculateColumns() + if len(w) != 1 { + t.Fatalf("widths = %v, want 1", len(w)) + } + if w[0] < 50 { + t.Errorf("width %v should respect minWidth=50", w[0]) + } +} + +func TestTableCalculateColumnsNoRows(t *testing.T) { + b := newBuilderWithPage(t) + tb := b.Table() + w := tb.calculateColumns() + if len(w) != 0 { + t.Errorf("widths for empty table = %v, want []", w) + } +} diff --git a/wsw/websocketWrapper_test.go b/wsw/websocketWrapper_test.go new file mode 100644 index 0000000..dfe5e72 --- /dev/null +++ b/wsw/websocketWrapper_test.go @@ -0,0 +1,652 @@ +package wsw + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +type testMessage struct { + Kind string `json:"kind"` + Value int `json:"value"` +} + +func newDummyReqRes() (*httptest.ResponseRecorder, *http.Request) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + return rec, req +} + +func TestNewWebSocketWrapper_NilOrigins_NoCheckOrigin(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + if w == nil { + t.Fatal("expected non-nil wrapper") + } + if w.upgrader.CheckOrigin != nil { + t.Errorf("expected CheckOrigin to be nil when allowedOrigins is nil") + } + if w.running { + t.Errorf("expected running to be false initially") + } + if w.writer != rec { + t.Errorf("expected writer to be the same passed in") + } + if w.request != req { + t.Errorf("expected request to be the same passed in") + } +} + +func TestNewWebSocketWrapper_WithOrigins_SetsCheckOrigin(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + if w == nil { + t.Fatal("expected non-nil wrapper") + } + if w.upgrader.CheckOrigin == nil { + t.Errorf("expected CheckOrigin to be set when allowedOrigins is non-nil") + } +} + +func makeReqWithOrigin(origin string) *http.Request { + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + if origin != "" { + r.Header.Set("Origin", origin) + } + return r +} + +func TestCheckOrigin_NoOriginHeader_AllowsRequest(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + r := makeReqWithOrigin("") + if !w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to return true when Origin header is missing") + } +} + +func TestCheckOrigin_InvalidURL_Rejects(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + // Construct a request whose Origin cannot be parsed by url.Parse. + r := httptest.NewRequest(http.MethodGet, "http://example.com/ws", nil) + r.Header["Origin"] = []string{"http://[::1:bad"} + if w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to return false for invalid Origin URL") + } +} + +func TestCheckOrigin_ExactMatch_Allowed(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + r := makeReqWithOrigin("https://example.com") + if !w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to allow exact host match") + } +} + +func TestCheckOrigin_NoMatch_Rejected(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + r := makeReqWithOrigin("https://other.com") + if w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to reject non-matching host") + } +} + +func TestCheckOrigin_WildcardMatch_Allowed(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"*.example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + r := makeReqWithOrigin("https://api.example.com") + if !w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to allow wildcard host match") + } +} + +func TestCheckOrigin_WildcardMatch_DoesNotMatchTopLevel(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"*.example.com"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + // path.Match's "*" does not cross dots, so "example.com" must not match "*.example.com". + r := makeReqWithOrigin("https://example.com") + if w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to reject bare host against wildcard subdomain pattern") + } +} + +func TestCheckOrigin_EmptyAllowedList_Rejects(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + r := makeReqWithOrigin("https://example.com") + if w.upgrader.CheckOrigin(r) { + t.Errorf("expected CheckOrigin to reject when allowed list is empty and Origin is set") + } +} + +func TestCheckOrigin_PortIsPartOfHost(t *testing.T) { + rec, req := newDummyReqRes() + allowed := []string{"example.com:8080"} + w := NewWebSocketWrapper[testMessage](rec, req, &allowed) + + rOK := makeReqWithOrigin("http://example.com:8080") + if !w.upgrader.CheckOrigin(rOK) { + t.Errorf("expected host:port to match allowed host:port") + } + + // Without port -> different host -> reject. + rNo := makeReqWithOrigin("http://example.com") + if w.upgrader.CheckOrigin(rNo) { + t.Errorf("expected host without port to not match allowed host:port") + } +} + +func TestRunning_BeforeStart_False(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + if w.Running() { + t.Errorf("expected Running() to be false before Start") + } +} + +func TestDecode_ValidJSON(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + msg, err := w.decode([]byte(`{"kind":"hello","value":7}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if msg.Kind != "hello" || msg.Value != 7 { + t.Errorf("unexpected decoded message: %+v", msg) + } +} + +func TestDecode_InvalidJSON_NoFallback_ReturnsError(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + _, err := w.decode([]byte("not json")) + if err == nil { + t.Fatalf("expected error when decoding invalid JSON without fallback") + } +} + +func TestDecode_InvalidJSON_WithFallback_UsesFallback(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + called := false + w.SetFallbackDecoder(func(b []byte) (testMessage, error) { + called = true + return testMessage{Kind: "fallback:" + string(b), Value: 42}, nil + }) + + msg, err := w.decode([]byte("not json")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Errorf("expected fallback decoder to be called") + } + if msg.Kind != "fallback:not json" || msg.Value != 42 { + t.Errorf("unexpected fallback result: %+v", msg) + } +} + +func TestDecode_InvalidJSON_FallbackError_Propagates(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + sentinel := errors.New("nope") + w.SetFallbackDecoder(func(b []byte) (testMessage, error) { + return testMessage{}, sentinel + }) + + _, err := w.decode([]byte("not json")) + if !errors.Is(err, sentinel) { + t.Errorf("expected sentinel error from fallback decoder, got %v", err) + } +} + +func TestDecode_ValidJSON_DoesNotCallFallback(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + called := false + w.SetFallbackDecoder(func(b []byte) (testMessage, error) { + called = true + return testMessage{}, nil + }) + + _, err := w.decode([]byte(`{"kind":"x","value":1}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if called { + t.Errorf("did not expect fallback decoder to be called when JSON is valid") + } +} + +func TestSetFallbackDecoder_StoresFunction(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + if w.fallbackDecoder != nil { + t.Fatalf("expected fallbackDecoder to be nil initially") + } + w.SetFallbackDecoder(func(b []byte) (testMessage, error) { + return testMessage{}, nil + }) + if w.fallbackDecoder == nil { + t.Errorf("expected fallbackDecoder to be set after SetFallbackDecoder") + } +} + +// --------------------------------------------------------------------------- +// Integration tests (using httptest + gorilla/websocket loopback). +// --------------------------------------------------------------------------- + +// startTestServer wires up a httptest server that, on a websocket upgrade +// request, builds a WebSocketWrapper, starts it, and returns it via the +// supplied channel for the test to interact with. The returned cleanup +// function should be deferred by the caller. +func startTestServer(t *testing.T, allowedOrigins *[]string) (*httptest.Server, <-chan *WebSocketWrapper[testMessage]) { + t.Helper() + wrapperCh := make(chan *WebSocketWrapper[testMessage], 1) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := NewWebSocketWrapper[testMessage](w, r, allowedOrigins) + if err := ww.Start(); err != nil { + t.Errorf("Start failed: %v", err) + wrapperCh <- nil + return + } + wrapperCh <- ww + })) + t.Cleanup(srv.Close) + return srv, wrapperCh +} + +// httpToWS rewrites an http(s) URL into ws(s). +func httpToWS(u string) string { + if rest, ok := strings.CutPrefix(u, "https://"); ok { + return "wss://" + rest + } + rest, _ := strings.CutPrefix(u, "http://") + return "ws://" + rest +} + +func dialWS(t *testing.T, server *httptest.Server) *websocket.Conn { + t.Helper() + wsURL := httpToWS(server.URL) + c, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + return c +} + +func TestIntegration_StartReceivesValidMessage(t *testing.T) { + srv, wrapperCh := startTestServer(t, nil) + c := dialWS(t, srv) + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + if !ww.Running() { + t.Errorf("expected Running() to be true after Start") + } + + payload := testMessage{Kind: "ping", Value: 9} + b, _ := json.Marshal(payload) + if err := c.WriteMessage(websocket.TextMessage, b); err != nil { + t.Fatalf("client write: %v", err) + } + + select { + case got := <-ww.MessageChan: + if got.Kind != "ping" || got.Value != 9 { + t.Errorf("unexpected message: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for message") + } + + // Drain CloseChan in a goroutine because Close() sends to it synchronously. + closeReason := make(chan string, 1) + go func() { + if r, ok := <-ww.CloseChan; ok { + closeReason <- r + } else { + closeReason <- "" + } + }() + + ww.Close("test-done") + + select { + case r := <-closeReason: + if r != "test-done" { + t.Errorf("expected close reason 'test-done', got %q", r) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for close reason") + } + + if ww.Running() { + t.Errorf("expected Running() to be false after Close") + } +} + +func TestIntegration_InvalidJSON_SendsErrorOnErrorChan(t *testing.T) { + srv, wrapperCh := startTestServer(t, nil) + c := dialWS(t, srv) + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + if err := c.WriteMessage(websocket.TextMessage, []byte("garbage")); err != nil { + t.Fatalf("client write: %v", err) + } + + select { + case err := <-ww.ErrorChan: + if err == nil { + t.Errorf("expected non-nil error on ErrorChan") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for error on ErrorChan") + } + + // Drain CloseChan and close cleanly. + go func() { <-ww.CloseChan }() + ww.Close() +} + +func TestIntegration_InvalidJSON_FallbackDecoder(t *testing.T) { + wrapperCh := make(chan *WebSocketWrapper[testMessage], 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := NewWebSocketWrapper[testMessage](w, r, nil) + ww.SetFallbackDecoder(func(b []byte) (testMessage, error) { + return testMessage{Kind: "raw:" + string(b), Value: 1}, nil + }) + if err := ww.Start(); err != nil { + t.Errorf("Start failed: %v", err) + wrapperCh <- nil + return + } + wrapperCh <- ww + })) + defer srv.Close() + + c := dialWS(t, srv) + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + if err := c.WriteMessage(websocket.TextMessage, []byte("not-json-but-ok")); err != nil { + t.Fatalf("client write: %v", err) + } + + select { + case got := <-ww.MessageChan: + if got.Kind != "raw:not-json-but-ok" { + t.Errorf("unexpected fallback-decoded message: %+v", got) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for fallback-decoded message") + } + + go func() { <-ww.CloseChan }() + ww.Close() +} + +func TestIntegration_BinaryMessage_SendsErrorOnErrorChan(t *testing.T) { + srv, wrapperCh := startTestServer(t, nil) + c := dialWS(t, srv) + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + if err := c.WriteMessage(websocket.BinaryMessage, []byte{0x01, 0x02, 0x03}); err != nil { + t.Fatalf("client write: %v", err) + } + + select { + case err := <-ww.ErrorChan: + if err == nil { + t.Errorf("expected non-nil error for binary message") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for binary-message error") + } + + go func() { <-ww.CloseChan }() + ww.Close() +} + +func TestIntegration_Send_DeliversToClient(t *testing.T) { + srv, wrapperCh := startTestServer(t, nil) + c := dialWS(t, srv) + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + go func() { + ww.Send(testMessage{Kind: "outbound", Value: 123}) + }() + + mt, raw, err := c.ReadMessage() + if err != nil { + t.Fatalf("client read: %v", err) + } + if mt != websocket.TextMessage { + t.Errorf("expected text message, got type %d", mt) + } + + var got testMessage + if err := json.Unmarshal(raw, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Kind != "outbound" || got.Value != 123 { + t.Errorf("unexpected payload: %+v", got) + } + + go func() { <-ww.CloseChan }() + ww.Close() +} + +func TestIntegration_Send_NotRunning_NoPanic(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + // Should not panic; should just print an internal error and return. + defer func() { + if r := recover(); r != nil { + t.Errorf("Send panicked when not running: %v", r) + } + }() + w.Send(testMessage{Kind: "x", Value: 1}) +} + +func TestIntegration_Close_BeforeStart_NoPanic(t *testing.T) { + rec, req := newDummyReqRes() + w := NewWebSocketWrapper[testMessage](rec, req, nil) + + defer func() { + if r := recover(); r != nil { + t.Errorf("Close panicked when not running: %v", r) + } + }() + w.Close("never started") +} + +func TestIntegration_Close_IsIdempotent(t *testing.T) { + srv, wrapperCh := startTestServer(t, nil) + c := dialWS(t, srv) + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + // First close: drain CloseChan in goroutine, since send is synchronous. + done := make(chan struct{}) + go func() { + <-ww.CloseChan + close(done) + }() + ww.Close("first") + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first close") + } + + // Second close: must be a no-op (channels are already closed). + defer func() { + if r := recover(); r != nil { + t.Errorf("second Close() panicked: %v", r) + } + }() + ww.Close("second") + + if ww.Running() { + t.Errorf("expected Running() to remain false after repeated Close") + } +} + +func TestIntegration_ClientClose_TriggersCloseChan(t *testing.T) { + srv, wrapperCh := startTestServer(t, nil) + c := dialWS(t, srv) + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + + // Reader for CloseChan. + gotClose := make(chan string, 1) + go func() { + if r, ok := <-ww.CloseChan; ok { + gotClose <- r + } else { + gotClose <- "" + } + }() + + // Client closes the connection abruptly. + _ = c.Close() + + select { + case r := <-gotClose: + if !strings.Contains(strings.ToLower(r), "fail") && !strings.Contains(strings.ToLower(r), "close") { + t.Errorf("expected close reason mentioning failure/close, got %q", r) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server-side close on client disconnect") + } + + if ww.Running() { + t.Errorf("expected Running() to be false after client close propagated") + } +} + +func TestIntegration_OriginCheck_RejectsDisallowedOrigin(t *testing.T) { + allowed := []string{"allowed.example"} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := NewWebSocketWrapper[testMessage](w, r, &allowed) + _ = ww.Start() // expected to fail with 403, returns error - we ignore it. + })) + defer srv.Close() + + wsURL := httpToWS(srv.URL) + hdr := http.Header{} + hdr.Set("Origin", "http://disallowed.example") + c, resp, err := websocket.DefaultDialer.Dial(wsURL, hdr) + if err == nil { + _ = c.Close() + t.Fatalf("expected dial to fail due to origin rejection") + } + if resp == nil { + t.Fatalf("expected an HTTP response on origin rejection") + } + if resp.StatusCode != http.StatusForbidden { + t.Errorf("expected status 403, got %d", resp.StatusCode) + } +} + +func TestIntegration_OriginCheck_AcceptsAllowedOrigin(t *testing.T) { + allowed := []string{"allowed.example"} + wrapperCh := make(chan *WebSocketWrapper[testMessage], 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ww := NewWebSocketWrapper[testMessage](w, r, &allowed) + if err := ww.Start(); err != nil { + wrapperCh <- nil + return + } + wrapperCh <- ww + })) + defer srv.Close() + + wsURL := httpToWS(srv.URL) + hdr := http.Header{} + hdr.Set("Origin", "http://allowed.example") + c, _, err := websocket.DefaultDialer.Dial(wsURL, hdr) + if err != nil { + t.Fatalf("dial with allowed origin failed: %v", err) + } + defer c.Close() + + ww := <-wrapperCh + if ww == nil { + t.Fatal("wrapper not initialized") + } + if !ww.Running() { + t.Errorf("expected Running() true after successful upgrade") + } + + go func() { <-ww.CloseChan }() + ww.Close() +} + diff --git a/zipext/zip_test.go b/zipext/zip_test.go new file mode 100644 index 0000000..03dcc04 --- /dev/null +++ b/zipext/zip_test.go @@ -0,0 +1,251 @@ +package zipext + +import ( + "archive/zip" + "bytes" + "io" + "testing" + + "git.blackforestbytes.com/BlackForestBytes/goext/tst" +) + +func readZipEntries(t *testing.T, data []byte) map[string][]byte { + t.Helper() + + r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + tst.AssertNoErr(t, err) + + out := make(map[string][]byte, len(r.File)) + for _, f := range r.File { + rc, err := f.Open() + tst.AssertNoErr(t, err) + + body, err := io.ReadAll(rc) + tst.AssertNoErr(t, err) + + err = rc.Close() + tst.AssertNoErr(t, err) + + out[f.Name] = body + } + return out +} + +func TestNewMemoryZipZipOnlyEmpty(t *testing.T) { + mz := NewMemoryZip(true, false) + + data, err := mz.GetZip() + tst.AssertNoErr(t, err) + tst.AssertTrue(t, len(data) > 0) + + entries := readZipEntries(t, data) + tst.AssertEqual(t, len(entries), 0) +} + +func TestMemoryZipAddSingleFile(t *testing.T) { + mz := NewMemoryZip(true, false) + + payload := []byte("Hello World") + err := mz.AddFile("hello.txt", payload) + tst.AssertNoErr(t, err) + + data, err := mz.GetZip() + tst.AssertNoErr(t, err) + + entries := readZipEntries(t, data) + tst.AssertEqual(t, len(entries), 1) + + got, ok := entries["hello.txt"] + tst.AssertTrue(t, ok) + tst.AssertArrayEqual(t, got, payload) +} + +func TestMemoryZipAddMultipleFiles(t *testing.T) { + mz := NewMemoryZip(true, false) + + files := map[string][]byte{ + "a.txt": []byte("aaa"), + "sub/b.txt": []byte("bbbb"), + "sub/dir/c.bin": {0, 1, 2, 3, 4, 5, 250, 251, 252, 253, 254, 255}, + "empty.txt": {}, + "d/e/f/g/h.json": []byte(`{"k":"v"}`), + } + + for name, body := range files { + err := mz.AddFile(name, body) + tst.AssertNoErr(t, err) + } + + data, err := mz.GetZip() + tst.AssertNoErr(t, err) + + entries := readZipEntries(t, data) + tst.AssertEqual(t, len(entries), len(files)) + + for name, expected := range files { + got, ok := entries[name] + tst.AssertTrue(t, ok) + tst.AssertArrayEqual(t, got, expected) + } +} + +func TestMemoryZipAddDuplicatePaths(t *testing.T) { + mz := NewMemoryZip(true, false) + + err := mz.AddFile("dup.txt", []byte("first")) + tst.AssertNoErr(t, err) + + err = mz.AddFile("dup.txt", []byte("second")) + tst.AssertNoErr(t, err) + + data, err := mz.GetZip() + tst.AssertNoErr(t, err) + + r, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + tst.AssertNoErr(t, err) + tst.AssertEqual(t, len(r.File), 2) +} + +func TestMemoryZipAddLargerData(t *testing.T) { + mz := NewMemoryZip(true, false) + + payload := make([]byte, 64*1024) + for i := range payload { + payload[i] = byte(i % 251) + } + + err := mz.AddFile("big.bin", payload) + tst.AssertNoErr(t, err) + + data, err := mz.GetZip() + tst.AssertNoErr(t, err) + + entries := readZipEntries(t, data) + tst.AssertEqual(t, len(entries), 1) + tst.AssertArrayEqual(t, entries["big.bin"], payload) +} + +func TestMemoryZipGetZipClosesWriter(t *testing.T) { + mz := NewMemoryZip(true, false) + + err := mz.AddFile("a.txt", []byte("data")) + tst.AssertNoErr(t, err) + + tst.AssertTrue(t, mz.open) + + _, err = mz.GetZip() + tst.AssertNoErr(t, err) + + tst.AssertFalse(t, mz.open) +} + +func TestMemoryZipGetZipIdempotent(t *testing.T) { + mz := NewMemoryZip(true, false) + + err := mz.AddFile("a.txt", []byte("data")) + tst.AssertNoErr(t, err) + + first, err := mz.GetZip() + tst.AssertNoErr(t, err) + + second, err := mz.GetZip() + tst.AssertNoErr(t, err) + + tst.AssertArrayEqual(t, first, second) +} + +func TestMemoryZipAddFileAfterCloseFails(t *testing.T) { + mz := NewMemoryZip(true, false) + + err := mz.Close() + tst.AssertNoErr(t, err) + + err = mz.AddFile("a.txt", []byte("data")) + tst.AssertTrue(t, err == errAlreadyClosed) +} + +func TestMemoryZipAddFileAfterGetZipFails(t *testing.T) { + mz := NewMemoryZip(true, false) + + _, err := mz.GetZip() + tst.AssertNoErr(t, err) + + err = mz.AddFile("a.txt", []byte("data")) + tst.AssertTrue(t, err == errAlreadyClosed) +} + +func TestMemoryZipDoubleCloseIsNoop(t *testing.T) { + mz := NewMemoryZip(true, false) + + err := mz.Close() + tst.AssertNoErr(t, err) + + err = mz.Close() + tst.AssertNoErr(t, err) +} + +func TestMemoryZipGetZipNotEnabled(t *testing.T) { + mz := NewMemoryZip(false, false) + + _, err := mz.GetZip() + tst.AssertTrue(t, err == errZipNotEnabled) +} + +func TestMemoryZipGetTarGzNotEnabled(t *testing.T) { + mz := NewMemoryZip(true, false) + + _, err := mz.GetTarGz() + tst.AssertTrue(t, err == errTgzNotEnabled) +} + +func TestMemoryZipBothDisabledClose(t *testing.T) { + mz := NewMemoryZip(false, false) + + err := mz.AddFile("a.txt", []byte("data")) + tst.AssertNoErr(t, err) + + err = mz.Close() + tst.AssertNoErr(t, err) + + _, err = mz.GetZip() + tst.AssertTrue(t, err == errZipNotEnabled) + + _, err = mz.GetTarGz() + tst.AssertTrue(t, err == errTgzNotEnabled) +} + +func TestMemoryZipBothDisabledAddFileAfterCloseFails(t *testing.T) { + mz := NewMemoryZip(false, false) + + err := mz.Close() + tst.AssertNoErr(t, err) + + err = mz.AddFile("a.txt", []byte("data")) + tst.AssertTrue(t, err == errAlreadyClosed) +} + +func TestMemoryZipNewMemoryZipFlags(t *testing.T) { + z1 := NewMemoryZip(true, false) + tst.AssertTrue(t, z1.zipEnabled) + tst.AssertFalse(t, z1.tarEnabled) + tst.AssertTrue(t, z1.open) + + z2 := NewMemoryZip(false, false) + tst.AssertFalse(t, z2.zipEnabled) + tst.AssertFalse(t, z2.tarEnabled) + tst.AssertTrue(t, z2.open) +} + +func TestMemoryZipZipBytesAreValidZipMagic(t *testing.T) { + mz := NewMemoryZip(true, false) + + err := mz.AddFile("x.txt", []byte("y")) + tst.AssertNoErr(t, err) + + data, err := mz.GetZip() + tst.AssertNoErr(t, err) + + tst.AssertTrue(t, len(data) >= 4) + tst.AssertEqual(t, data[0], byte(0x50)) + tst.AssertEqual(t, data[1], byte(0x4B)) +}