v0.0.351 sq value converter
All checks were successful
Build Docker and Deploy / Run goext test-suite (push) Successful in 2m30s
All checks were successful
Build Docker and Deploy / Run goext test-suite (push) Successful in 2m30s
This commit is contained in:
@@ -4,17 +4,11 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/rfctime"
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
//TODO UNFINISHED
|
||||
// this is not finished
|
||||
// idea was that we can register converter in the database struct
|
||||
// they get inherited from the transactions
|
||||
// and when marshallingunmarshaling (sq.Query | sq.QueryAll)
|
||||
// or marshaling (sq.InsertSingle)
|
||||
// the types get converter automatically...
|
||||
|
||||
type DBTypeConverter interface {
|
||||
ModelTypeString() string
|
||||
DBTypeString() string
|
||||
@@ -40,16 +34,42 @@ var ConverterTimeToUnixMillis = NewDBTypeConverter[time.Time, int64](func(v time
|
||||
return time.UnixMilli(v), nil
|
||||
})
|
||||
|
||||
var ConverterOptTimeToUnixMillis = NewDBTypeConverter[*time.Time, *int64](func(v *time.Time) (*int64, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
var ConverterRFCUnixMilliTimeToUnixMillis = NewDBTypeConverter[rfctime.UnixMilliTime, int64](func(v rfctime.UnixMilliTime) (int64, error) {
|
||||
return v.UnixMilli(), nil
|
||||
}, func(v int64) (rfctime.UnixMilliTime, error) {
|
||||
return rfctime.NewUnixMilli(time.UnixMilli(v)), nil
|
||||
})
|
||||
|
||||
var ConverterRFCUnixNanoTimeToUnixNanos = NewDBTypeConverter[rfctime.UnixNanoTime, int64](func(v rfctime.UnixNanoTime) (int64, error) {
|
||||
return v.UnixNano(), nil
|
||||
}, func(v int64) (rfctime.UnixNanoTime, error) {
|
||||
return rfctime.NewUnixNano(time.Unix(0, v)), nil
|
||||
})
|
||||
|
||||
var ConverterRFCUnixTimeToUnixSeconds = NewDBTypeConverter[rfctime.UnixTime, int64](func(v rfctime.UnixTime) (int64, error) {
|
||||
return v.Unix(), nil
|
||||
}, func(v int64) (rfctime.UnixTime, error) {
|
||||
return rfctime.NewUnix(time.Unix(v, 0)), nil
|
||||
})
|
||||
|
||||
var ConverterRFC339TimeToString = NewDBTypeConverter[rfctime.RFC3339Time, string](func(v rfctime.RFC3339Time) (string, error) {
|
||||
return v.Format(time.RFC3339), nil
|
||||
}, func(v string) (rfctime.RFC3339Time, error) {
|
||||
t, err := time.Parse(time.RFC3339Nano, v)
|
||||
if err != nil {
|
||||
return rfctime.RFC3339Time{}, err
|
||||
}
|
||||
return langext.Ptr(v.UnixMilli()), nil
|
||||
}, func(v *int64) (*time.Time, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
return rfctime.NewRFC3339(t), nil
|
||||
})
|
||||
|
||||
var ConverterRFC339NanoTimeToString = NewDBTypeConverter[rfctime.RFC3339NanoTime, string](func(v rfctime.RFC3339NanoTime) (string, error) {
|
||||
return v.Format(time.RFC3339Nano), nil
|
||||
}, func(v string) (rfctime.RFC3339NanoTime, error) {
|
||||
t, err := time.Parse(time.RFC3339Nano, v)
|
||||
if err != nil {
|
||||
return rfctime.RFC3339NanoTime{}, err
|
||||
}
|
||||
return langext.Ptr(time.UnixMilli(*v)), nil
|
||||
return rfctime.NewRFC3339Nano(t), nil
|
||||
})
|
||||
|
||||
type dbTypeConverterImpl[TModelData any, TDBData any] struct {
|
||||
@@ -89,3 +109,36 @@ func NewDBTypeConverter[TModelData any, TDBData any](todb func(v TModelData) (TD
|
||||
tomodel: tomodel,
|
||||
}
|
||||
}
|
||||
|
||||
func convertValueToDB(q Queryable, value any) (any, error) {
|
||||
modelTypeStr := fmt.Sprintf("%T", value)
|
||||
|
||||
for _, conv := range q.ListConverter() {
|
||||
if conv.ModelTypeString() == modelTypeStr {
|
||||
return conv.ModelToDB(value)
|
||||
}
|
||||
}
|
||||
|
||||
if value != nil && reflect.TypeOf(value).Kind() == reflect.Ptr {
|
||||
vof := reflect.ValueOf(value)
|
||||
if vof.IsNil() {
|
||||
return nil, nil
|
||||
} else {
|
||||
return convertValueToDB(q, vof.Elem().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func convertValueToModel(q Queryable, value any, destinationType string) (any, error) {
|
||||
dbTypeString := fmt.Sprintf("%T", value)
|
||||
|
||||
for _, conv := range q.ListConverter() {
|
||||
if conv.ModelTypeString() == destinationType && conv.DBTypeString() == dbTypeString {
|
||||
return conv.DBToModel(value)
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
@@ -4,16 +4,19 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type DB interface {
|
||||
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
|
||||
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
|
||||
Queryable
|
||||
|
||||
Ping(ctx context.Context) error
|
||||
BeginTransaction(ctx context.Context, iso sql.IsolationLevel) (Tx, error)
|
||||
AddListener(listener Listener)
|
||||
Exit() error
|
||||
RegisterConverter(DBTypeConverter)
|
||||
RegisterDefaultConverter()
|
||||
}
|
||||
|
||||
type database struct {
|
||||
@@ -21,6 +24,7 @@ type database struct {
|
||||
txctr uint16
|
||||
lock sync.Mutex
|
||||
lstr []Listener
|
||||
conv []DBTypeConverter
|
||||
}
|
||||
|
||||
func NewDB(db *sqlx.DB) DB {
|
||||
@@ -120,9 +124,28 @@ func (db *database) BeginTransaction(ctx context.Context, iso sql.IsolationLevel
|
||||
v.PostTxBegin(txid, err)
|
||||
}
|
||||
|
||||
return NewTransaction(xtx, txid, db.lstr), nil
|
||||
return NewTransaction(xtx, txid, db), nil
|
||||
}
|
||||
|
||||
func (db *database) Exit() error {
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *database) ListConverter() []DBTypeConverter {
|
||||
return db.conv
|
||||
}
|
||||
|
||||
func (db *database) RegisterConverter(conv DBTypeConverter) {
|
||||
db.conv = langext.ArrFilter(db.conv, func(v DBTypeConverter) bool { return v.ModelTypeString() != conv.ModelTypeString() })
|
||||
db.conv = append(db.conv, conv)
|
||||
}
|
||||
|
||||
func (db *database) RegisterDefaultConverter() {
|
||||
db.RegisterConverter(ConverterBoolToBit)
|
||||
db.RegisterConverter(ConverterTimeToUnixMillis)
|
||||
db.RegisterConverter(ConverterRFCUnixMilliTimeToUnixMillis)
|
||||
db.RegisterConverter(ConverterRFCUnixNanoTimeToUnixNanos)
|
||||
db.RegisterConverter(ConverterRFCUnixTimeToUnixSeconds)
|
||||
db.RegisterConverter(ConverterRFC339TimeToString)
|
||||
db.RegisterConverter(ConverterRFC339NanoTimeToString)
|
||||
}
|
||||
|
@@ -121,7 +121,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
tableList, err := ScanAll[tabInfo](rowsTableList, SModeFast, Unsafe, true)
|
||||
tableList, err := ScanAll[tabInfo](ctx, db, rowsTableList, SModeFast, Unsafe, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -143,7 +143,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
|
||||
return "", err
|
||||
}
|
||||
|
||||
columnList, err := ScanAll[colInfo](rowsColumnList, SModeFast, Unsafe, true)
|
||||
columnList, err := ScanAll[colInfo](ctx, db, rowsColumnList, SModeFast, Unsafe, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
idxList, err := ScanAll[idxInfo](rowsIdxList, SModeFast, Unsafe, true)
|
||||
idxList, err := ScanAll[idxInfo](ctx, db, rowsIdxList, SModeFast, Unsafe, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -173,7 +173,7 @@ func CreateSqliteDatabaseSchemaString(ctx context.Context, db Queryable) (string
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fkyList, err := ScanAll[fkyInfo](rowsIdxList, SModeFast, Unsafe, true)
|
||||
fkyList, err := ScanAll[fkyInfo](ctx, db, rowsIdxList, SModeFast, Unsafe, true)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@@ -9,4 +9,5 @@ import (
|
||||
type Queryable interface {
|
||||
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
|
||||
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
|
||||
ListConverter() []DBTypeConverter
|
||||
}
|
||||
|
@@ -13,8 +13,8 @@ import (
|
||||
type StructScanMode string
|
||||
|
||||
const (
|
||||
SModeFast StructScanMode = "FAST"
|
||||
SModeExtended StructScanMode = "EXTENDED"
|
||||
SModeFast StructScanMode = "FAST" // Use default sq.Scan, does not work with joined/resolved types and/or custom value converter
|
||||
SModeExtended StructScanMode = "EXTENDED" // Fully featured perhaps (?) a tiny bit slower - default
|
||||
)
|
||||
|
||||
type StructScanSafety string
|
||||
@@ -51,7 +51,13 @@ func InsertSingle[TData any](ctx context.Context, q Queryable, tableName string,
|
||||
|
||||
columns = append(columns, "\""+columnName+"\"")
|
||||
params = append(params, ":"+paramkey)
|
||||
pp[paramkey] = rvfield.Interface()
|
||||
|
||||
val, err := convertValueToDB(q, rvfield.Interface())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pp[paramkey] = val
|
||||
|
||||
}
|
||||
|
||||
@@ -71,7 +77,7 @@ func QuerySingle[TData any](ctx context.Context, q Queryable, sql string, pp PP,
|
||||
return *new(TData), err
|
||||
}
|
||||
|
||||
data, err := ScanSingle[TData](rows, mode, sec, true)
|
||||
data, err := ScanSingle[TData](ctx, q, rows, mode, sec, true)
|
||||
if err != nil {
|
||||
return *new(TData), err
|
||||
}
|
||||
@@ -85,7 +91,7 @@ func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mo
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := ScanAll[TData](rows, mode, sec, true)
|
||||
data, err := ScanAll[TData](ctx, q, rows, mode, sec, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -93,7 +99,7 @@ func QueryAll[TData any](ctx context.Context, q Queryable, sql string, pp PP, mo
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
|
||||
func ScanSingle[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) (TData, error) {
|
||||
if rows.Next() {
|
||||
var strscan *StructScanner
|
||||
|
||||
@@ -123,7 +129,7 @@ func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanS
|
||||
return *new(TData), err
|
||||
}
|
||||
} else if mode == SModeExtended {
|
||||
err := strscan.StructScanExt(&data)
|
||||
err := strscan.StructScanExt(q, &data)
|
||||
if err != nil {
|
||||
return *new(TData), err
|
||||
}
|
||||
@@ -149,6 +155,10 @@ func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanS
|
||||
return *new(TData), err
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return *new(TData), err
|
||||
}
|
||||
|
||||
return data, nil
|
||||
|
||||
} else {
|
||||
@@ -159,7 +169,7 @@ func ScanSingle[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanS
|
||||
}
|
||||
}
|
||||
|
||||
func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
|
||||
func ScanAll[TData any](ctx context.Context, q Queryable, rows *sqlx.Rows, mode StructScanMode, sec StructScanSafety, close bool) ([]TData, error) {
|
||||
var strscan *StructScanner
|
||||
|
||||
if sec == Safe {
|
||||
@@ -182,6 +192,11 @@ func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafe
|
||||
|
||||
res := make([]TData, 0)
|
||||
for rows.Next() {
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if mode == SModeFast {
|
||||
var data TData
|
||||
err := strscan.StructScanBase(&data)
|
||||
@@ -191,7 +206,7 @@ func ScanAll[TData any](rows *sqlx.Rows, mode StructScanMode, sec StructScanSafe
|
||||
res = append(res, data)
|
||||
} else if mode == SModeExtended {
|
||||
var data TData
|
||||
err := strscan.StructScanExt(&data)
|
||||
err := strscan.StructScanExt(q, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
89
sq/sq_test.go
Normal file
89
sq/sq_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package sq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/mattn/go-sqlite3"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/rfctime"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/tst"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTypeConverter1(t *testing.T) {
|
||||
type RequestData struct {
|
||||
ID string `db:"id"`
|
||||
Timestamp time.Time `db:"timestamp"`
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dbdir := t.TempDir()
|
||||
dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3")
|
||||
|
||||
sqlite3.Version() // ensure loaded
|
||||
|
||||
tst.AssertNoErr(t, os.MkdirAll(dbdir, os.ModePerm))
|
||||
|
||||
url := fmt.Sprintf("file:%s?_journal=%s&_timeout=%d&_fk=%s&_busy_timeout=%d", dbfile1, "DELETE", 1000, "true", 1000)
|
||||
|
||||
xdb := tst.Must(sqlx.Open("sqlite3", url))(t)
|
||||
|
||||
db := NewDB(xdb)
|
||||
db.RegisterDefaultConverter()
|
||||
|
||||
_, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
_, err = InsertSingle(ctx, db, "requests", RequestData{
|
||||
ID: "001",
|
||||
Timestamp: time.Date(2000, 06, 15, 12, 0, 0, 0, time.UTC),
|
||||
})
|
||||
tst.AssertNoErr(t, err)
|
||||
}
|
||||
|
||||
func TestTypeConverter2(t *testing.T) {
|
||||
type RequestData struct {
|
||||
ID string `db:"id"`
|
||||
Timestamp rfctime.UnixMilliTime `db:"timestamp"`
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
dbdir := t.TempDir()
|
||||
dbfile1 := filepath.Join(dbdir, langext.MustHexUUID()+".sqlite3")
|
||||
|
||||
sqlite3.Version() // ensure loaded
|
||||
|
||||
tst.AssertNoErr(t, os.MkdirAll(dbdir, os.ModePerm))
|
||||
|
||||
url := fmt.Sprintf("file:%s?_journal=%s&_timeout=%d&_fk=%s&_busy_timeout=%d", dbfile1, "DELETE", 1000, "true", 1000)
|
||||
|
||||
xdb := tst.Must(sqlx.Open("sqlite3", url))(t)
|
||||
|
||||
db := NewDB(xdb)
|
||||
db.RegisterDefaultConverter()
|
||||
|
||||
_, err := db.Exec(ctx, "CREATE TABLE `requests` ( id TEXT NOT NULL, timestamp INTEGER NOT NULL, PRIMARY KEY (id) ) STRICT", PP{})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
t0 := rfctime.NewUnixMilli(time.Date(2012, 03, 01, 16, 0, 0, 0, time.UTC))
|
||||
|
||||
_, err = InsertSingle(ctx, db, "requests", RequestData{
|
||||
ID: "002",
|
||||
Timestamp: t0,
|
||||
})
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
r, err := QuerySingle[RequestData](ctx, db, "SELECT * FROM requests WHERE id = '002'", PP{}, SModeExtended, Safe)
|
||||
tst.AssertNoErr(t, err)
|
||||
|
||||
fmt.Printf("%+v\n", r)
|
||||
|
||||
tst.AssertEqual(t, "002", r.ID)
|
||||
tst.AssertEqual(t, t0.UnixNano(), r.Timestamp.UnixNano())
|
||||
}
|
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/jmoiron/sqlx/reflectx"
|
||||
"gogs.mikescher.com/BlackForestBytes/goext/langext"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
@@ -15,9 +16,10 @@ type StructScanner struct {
|
||||
Mapper *reflectx.Mapper
|
||||
unsafe bool
|
||||
|
||||
fields [][]int
|
||||
values []any
|
||||
columns []string
|
||||
fields [][]int
|
||||
values []any
|
||||
converter []DBTypeConverter
|
||||
columns []string
|
||||
}
|
||||
|
||||
func NewStructScanner(rows *sqlx.Rows, unsafe bool) *StructScanner {
|
||||
@@ -47,13 +49,15 @@ func (r *StructScanner) Start(dest any) error {
|
||||
return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
|
||||
}
|
||||
r.values = make([]interface{}, len(columns))
|
||||
r.converter = make([]DBTypeConverter, len(columns))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StructScanExt forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
|
||||
// does also wok with nullabel structs (from LEFT JOIN's)
|
||||
func (r *StructScanner) StructScanExt(dest any) error {
|
||||
// does also work with nullabel structs (from LEFT JOIN's)
|
||||
// does also work with custom value converters
|
||||
func (r *StructScanner) StructScanExt(q Queryable, dest any) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
|
||||
if v.Kind() != reflect.Ptr {
|
||||
@@ -64,7 +68,7 @@ func (r *StructScanner) StructScanExt(dest any) error {
|
||||
|
||||
v = v.Elem()
|
||||
|
||||
err := fieldsByTraversalExtended(v, r.fields, r.values)
|
||||
err := fieldsByTraversalExtended(q, v, r.fields, r.values, r.converter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -131,7 +135,6 @@ func (r *StructScanner) StructScanExt(dest any) error {
|
||||
|
||||
val1 := reflect.ValueOf(r.values[i])
|
||||
val2 := val1.Elem()
|
||||
val3 := val2.Elem()
|
||||
|
||||
if val2.IsNil() {
|
||||
if f.Kind() != reflect.Pointer {
|
||||
@@ -140,7 +143,16 @@ func (r *StructScanner) StructScanExt(dest any) error {
|
||||
|
||||
f.Set(reflect.Zero(f.Type())) // set to nil
|
||||
} else {
|
||||
f.Set(val3)
|
||||
if r.converter[i] != nil {
|
||||
val3 := val2.Elem().Interface()
|
||||
conv3, err := r.converter[i].DBToModel(val3)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Set(reflect.ValueOf(conv3))
|
||||
} else {
|
||||
f.Set(val2.Elem())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -172,7 +184,7 @@ func (r *StructScanner) StructScanBase(dest any) error {
|
||||
}
|
||||
|
||||
// fieldsByTraversal forked from github.com/jmoiron/sqlx@v1.3.5/sqlx.go
|
||||
func fieldsByTraversalExtended(v reflect.Value, traversals [][]int, values []interface{}) error {
|
||||
func fieldsByTraversalExtended(q Queryable, v reflect.Value, traversals [][]int, values []interface{}, converter []DBTypeConverter) error {
|
||||
v = reflect.Indirect(v)
|
||||
if v.Kind() != reflect.Struct {
|
||||
return errors.New("argument not a struct")
|
||||
@@ -185,7 +197,23 @@ func fieldsByTraversalExtended(v reflect.Value, traversals [][]int, values []int
|
||||
}
|
||||
f := reflectx.FieldByIndexes(v, traversal)
|
||||
|
||||
values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface()
|
||||
typeStr := f.Type().String()
|
||||
|
||||
foundConverter := false
|
||||
for _, conv := range q.ListConverter() {
|
||||
if conv.ModelTypeString() == typeStr {
|
||||
_v := langext.Ptr[any](nil)
|
||||
values[i] = _v
|
||||
foundConverter = true
|
||||
converter[i] = conv
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundConverter {
|
||||
values[i] = reflect.New(reflect.PointerTo(f.Type())).Interface()
|
||||
converter[i] = nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -17,35 +17,35 @@ const (
|
||||
)
|
||||
|
||||
type Tx interface {
|
||||
Queryable
|
||||
|
||||
Rollback() error
|
||||
Commit() error
|
||||
Status() TxStatus
|
||||
Exec(ctx context.Context, sql string, prep PP) (sql.Result, error)
|
||||
Query(ctx context.Context, sql string, prep PP) (*sqlx.Rows, error)
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
tx *sqlx.Tx
|
||||
id uint16
|
||||
lstr []Listener
|
||||
status TxStatus
|
||||
execCtr int
|
||||
queryCtr int
|
||||
db *database
|
||||
}
|
||||
|
||||
func NewTransaction(xtx *sqlx.Tx, txid uint16, lstr []Listener) Tx {
|
||||
func NewTransaction(xtx *sqlx.Tx, txid uint16, db *database) Tx {
|
||||
return &transaction{
|
||||
tx: xtx,
|
||||
id: txid,
|
||||
lstr: lstr,
|
||||
status: TxStatusInitial,
|
||||
execCtr: 0,
|
||||
queryCtr: 0,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (tx *transaction) Rollback() error {
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
err := v.PreTxRollback(tx.id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -58,7 +58,7 @@ func (tx *transaction) Rollback() error {
|
||||
tx.status = TxStatusRollback
|
||||
}
|
||||
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
v.PostTxRollback(tx.id, result)
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func (tx *transaction) Rollback() error {
|
||||
}
|
||||
|
||||
func (tx *transaction) Commit() error {
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
err := v.PreTxCommit(tx.id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -79,7 +79,7 @@ func (tx *transaction) Commit() error {
|
||||
tx.status = TxStatusComitted
|
||||
}
|
||||
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
v.PostTxRollback(tx.id, result)
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func (tx *transaction) Commit() error {
|
||||
|
||||
func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Result, error) {
|
||||
origsql := sqlstr
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
err := v.PreExec(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -101,7 +101,7 @@ func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Re
|
||||
tx.status = TxStatusActive
|
||||
}
|
||||
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
v.PostExec(langext.Ptr(tx.id), origsql, sqlstr, prep)
|
||||
}
|
||||
|
||||
@@ -113,7 +113,7 @@ func (tx *transaction) Exec(ctx context.Context, sqlstr string, prep PP) (sql.Re
|
||||
|
||||
func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx.Rows, error) {
|
||||
origsql := sqlstr
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
err := v.PreQuery(ctx, langext.Ptr(tx.id), &sqlstr, &prep)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -126,7 +126,7 @@ func (tx *transaction) Query(ctx context.Context, sqlstr string, prep PP) (*sqlx
|
||||
tx.status = TxStatusActive
|
||||
}
|
||||
|
||||
for _, v := range tx.lstr {
|
||||
for _, v := range tx.db.lstr {
|
||||
v.PostQuery(langext.Ptr(tx.id), origsql, sqlstr, prep)
|
||||
}
|
||||
|
||||
@@ -140,6 +140,10 @@ func (tx *transaction) Status() TxStatus {
|
||||
return tx.status
|
||||
}
|
||||
|
||||
func (tx *transaction) ListConverter() []DBTypeConverter {
|
||||
return tx.db.conv
|
||||
}
|
||||
|
||||
func (tx *transaction) Traffic() (int, int) {
|
||||
return tx.execCtr, tx.queryCtr
|
||||
}
|
||||
|
Reference in New Issue
Block a user