Added a SQL-Preprocessor - this way we can unmarshal recursive structures (LEFT JOIN etc)

This commit is contained in:
2022-12-21 18:14:13 +01:00
parent bbf7962e29
commit dbc014f819
19 changed files with 740 additions and 162 deletions

View File

@@ -93,18 +93,21 @@ func (db *Database) CreateChannel(ctx TxContext, userid models.UserID, name stri
}, nil
}
func (db *Database) ListChannelsByOwner(ctx TxContext, userid models.UserID) ([]models.Channel, error) {
func (db *Database) ListChannelsByOwner(ctx TxContext, userid models.UserID, subUserID models.UserID) ([]models.ChannelWithSubscription, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return nil, err
}
rows, err := tx.Query(ctx, "SELECT * FROM channels WHERE owner_user_id = :ouid", sq.PP{"ouid": userid})
rows, err := tx.Query(ctx, "SELECT channels.*, sub.* FROM channels LEFT JOIN subscriptions AS sub ON channels.channel_id = sub.channel_id AND sub.subscriber_user_id = :subuid WHERE owner_user_id = :ouid", sq.PP{
"ouid": userid,
"subuid": subUserID,
})
if err != nil {
return nil, err
}
data, err := models.DecodeChannels(rows)
data, err := models.DecodeChannelsWithSubscription(rows)
if err != nil {
return nil, err
}
@@ -112,25 +115,27 @@ func (db *Database) ListChannelsByOwner(ctx TxContext, userid models.UserID) ([]
return data, nil
}
func (db *Database) ListChannelsBySubscriber(ctx TxContext, userid models.UserID, confirmed bool) ([]models.Channel, error) {
func (db *Database) ListChannelsBySubscriber(ctx TxContext, userid models.UserID, confirmed *bool) ([]models.ChannelWithSubscription, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return nil, err
}
confCond := ""
if confirmed {
if confirmed != nil && *confirmed {
confCond = " AND sub.confirmed = 1"
} else if confirmed != nil && !*confirmed {
confCond = " AND sub.confirmed = 0"
}
rows, err := tx.Query(ctx, "SELECT * FROM channels LEFT JOIN subscriptions sub on channels.channel_id = sub.channel_id WHERE sub.subscriber_user_id = :suid "+confCond, sq.PP{
"suid": userid,
rows, err := tx.Query(ctx, "SELECT channels.*, sub.* FROM channels LEFT JOIN subscriptions AS sub on channels.channel_id = sub.channel_id AND sub.subscriber_user_id = :subuid WHERE sub.subscription_id IS NOT NULL "+confCond, sq.PP{
"subuid": userid,
})
if err != nil {
return nil, err
}
data, err := models.DecodeChannels(rows)
data, err := models.DecodeChannelsWithSubscription(rows)
if err != nil {
return nil, err
}
@@ -138,25 +143,28 @@ func (db *Database) ListChannelsBySubscriber(ctx TxContext, userid models.UserID
return data, nil
}
func (db *Database) ListChannelsByAccess(ctx TxContext, userid models.UserID, confirmed bool) ([]models.Channel, error) {
func (db *Database) ListChannelsByAccess(ctx TxContext, userid models.UserID, confirmed *bool) ([]models.ChannelWithSubscription, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return nil, err
}
confCond := "OR sub.subscriber_user_id = ?"
if confirmed {
confCond = "OR (sub.subscriber_user_id = ? AND sub.confirmed = 1)"
confCond := ""
if confirmed != nil && *confirmed {
confCond = "OR sub.confirmed = 1"
} else if confirmed != nil && !*confirmed {
confCond = "OR sub.confirmed = 0"
}
rows, err := tx.Query(ctx, "SELECT * FROM channels LEFT JOIN subscriptions sub on channels.channel_id = sub.channel_id WHERE owner_user_id = :ouid "+confCond, sq.PP{
"ouid": userid,
rows, err := tx.Query(ctx, "SELECT channels.*, sub.* FROM channels LEFT JOIN subscriptions AS sub on channels.channel_id = sub.channel_id AND sub.subscriber_user_id = :subuid WHERE owner_user_id = :ouid "+confCond, sq.PP{
"ouid": userid,
"subuid": userid,
})
if err != nil {
return nil, err
}
data, err := models.DecodeChannels(rows)
data, err := models.DecodeChannelsWithSubscription(rows)
if err != nil {
return nil, err
}
@@ -164,26 +172,27 @@ func (db *Database) ListChannelsByAccess(ctx TxContext, userid models.UserID, co
return data, nil
}
func (db *Database) GetChannel(ctx TxContext, userid models.UserID, channelid models.ChannelID) (models.Channel, error) {
func (db *Database) GetChannel(ctx TxContext, userid models.UserID, channelid models.ChannelID) (models.ChannelWithSubscription, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return models.Channel{}, err
return models.ChannelWithSubscription{}, err
}
rows, err := tx.Query(ctx, "SELECT * FROM channels WHERE owner_user_id = :ouid AND channel_id = :cid LIMIT 1", sq.PP{
"ouid": userid,
"cid": channelid,
rows, err := tx.Query(ctx, "SELECT channels.*, sub.* FROM channels LEFT JOIN subscriptions AS sub on channels.channel_id = sub.channel_id AND sub.subscriber_user_id = :subuid WHERE owner_user_id = :ouid AND channels.channel_id = :cid LIMIT 1", sq.PP{
"ouid": userid,
"cid": channelid,
"subuid": userid,
})
if err != nil {
return models.Channel{}, err
return models.ChannelWithSubscription{}, err
}
client, err := models.DecodeChannel(rows)
channel, err := models.DecodeChannelWithSubscription(rows)
if err != nil {
return models.Channel{}, err
return models.ChannelWithSubscription{}, err
}
return client, nil
return channel, nil
}
func (db *Database) IncChannelMessageCounter(ctx TxContext, channel models.Channel) error {

View File

@@ -2,6 +2,7 @@ package db
import (
server "blackforestbytes.com/simplecloudnotifier"
"blackforestbytes.com/simplecloudnotifier/db/dbtools"
"blackforestbytes.com/simplecloudnotifier/db/schema"
"context"
"database/sql"
@@ -9,7 +10,6 @@ import (
"fmt"
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/sq"
"time"
@@ -40,7 +40,9 @@ func NewDatabase(conf server.Config) (*Database, error) {
scndb := &Database{qqdb}
qqdb.SetListener(scndb)
qqdb.AddListener(dbtools.DBLogger{})
qqdb.AddListener(dbtools.NewDBPreprocessor(scndb.db))
return scndb, nil
}
@@ -83,35 +85,3 @@ func (db *Database) Ping(ctx context.Context) error {
func (db *Database) BeginTx(ctx context.Context) (sq.Tx, error) {
return db.db.BeginTransaction(ctx, sql.LevelDefault)
}
func (db *Database) OnQuery(txID *uint16, sql string, _ *sq.PP) {
if txID == nil {
log.Debug().Msg(fmt.Sprintf("[SQL-QUERY] %s", fmtSQLPrint(sql)))
} else {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-QUERY] %s", *txID, fmtSQLPrint(sql)))
}
}
func (db *Database) OnExec(txID *uint16, sql string, _ *sq.PP) {
if txID == nil {
log.Debug().Msg(fmt.Sprintf("[SQL-EXEC] %s", fmtSQLPrint(sql)))
} else {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-EXEC] %s", *txID, fmtSQLPrint(sql)))
}
}
func (db *Database) OnPing() {
log.Debug().Msg("[SQL-PING]")
}
func (db *Database) OnTxBegin(txid uint16) {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-START]", txid))
}
func (db *Database) OnTxCommit(txid uint16) {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-COMMIT]", txid))
}
func (db *Database) OnTxRollback(txid uint16) {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-ROLLBACK]", txid))
}

View File

@@ -0,0 +1,90 @@
package dbtools
import (
"context"
"fmt"
"github.com/rs/zerolog/log"
"gogs.mikescher.com/BlackForestBytes/goext/sq"
"strings"
)
type DBLogger struct{}
func (l DBLogger) PrePing(ctx context.Context) error {
log.Debug().Msg("[SQL-PING]")
return nil
}
func (l DBLogger) PreTxBegin(ctx context.Context, txid uint16) error {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-START]", txid))
return nil
}
func (l DBLogger) PreTxCommit(txid uint16) error {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-COMMIT]", txid))
return nil
}
func (l DBLogger) PreTxRollback(txid uint16) error {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-ROLLBACK]", txid))
return nil
}
func (l DBLogger) PreQuery(ctx context.Context, txID *uint16, sql *string, params *sq.PP) error {
if txID == nil {
log.Debug().Msg(fmt.Sprintf("[SQL-QUERY] %s", fmtSQLPrint(*sql)))
} else {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-QUERY] %s", *txID, fmtSQLPrint(*sql)))
}
return nil
}
func (l DBLogger) PreExec(ctx context.Context, txID *uint16, sql *string, params *sq.PP) error {
if txID == nil {
log.Debug().Msg(fmt.Sprintf("[SQL-EXEC] %s", fmtSQLPrint(*sql)))
} else {
log.Debug().Msg(fmt.Sprintf("[SQL-TX<%d>-EXEC] %s", *txID, fmtSQLPrint(*sql)))
}
return nil
}
func (l DBLogger) PostPing(result error) {
//
}
func (l DBLogger) PostTxBegin(txid uint16, result error) {
//
}
func (l DBLogger) PostTxCommit(txid uint16, result error) {
//
}
func (l DBLogger) PostTxRollback(txid uint16, result error) {
//
}
func (l DBLogger) PostQuery(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
//
}
func (l DBLogger) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
//
}
func fmtSQLPrint(sql string) string {
if strings.Contains(sql, ";") {
return "(...multi...)"
}
sql = strings.ReplaceAll(sql, "\r", "")
sql = strings.ReplaceAll(sql, "\n", " ")
return sql
}

View File

@@ -0,0 +1,238 @@
package dbtools
import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog/log"
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"gogs.mikescher.com/BlackForestBytes/goext/sq"
"regexp"
"strings"
"sync"
)
//
// This is..., not good...
//
// for sq.ScanAll to work with (left-)joined tables _need_ to get column names aka "alias.column"
// But sqlite (and all other db server) only return "column" if we don't manually specify `alias.column as "alias.columnname"`
// But always specifying all columns (and their alias) would be __very__ cumbersome...
//
// The "solution" is this preprocessor, which translates queries of the form `SELECT tab1.*, tab2.* From tab1` into `SELECT tab1.col1 AS "tab1.col1", tab1.col2 AS "tab1.col2" ....`
//
// Prerequisites:
// - all aliased tables must be written as `tablename AS alias` (the variant without the AS keyword is invalid)
// - a star only expands to the (single) table in FROM. Use *, table2.* if there exists a second (joined) table
// - No weird SQL syntax, this "parser" is not very robust...
//
type DBPreprocessor struct {
db sq.DB
lock sync.Mutex
cacheColumns map[string][]string
cacheQuery map[string]string
}
var regexAlias = regexp.MustCompile("([A-Za-z_\\-0-9]+)\\s+AS\\s+([A-Za-z_\\-0-9]+)")
func NewDBPreprocessor(db sq.DB) *DBPreprocessor {
return &DBPreprocessor{
db: db,
lock: sync.Mutex{},
cacheColumns: make(map[string][]string),
cacheQuery: make(map[string]string),
}
}
func (pp *DBPreprocessor) PrePing(ctx context.Context) error {
return nil
}
func (pp *DBPreprocessor) PreTxBegin(ctx context.Context, txid uint16) error {
return nil
}
func (pp *DBPreprocessor) PreTxCommit(txid uint16) error {
return nil
}
func (pp *DBPreprocessor) PreTxRollback(txid uint16) error {
return nil
}
func (pp *DBPreprocessor) PreQuery(ctx context.Context, txID *uint16, sql *string, params *sq.PP) error {
sqlOriginal := *sql
pp.lock.Lock()
v, ok := pp.cacheQuery[sqlOriginal]
pp.lock.Unlock()
if ok {
*sql = v
return nil
}
if !strings.HasPrefix(sqlOriginal, "SELECT ") {
return nil
}
idxFrom := strings.Index(sqlOriginal, " FROM ")
if idxFrom < 0 {
return nil
}
fromTableName := strings.Split(strings.TrimSpace(sqlOriginal[idxFrom+len(" FROM"):]), " ")[0]
sels := strings.TrimSpace(sqlOriginal[len("SELECT "):idxFrom])
split := strings.Split(sels, ",")
newsel := make([]string, 0)
aliasMap := make(map[string]string)
for _, v := range regexAlias.FindAllStringSubmatch(sqlOriginal, idxFrom+len(" FROM")) {
aliasMap[strings.TrimSpace(v[2])] = strings.TrimSpace(v[1])
}
for _, expr := range split {
expr = strings.TrimSpace(expr)
if expr == "*" {
columns, err := pp.getTableColumns(ctx, fromTableName)
if err != nil {
return err
}
for _, colname := range columns {
newsel = append(newsel, fmt.Sprintf("%s.%s AS \"%s\"", fromTableName, colname, colname))
}
} else if strings.HasSuffix(expr, ".*") {
tableName := expr[0 : len(expr)-2]
if tableRealName, ok := aliasMap[tableName]; ok {
columns, err := pp.getTableColumns(ctx, tableRealName)
if err != nil {
return err
}
for _, colname := range columns {
newsel = append(newsel, fmt.Sprintf("%s.%s AS \"%s.%s\"", tableName, colname, tableName, colname))
}
} else if tableName == fromTableName {
columns, err := pp.getTableColumns(ctx, tableName)
if err != nil {
return err
}
for _, colname := range columns {
newsel = append(newsel, fmt.Sprintf("%s.%s AS \"%s\"", tableName, colname, colname))
}
} else {
columns, err := pp.getTableColumns(ctx, tableName)
if err != nil {
return err
}
for _, colname := range columns {
newsel = append(newsel, fmt.Sprintf("%s.%s AS \"%s.%s\"", tableName, colname, tableName, colname))
}
}
} else {
return nil
}
}
newSQL := "SELECT " + strings.Join(newsel, ", ") + sqlOriginal[idxFrom:]
pp.lock.Lock()
pp.cacheQuery[sqlOriginal] = newSQL
pp.lock.Unlock()
log.Debug().Msgf("Preprocessed SQL statement from '%s' --to--> '%s'", sqlOriginal, newSQL)
*sql = newSQL
return nil
}
func (pp *DBPreprocessor) PreExec(ctx context.Context, txID *uint16, sql *string, params *sq.PP) error {
return nil
}
func (pp *DBPreprocessor) PostPing(result error) {
//
}
func (pp *DBPreprocessor) PostTxBegin(txid uint16, result error) {
//
}
func (pp *DBPreprocessor) PostTxCommit(txid uint16, result error) {
//
}
func (pp *DBPreprocessor) PostTxRollback(txid uint16, result error) {
//
}
func (pp *DBPreprocessor) PostQuery(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
//
}
func (pp *DBPreprocessor) PostExec(txID *uint16, sqlOriginal string, sqlReal string, params sq.PP) {
//
}
func (pp *DBPreprocessor) getTableColumns(ctx context.Context, tablename string) ([]string, error) {
pp.lock.Lock()
v, ok := pp.cacheColumns[tablename]
pp.lock.Unlock()
if ok {
return v, nil
}
type res struct {
CID int64 `db:"cid"`
Name string `db:"name"`
Type string `db:"type"`
NotNull int `db:"notnull"`
DFLT *string `db:"dflt_value"`
PK int `db:"pk"`
}
rows, err := pp.db.Query(ctx, "PRAGMA table_info('"+tablename+"');", sq.PP{})
if err != nil {
return nil, err
}
resrows, err := sq.ScanAll[res](rows, true)
if err != nil {
return nil, err
}
columns := langext.ArrMap(resrows, func(v res) string { return v.Name })
if len(columns) == 0 {
return nil, errors.New("no columns in table '" + tablename + "' (table does not exist?)")
}
pp.lock.Lock()
pp.cacheColumns[tablename] = columns
pp.lock.Unlock()
return columns, nil
}

View File

@@ -62,13 +62,46 @@ func (db *Database) ListSubscriptionsByChannel(ctx TxContext, channelID models.C
return data, nil
}
func (db *Database) ListSubscriptionsByOwner(ctx TxContext, ownerUserID models.UserID) ([]models.Subscription, error) {
func (db *Database) ListSubscriptionsByOwner(ctx TxContext, ownerUserID models.UserID, confirmed *bool) ([]models.Subscription, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return nil, err
}
rows, err := tx.Query(ctx, "SELECT * FROM subscriptions WHERE channel_owner_user_id = :ouid", sq.PP{"ouid": ownerUserID})
cond := ""
if confirmed != nil && *confirmed {
cond = " AND confirmed = 1"
} else if confirmed != nil && !*confirmed {
cond = " AND confirmed = 0"
}
rows, err := tx.Query(ctx, "SELECT * FROM subscriptions WHERE channel_owner_user_id = :ouid"+cond, sq.PP{"ouid": ownerUserID})
if err != nil {
return nil, err
}
data, err := models.DecodeSubscriptions(rows)
if err != nil {
return nil, err
}
return data, nil
}
func (db *Database) ListSubscriptionsBySubscriber(ctx TxContext, subscriberUserID models.UserID, confirmed *bool) ([]models.Subscription, error) {
tx, err := ctx.GetOrCreateTransaction(db)
if err != nil {
return nil, err
}
cond := ""
if confirmed != nil && *confirmed {
cond = " AND confirmed = 1"
} else if confirmed != nil && !*confirmed {
cond = " AND confirmed = 0"
}
rows, err := tx.Query(ctx, "SELECT * FROM subscriptions WHERE subscriber_user_id = :suid"+cond, sq.PP{"suid": subscriberUserID})
if err != nil {
return nil, err
}

View File

@@ -186,7 +186,7 @@ func (db *Database) UpdateUserKeys(ctx TxContext, userid models.UserID, sendKey
return err
}
_, err = tx.Exec(ctx, "UPDATE users SET send_key = :sk, read_key = :rk, admin_key = :ak WHERE user_id = ?", sq.PP{
_, err = tx.Exec(ctx, "UPDATE users SET send_key = :sk, read_key = :rk, admin_key = :ak WHERE user_id = :uid", sq.PP{
"sk": sendKey,
"rk": readKey,
"ak": adminKey,

View File

@@ -2,7 +2,6 @@ package db
import (
"gogs.mikescher.com/BlackForestBytes/goext/langext"
"strings"
"time"
)
@@ -24,14 +23,3 @@ func time2DBOpt(t *time.Time) *int64 {
}
return langext.Ptr(t.UnixMilli())
}
func fmtSQLPrint(sql string) string {
if strings.Contains(sql, ";") {
return "(...multi...)"
}
sql = strings.ReplaceAll(sql, "\r", "")
sql = strings.ReplaceAll(sql, "\n", " ")
return sql
}