Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ import (
// forming a single master multiple slaves topology.
// Reads and writes are automatically directed to the correct physical db.
type DB struct {
pdbs []*sql.DB // Physical databases
count uint64 // Monotonically incrementing counter on each query
pdbs []*sql.DB // Physical databases
count uint64 // Monotonically incrementing counter on each query
queryRowDB func(query string, args ...interface{}) *sql.DB
}

// Open concurrently opens each underlying physical db.
Expand Down Expand Up @@ -147,7 +148,12 @@ func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{
// Errors are deferred until Row's Scan method is called.
// QueryRow uses a slave as the physical db.
func (db *DB) QueryRow(query string, args ...interface{}) *sql.Row {
return db.Slave().QueryRow(query, args...)
d := db.Slave() // default to slave
if db.queryRowDB != nil {
// conditional function that decide what db to use
d = db.queryRowDB(query, args)
}
return d.QueryRow(query, args...)
}

// QueryRowContext executes a query that is expected to return at most one row.
Expand Down Expand Up @@ -206,3 +212,7 @@ func (db *DB) slave(n int) int {
}
return int(1 + (atomic.AddUint64(&db.count, 1) % uint64(n-1)))
}

func (db *DB) SetQueryRowDB(fn func(query string, args ...interface{}) *sql.DB) {
db.queryRowDB = fn
}
42 changes: 42 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package nap

import (
"database/sql"
"strings"
"testing"
"testing/quick"

Expand Down Expand Up @@ -59,3 +61,43 @@ func TestSlave(t *testing.T) {
t.Error(err)
}
}

func TestQueryRow(t *testing.T) {
// https://www.sqlite.org/inmemorydb.html
db, err := Open("sqlite3", ":memory:;:memory:;:memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()

if err = db.Ping(); err != nil {
t.Error(err)
}

master := false

db.SetQueryRowDB(func(query string, args ...interface{}) *sql.DB {
if len(query) > 12 && strings.ToLower(query)[0:12] == "insert into " {
master = true
return db.Master()
}
master = false
return db.Slave()
})

res := db.QueryRow("insert into t1(c1,c2) values(1,1);", nil)
if res == nil {
t.Errorf("func QueryRow has no results")
}
if !master {
t.Errorf("query row expected to use master database")
}

res = db.QueryRow("select * from t1", nil)
if res == nil {
t.Errorf("func QueryRow has no results")
}
if master {
t.Errorf("query row expected to use slave database")
}
}