v0.0.351 sq value converter
All checks were successful
Build Docker and Deploy / Run goext test-suite (push) Successful in 2m30s

This commit is contained in:
2023-12-29 19:25:36 +01:00
parent 6e90239fef
commit f9ccafb976
12 changed files with 280 additions and 111 deletions

View File

@@ -4,17 +4,11 @@ import (
"errors"
"fmt"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/rfctime"
"reflect"
"time"
)
//TODO UNFINISHED
// this is not finished
// idea was that we can register converter in the database struct
// they get inherited from the transactions
// and when marshallingunmarshaling (sq.Query | sq.QueryAll)
// or marshaling (sq.InsertSingle)
// the types get converter automatically...
type DBTypeConverter interface {
ModelTypeString() string
DBTypeString() string
@@ -40,16 +34,42 @@ var ConverterTimeToUnixMillis = NewDBTypeConverter[time.Time, int64](func(v time
return time.UnixMilli(v), nil
})
var ConverterOptTimeToUnixMillis = NewDBTypeConverter[*time.Time, *int64](func(v *time.Time) (*int64, error) {
if v == nil {
return nil, nil
var ConverterRFCUnixMilliTimeToUnixMillis = NewDBTypeConverter[rfctime.UnixMilliTime, int64](func(v rfctime.UnixMilliTime) (int64, error) {
return v.UnixMilli(), nil
}, func(v int64) (rfctime.UnixMilliTime, error) {
return rfctime.NewUnixMilli(time.UnixMilli(v)), nil
})
var ConverterRFCUnixNanoTimeToUnixNanos = NewDBTypeConverter[rfctime.UnixNanoTime, int64](func(v rfctime.UnixNanoTime) (int64, error) {
return v.UnixNano(), nil
}, func(v int64) (rfctime.UnixNanoTime, error) {
return rfctime.NewUnixNano(time.Unix(0, v)), nil
})
var ConverterRFCUnixTimeToUnixSeconds = NewDBTypeConverter[rfctime.UnixTime, int64](func(v rfctime.UnixTime) (int64, error) {
return v.Unix(), nil
}, func(v int64) (rfctime.UnixTime, error) {
return rfctime.NewUnix(time.Unix(v, 0)), nil
})
var ConverterRFC339TimeToString = NewDBTypeConverter[rfctime.RFC3339Time, string](func(v rfctime.RFC3339Time) (string, error) {
return v.Format(time.RFC3339), nil
}, func(v string) (rfctime.RFC3339Time, error) {
t, err := time.Parse(time.RFC3339Nano, v)
if err != nil {
return rfctime.RFC3339Time{}, err
}
return langext.Ptr(v.UnixMilli()), nil
}, func(v *int64) (*time.Time, error) {
if v == nil {
return nil, nil
return rfctime.NewRFC3339(t), nil
})
var ConverterRFC339NanoTimeToString = NewDBTypeConverter[rfctime.RFC3339NanoTime, string](func(v rfctime.RFC3339NanoTime) (string, error) {
return v.Format(time.RFC3339Nano), nil
}, func(v string) (rfctime.RFC3339NanoTime, error) {
t, err := time.Parse(time.RFC3339Nano, v)
if err != nil {
return rfctime.RFC3339NanoTime{}, err
}
return langext.Ptr(time.UnixMilli(*v)), nil
return rfctime.NewRFC3339Nano(t), nil
})
type dbTypeConverterImpl[TModelData any, TDBData any] struct {
@@ -89,3 +109,36 @@ func NewDBTypeConverter[TModelData any, TDBData any](todb func(v TModelData) (TD
tomodel: tomodel,
}
}
func convertValueToDB(q Queryable, value any) (any, error) {
modelTypeStr := fmt.Sprintf("%T", value)
for _, conv := range q.ListConverter() {
if conv.ModelTypeString() == modelTypeStr {
return conv.ModelToDB(value)
}
}
if value != nil && reflect.TypeOf(value).Kind() == reflect.Ptr {
vof := reflect.ValueOf(value)
if vof.IsNil() {
return nil, nil
} else {
return convertValueToDB(q, vof.Elem().Interface())
}
}
return value, nil
}
func convertValueToModel(q Queryable, value any, destinationType string) (any, error) {
dbTypeString := fmt.Sprintf("%T", value)
for _, conv := range q.ListConverter() {
if conv.ModelTypeString() == destinationType && conv.DBTypeString() == dbTypeString {
return conv.DBToModel(value)
}
}
return value, nil
}

View File

@@ -4,16 +4,19 @@ import (
"context"
"database/sql"
"github.com/jmoiron/sqlx"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"sync"
)
type DB interface {
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
Queryable
Ping(ctx context.Context) error
BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error)
AddListener(listener Listener)
Exit() error
RegisterConverter(DBTypeConverter)
RegisterDefaultConverter()
}
type database struct {
@@ -21,6 +24,7 @@ type database struct {
txctr uint16
lock sync.Mutex
lstr []Listener
conv []DBTypeConverter
}
func NewDB(db *sqlx.DB) DB {
@@ -120,9 +124,28 @@ func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel
v.PostTxBegin(txid, err)
}
return NewTransaction(xtx, txid, db.lstr), nil
return NewTransaction(xtx, txid, db), nil
}
func (db *database) Exit() error {
return db.db.Close()
}
func (db *database) ListConverter() []DBTypeConverter {
return db.conv
}
func (db *database) RegisterConverter(conv DBTypeConverter) {
db.conv = langext.ArrFilter(db.conv, func(v DBTypeConverter) bool { return v.ModelTypeString() != conv.ModelTypeString() })
db.conv = append(db.conv, conv)
}
func (db *database) RegisterDefaultConverter() {
db.RegisterConverter(ConverterBoolToBit)
db.RegisterConverter(ConverterTimeToUnixMillis)
db.RegisterConverter(ConverterRFCUnixMilliTimeToUnixMillis)
db.RegisterConverter(ConverterRFCUnixNanoTimeToUnixNanos)
db.RegisterConverter(ConverterRFCUnixTimeToUnixSeconds)
db.RegisterConverter(ConverterRFC339TimeToString)
db.RegisterConverter(ConverterRFC339NanoTimeToString)
}

View File

@@ -121,7 +121,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
if err != nil {
return "", err
}
tableList, err := ScanAll[tabInfo](rowsTableList, SModeFast, Unsafe, true)
tableList, err := ScanAll[tabInfo](ctx, db, rowsTableList, SModeFast, Unsafe, true)
if err != nil {
return "", err
}
@@ -143,7 +143,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
return "", err
}
columnList, err := ScanAll[colInfo](rowsColumnList, SModeFast, Unsafe, true)
columnList, err := ScanAll[colInfo](ctx, db, rowsColumnList, SModeFast, Unsafe, true)
if err != nil {
return "", err
}
@@ -158,7 +158,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
if err != nil {
return "", err
}
idxList, err := ScanAll[idxInfo](rowsIdxList, SModeFast, Unsafe, true)
idxList, err := ScanAll[idxInfo](ctx, db, rowsIdxList, SModeFast, Unsafe, true)
if err != nil {
return "", err
}
@@ -173,7 +173,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
if err != nil {
return "", err
}
fkyList, err := ScanAll[fkyInfo](rowsIdxList, SModeFast, Unsafe, true)
fkyList, err := ScanAll[fkyInfo](ctx, db, rowsIdxList, SModeFast, Unsafe, true)
if err != nil {
return "", err
}

View File

@@ -9,4 +9,5 @@ import (
type Queryable interface {
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
ListConverter() []DBTypeConverter
}

View File

@@ -13,8 +13,8 @@ import (
type StructScanMode string
const (
SModeFast StructScanMode = "FAST"
SModeExtended StructScanMode = "EXTENDED"
SModeFast StructScanMode = "FAST" // Use default sq.Scan, does not work with joined/resolved types and/or custom value converter
SModeExtended StructScanMode = "EXTENDED" // Fully featured perhaps (?) a tiny bit slower - default
)
type StructScanSafety string
@@ -51,7 +51,13 @@ func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string,
columns = append(columns, "\""+columnName+"\"")
params = append(params, ":"+paramkey)
pp[paramkey] = rvfield.Interface()
val, err := convertValueToDB(q, rvfield.Interface())
if err != nil {
return nil, err
}
pp[paramkey] = val
}
@@ -71,7 +77,7 @@ func QuerySingle[TData any](ctx context.Context, q Queryable, sql string, pp PP,
return *new(TData), err
}
data, err := ScanSingle[TData](rows, mode, sec, true)
data, err := ScanSingle[TData](ctx, q, rows, mode, sec, true)
if err != nil {
return *new(TData), err
}
@@ -85,7 +91,7 @@ func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mo
return nil, err
}
data, err := ScanAll[TData](rows, mode, sec, true)
data, err := ScanAll[TData](ctx, q, rows, mode, sec, true)
if err != nil {
return nil, err
}
@@ -93,7 +99,7 @@ func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mo
return data, nil
}
func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
func ScanSingle[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
if rows.Next() {
var strscan *StructScanner
@@ -123,7 +129,7 @@ func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanS
return *new(TData), err
}
} else if mode == SModeExtended {
err := strscan.StructScanExt(&data)
err := strscan.StructScanExt(q, &data)
if err != nil {
return *new(TData), err
}
@@ -149,6 +155,10 @@ func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanS
return *new(TData), err
}
if err := ctx.Err(); err != nil {
return *new(TData), err
}
return data, nil
} else {
@@ -159,7 +169,7 @@ func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanS
}
}
func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
func ScanAll[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
var strscan *StructScanner
if sec == Safe {
@@ -182,6 +192,11 @@ func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafe
res := make([]TData, 0)
for rows.Next() {
if err := ctx.Err(); err != nil {
return nil, err
}
if mode == SModeFast {
var data TData
err := strscan.StructScanBase(&data)
@@ -191,7 +206,7 @@ func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafe
res = append(res, data)
} else if mode == SModeExtended {
var data TData
err := strscan.StructScanExt(&data)
err := strscan.StructScanExt(q, &data)
if err != nil {
return nil, err
}

89
sq/sq_test.go Normal file
View File

@@ -0,0 +1,89 @@
package sq
import (
"context"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/rfctime"
"gogs.mikescher.com/BlackForestBytes/goext/tst"
"os"
"path/filepath"
"testing"
"time"
)
func TestTypeConverter1(t *testing.T) {
type RequestData struct {
ID string `db:"id"`
Timestamp time.Time `db:"timestamp"`
}
ctx := context.Background()
dbdir := t.TempDir()
dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3")
sqlite3.Version() // ensure loaded
tst.AssertNoErr(t, os.MkdirAll(dbdir, os.ModePerm))
url := fmt.Sprintf("file:%s?_journal=%s&_timeout=%d&_fk=%s&_busy_timeout=%d", dbfile1, "DELETE", 1000, "true", 1000)
xdb := tst.Must(sqlx.Open("sqlite3", url))(t)
db := NewDB(xdb)
db.RegisterDefaultConverter()
_, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{})
tst.AssertNoErr(t, err)
_, err = InsertSingle(ctx, db, "requests", RequestData{
ID: "001",
Timestamp: time.Date(2000, 06, 15, 12, 0, 0, 0, time.UTC),
})
tst.AssertNoErr(t, err)
}
func TestTypeConverter2(t *testing.T) {
type RequestData struct {
ID string `db:"id"`
Timestamp rfctime.UnixMilliTime `db:"timestamp"`
}
ctx := context.Background()
dbdir := t.TempDir()
dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3")
sqlite3.Version() // ensure loaded
tst.AssertNoErr(t, os.MkdirAll(dbdir, os.ModePerm))
url := fmt.Sprintf("file:%s?_journal=%s&_timeout=%d&_fk=%s&_busy_timeout=%d", dbfile1, "DELETE", 1000, "true", 1000)
xdb := tst.Must(sqlx.Open("sqlite3", url))(t)
db := NewDB(xdb)
db.RegisterDefaultConverter()
_, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{})
tst.AssertNoErr(t, err)
t0 := rfctime.NewUnixMilli(time.Date(2012, 03, 01, 16, 0, 0, 0, time.UTC))
_, err = InsertSingle(ctx, db, "requests", RequestData{
ID: "002",
Timestamp: t0,
})
tst.AssertNoErr(t, err)
r, err := QuerySingle[RequestData](ctx, db, "SELECT * FROM requests WHERE id = '002'", PP{}, SModeExtended, Safe)
tst.AssertNoErr(t, err)
fmt.Printf("%+v\n", r)
tst.AssertEqual(t, "002", r.ID)
tst.AssertEqual(t, t0.UnixNano(), r.Timestamp.UnixNano())
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"reflect"
)
@@ -15,9 +16,10 @@ type StructScanner struct {
Mapper *reflectx.Mapper
unsafe bool
fields [][]int
values []any
columns []string
fields [][]int
values []any
converter []DBTypeConverter
columns []string
}
func NewStructScanner(rows *sqlx.Rows, unsafe bool) *StructScanner {
@@ -47,13 +49,15 @@ func (r *StructScanner) Start(dest any) error {
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
}
r.values = make([]interface{}, len(columns))
r.converter = make([]DBTypeConverter, len(columns))
return nil
}
// StructScanExt forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
// does also wok with nullabel structs (from LEFT JOIN's)
func (r *StructScanner) StructScanExt(dest any) error {
// does also work with nullabel structs (from LEFT JOIN's)
// does also work with custom value converters
func (r *StructScanner) StructScanExt(q Queryable, dest any) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
@@ -64,7 +68,7 @@ func (r *StructScanner) StructScanExt(dest any) error {
v = v.Elem()
err := fieldsByTraversalExtended(v, r.fields, r.values)
err := fieldsByTraversalExtended(q, v, r.fields, r.values, r.converter)
if err != nil {
return err
}
@@ -131,7 +135,6 @@ func (r *StructScanner) StructScanExt(dest any) error {
val1 := reflect.ValueOf(r.values[i])
val2 := val1.Elem()
val3 := val2.Elem()
if val2.IsNil() {
if f.Kind() != reflect.Pointer {
@@ -140,7 +143,16 @@ func (r *StructScanner) StructScanExt(dest any) error {
f.Set(reflect.Zero(f.Type())) // set to nil
} else {
f.Set(val3)
if r.converter[i] != nil {
val3 := val2.Elem().Interface()
conv3, err := r.converter[i].DBToModel(val3)
if err != nil {
return err
}
f.Set(reflect.ValueOf(conv3))
} else {
f.Set(val2.Elem())
}
}
}
@@ -172,7 +184,7 @@ func (r *StructScanner) StructScanBase(dest any) error {
}
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
func fieldsByTraversalExtended(v reflect.Value, traversals [][]int, values []interface{}) error {
func fieldsByTraversalExtended(q Queryable, v reflect.Value, traversals [][]int, values []interface{}, converter []DBTypeConverter) error {
v = reflect.Indirect(v)
if v.Kind() != reflect.Struct {
return errors.New("argument not a struct")
@@ -185,7 +197,23 @@ func fieldsByTraversalExtended(v reflect.Value, traversals [][]int, values []int
}
f := reflectx.FieldByIndexes(v, traversal)
values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface()
typeStr := f.Type().String()
foundConverter := false
for _, conv := range q.ListConverter() {
if conv.ModelTypeString() == typeStr {
_v := langext.Ptr[any](nil)
values[i] = _v
foundConverter = true
converter[i] = conv
break
}
}
if !foundConverter {
values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface()
converter[i] = nil
}
}
return nil
}

View File

@@ -17,35 +17,35 @@ const (
)
type Tx interface {
Queryable
Rollback() error
Commit() error
Status() TxStatus
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
}
type transaction struct {
tx *sqlx.Tx
id uint16
lstr []Listener
status TxStatus
execCtr int
queryCtr int
db *database
}
func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr []Listener) Tx {
func NewTransaction(xtx *sqlx.Tx, txid uint16, db *database) Tx {
return &transaction{
tx: xtx,
id: txid,
lstr: lstr,
status: TxStatusInitial,
execCtr: 0,
queryCtr: 0,
db: db,
}
}
func (tx *transaction) Rollback() error {
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
err := v.PreTxRollback(tx.id)
if err != nil {
return err
@@ -58,7 +58,7 @@ func (tx *transaction) Rollback() error {
tx.status = TxStatusRollback
}
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
v.PostTxRollback(tx.id, result)
}
@@ -66,7 +66,7 @@ func (tx *transaction) Rollback() error {
}
func (tx *transaction) Commit() error {
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
err := v.PreTxCommit(tx.id)
if err != nil {
return err
@@ -79,7 +79,7 @@ func (tx *transaction) Commit() error {
tx.status = TxStatusComitted
}
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
v.PostTxRollback(tx.id, result)
}
@@ -88,7 +88,7 @@ func (tx *transaction) Commit() error {
func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
origsql := sqlstr
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
if err != nil {
return nil, err
@@ -101,7 +101,7 @@ func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Re
tx.status = TxStatusActive
}
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep)
}
@@ -113,7 +113,7 @@ func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Re
func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
origsql := sqlstr
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
if err != nil {
return nil, err
@@ -126,7 +126,7 @@ func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx
tx.status = TxStatusActive
}
for _, v := range tx.lstr {
for _, v := range tx.db.lstr {
v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep)
}
@@ -140,6 +140,10 @@ func (tx *transaction) Status() TxStatus {
return tx.status
}
func (tx *transaction) ListConverter() []DBTypeConverter {
return tx.db.conv
}
func (tx *transaction) Traffic() (int, int) {
return tx.execCtr, tx.queryCtr
}