From 95779800d4f476de358938deb93dcf328ca19b2e Mon Sep 17 00:00:00 2001 From: Kushtrim Junuzi Date: Mon, 19 Aug 2019 03:43:23 +0000 Subject: [PATCH] Add database picker master/slave when QueryRow is called --- db.go | 16 +++++++++++++--- db_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/db.go b/db.go index 00bda4f..ffea504 100644 --- a/db.go +++ b/db.go @@ -13,8 +13,9 @@ import ( // forming a single master multiple slaves topology. // Reads and writes are automatically directed to the correct physical db. type DB struct { - pdbs []*sql.DB // Physical databases - count uint64 // Monotonically incrementing counter on each query + pdbs []*sql.DB // Physical databases + count uint64 // Monotonically incrementing counter on each query + queryRowDB func(query string, args ...interface{}) *sql.DB } // Open concurrently opens each underlying physical db. @@ -147,7 +148,12 @@ func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{ // Errors are deferred until Row's Scan method is called. // QueryRow uses a slave as the physical db. func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row { - return db.Slave().QueryRow(query, args...) + d := db.Slave() // default to slave + if db.queryRowDB != nil { + // conditional function that decide what db to use + d = db.queryRowDB(query, args) + } + return d.QueryRow(query, args...) } // QueryRowContext executes a query that is expected to return at most one row. @@ -206,3 +212,7 @@ func (db *DB) slave(n int) int { } return int(1 + (atomic.AddUint64(&db.count, 1) % uint64(n-1))) } + +func (db *DB) SetQueryRowDB(fn func(query string, args ...interface{}) *sql.DB) { + db.queryRowDB = fn +} diff --git a/db_test.go b/db_test.go index 39c3855..088c478 100644 --- a/db_test.go +++ b/db_test.go @@ -1,6 +1,8 @@ package nap import ( + "database/sql" + "strings" "testing" "testing/quick" @@ -59,3 +61,43 @@ func TestSlave(t *testing.T) { t.Error(err) } } + +func TestQueryRow(t *testing.T) { + // https://www.sqlite.org/inmemorydb.html + db, err := Open("sqlite3", ":memory:;:memory:;:memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if err = db.Ping(); err != nil { + t.Error(err) + } + + master := false + + db.SetQueryRowDB(func(query string, args ...interface{}) *sql.DB { + if len(query) > 12 && strings.ToLower(query)[0:12] == "insert into " { + master = true + return db.Master() + } + master = false + return db.Slave() + }) + + res := db.QueryRow("insert into t1(c1,c2) values(1,1);", nil) + if res == nil { + t.Errorf("func QueryRow has no results") + } + if !master { + t.Errorf("query row expected to use master database") + } + + res = db.QueryRow("select * from t1", nil) + if res == nil { + t.Errorf("func QueryRow has no results") + } + if master { + t.Errorf("query row expected to use slave database") + } +}