diff --git a/README.md b/README.md index d541750..1eb438b 100644 --- a/README.md +++ b/README.md @@ -207,7 +207,40 @@ Imported functions participate in compile-time checking for arity and return typ ## Transactions and Batching -TODO! +Use `transaction:` to run multiple queries atomically. The block commits on success and rolls back on any exception. Nesting is supported via savepoints. `tryTransaction:` behaves the same but returns `bool` (false on database errors) without raising. + +Examples: + +```nim +# Commit on success +transaction: + query: + insert person(id = ?(1), name = ?"alice", password = ?"p", email = ?"a@x", salt = ?"s", status = ?"ok") + query: + update thread(views = views + 1) + where id == ?(42) + +# Rollback on error +let ok = tryTransaction: + query: + insert person(id = ?(2), name = ?"bob", password = ?"p", email = ?"b@x", salt = ?"s", status = ?"ok") + # Primary key violation => entire block is rolled back, ok = false + query: + insert person(id = ?(2), name = ?"duplicate", password = ?"p", email = ?"d@x", salt = ?"s", status = ?"x") + +# Nested transactions via savepoints +transaction: + query: + insert person(id = ?(3), name = ?"carol", password = ?"p", email = ?"c@x", salt = ?"s", status = ?"ok") + let innerOk = tryTransaction: + # This will fail and roll back to the savepoint + query: + insert person(id = ?(3), name = ?"duplicate", password = ?"p", email = ?"d@x", salt = ?"s", status = ?"x") + doAssert innerOk == false + # Continue outer transaction normally +``` + +PostgreSQL and SQLite are supported. The macros use `BEGIN/COMMIT/ROLLBACK` for the outermost transaction and `SAVEPOINT/RELEASE/ROLLBACK TO` for nested scopes. ## Reusable Procedures and Iterators diff --git a/ormin/ormin_postgre.nim b/ormin/ormin_postgre.nim index fce1ab9..0d09799 100644 --- a/ormin/ormin_postgre.nim +++ b/ormin/ormin_postgre.nim @@ -33,6 +33,8 @@ proc c_strtol(buf: cstring, endptr: ptr cstring = nil, base: cint = 10): int {. var sid {.compileTime.}: int proc prepareStmt*(db: DbConn; q: string): PStmt = + when defined(debugOrminTrace): + echo "[[Ormin Executing]]: ", q inc sid result = "ormin" & $sid var res = pqprepare(db, result, q, 0, nil) diff --git a/ormin/ormin_sqlite.nim b/ormin/ormin_sqlite.nim index 2f19c89..321a8e0 100644 --- a/ormin/ormin_sqlite.nim +++ b/ormin/ormin_sqlite.nim @@ -28,6 +28,8 @@ proc dbError*(db: DbConn) {.noreturn.} = raise e proc prepareStmt*(db: DbConn; q: string): PStmt = + when defined(debugOrminTrace): + echo "[[Ormin Executing]]: ", q if prepare_v2(db, q, q.len.cint, result, nil) != SQLITE_OK: dbError(db) diff --git a/ormin/queries.nim b/ormin/queries.nim index a5f416e..e4e82ff 100644 --- a/ormin/queries.nim +++ b/ormin/queries.nim @@ -140,6 +140,27 @@ type # For SQLite: expression to return instead of last_insert_rowid() retExpr: NimNode +# Execute a non-row SQL statement strictly (errors on failure) +template execNoRowsStrict*(sqlStmt: string) = + when defined(debugOrminTrace): + echo "[[Ormin Executing]]: ", q + let s {.gensym.} = prepareStmt(db, sqlStmt) + startQuery(db, s) + if stepQuery(db, s, false): + stopQuery(db, s) + else: + stopQuery(db, s) + dbError(db) + +# Execute a non-row SQL statement, relying on startQuery to raise on failure +template execNoRowsLoose(sqlStmt: string) = + when defined(debugOrminTrace): + echo "[[Ormin Executing]]: ", sqlStmt + let s {.gensym.} = prepareStmt(db, sqlStmt) + startQuery(db, s) + discard stepQuery(db, s, false) + stopQuery(db, s) + proc newQueryBuilder(): QueryBuilder {.compileTime.} = QueryBuilder(head: "", fromm: "", join: "", values: "", where: "", groupby: "", having: "", orderby: "", limit: "", offset: "", @@ -549,7 +570,12 @@ proc generateRoutine(name: NimNode, q: QueryBuilder; if k != nnkIteratorDef: rtyp = nnkBracketExpr.newTree(ident"seq", rtyp) finalParams.add rtyp - finalParams.add newIdentDefs(ident"db", ident("DbConn")) + when dbBackend == DbBackend.postgre: + finalParams.add newIdentDefs(ident"db", newTree(nnkDotExpr, ident"ormin_postgre", ident"DbConn")) + elif dbBackend == DbBackend.sqlite: + finalParams.add newIdentDefs(ident"db", newTree(nnkDotExpr, ident"ormin_sqlite", ident"DbConn")) + else: + finalParams.add newIdentDefs(ident"db", ident("DbConn")) var i = 1 if q.params.len > 0: body.add newCall(bindSym"startBindings", prepStmt, newLit(q.params.len)) @@ -1081,6 +1107,86 @@ macro tryQuery*(body: untyped): untyped = when defined(debugOrminDsl): macros.hint("Ormin Query: " & repr(result), body) +# ------------------------- +# Transactions DSL +# ------------------------- + +# Transaction state for nested transactions +var txDepth {.threadvar.}: int + +proc getTxDepth*(): int = + result = txDepth + +proc isTopTx*(): bool = + result = txDepth == 1 + +proc incTxDepth*() = + inc txDepth + +proc decTxDepth*() = + dec txDepth + +template txBegin*(sp: untyped) = + if isTopTx(): + execNoRowsLoose("begin transaction") + else: + execNoRowsLoose("savepoint " & sp) + +template txCommit*(sp: untyped) = + if isTopTx(): + execNoRowsLoose("commit") + else: + execNoRowsLoose("release savepoint " & sp) + +template txRollback*(sp: untyped) = + if isTopTx(): + execNoRowsLoose("rollback") + else: + execNoRowsLoose("rollback to savepoint " & sp) + +template transaction*(body: untyped) = + ## Runs the body inside a database transaction. Commits on success, + ## rolls back on any exception and rethrows. Supports nesting via savepoints. + block: + incTxDepth() + let sp = "ormin_tx_" & $txDepth + + try: + txBegin(sp) + `body` + txCommit(sp) + except DbError: + txRollback(sp) + raise + except CatchableError, Defect: + txRollback(sp) + raise + finally: + decTxDepth() + +macro getBlock(blk: untyped): untyped = + result = blk[0] + +template transaction*(body, other: untyped) = + ## Runs the body inside a database transaction. Commits on success, + ## rolls back on any exception and rethrows. Supports nesting via savepoints. + block: + incTxDepth() + let sp = "ormin_tx_" & $txDepth + + try: + txBegin(sp) + `body` + txCommit(sp) + except DbError: + txRollback(sp) + getBlock(`other`) + except CatchableError, Defect: + txRollback(sp) + raise + finally: + decTxDepth() + proc createRoutine(name, query: NimNode; k: NimNodeKind): NimNode = expectKind query, nnkStmtList expectMinLen query, 1 diff --git a/tests/tpostgre.nim b/tests/tpostgre.nim index 0929252..1caee77 100644 --- a/tests/tpostgre.nim +++ b/tests/tpostgre.nim @@ -60,23 +60,23 @@ suite "timestamp_insert": test "insert": query: insert tb_timestamp(dt = ?dt1, dtn = ?dtn1, dtz = ?dtz1) - check db.getValue(sql"select count(*) from tb_timestamp") == "1" + check db_postgres.getValue(db, sql"select count(*) from tb_timestamp") == "1" test "json": query: insert tb_timestamp(dt = %dtjson1["dt"], dtn = %dtjson1["dtn"], dtz = %dtjson1["dtz"]) - check db.getValue(sql"select count(*) from tb_timestamp") == "1" + check db_postgres.getValue(db, sql"select count(*) from tb_timestamp") == "1" suite "timestamp": db.dropTable(sqlFile, "tb_timestamp") db.createTable(sqlFile, "tb_timestamp") - db.exec(insertSql, dtStr1, dtnStr1, dtzStr1) - db.exec(insertSql, dtStr2, dtnStr2, dtzStr2) - db.exec(insertSql, dtStr3, dtnStr3, dtzStr3) - doAssert db.getValue(sql"select count(*) from tb_timestamp") == "3" + db_postgres.exec(db, insertSql, dtStr1, dtnStr1, dtzStr1) + db_postgres.exec(db, insertSql, dtStr2, dtnStr2, dtzStr2) + db_postgres.exec(db, insertSql, dtStr3, dtnStr3, dtzStr3) + doAssert db_postgres.getValue(db, sql"select count(*) from tb_timestamp") == "3" test "query": let res = query: diff --git a/tests/ttransactions.nim b/tests/ttransactions.nim new file mode 100644 index 0000000..35d0834 --- /dev/null +++ b/tests/ttransactions.nim @@ -0,0 +1,132 @@ +import unittest, os, strformat +import ormin +import ormin/db_utils +when NimVersion < "1.2.0": import ./compat + +let testDir = currentSourcePath.parentDir() + +when defined postgre: + when defined(macosx): + {.passL: "-Wl,-rpath,/opt/homebrew/lib/postgresql@14".} + from db_connector/db_postgres import exec, getValue + const backend = DbBackend.postgre + importModel(backend, "forum_model_postgres") + const sqlFileName = "forum_model_postgres.sql" + let db {.global.} = open("localhost", "test", "test", "test_ormin") +else: + from db_connector/db_sqlite import exec, getValue + const backend = DbBackend.sqlite + importModel(backend, "forum_model_sqlite") + const sqlFileName = "forum_model_sqlite.sql" + var memoryPath = testDir & "/" & ":memory:" + let db {.global.} = open(memoryPath, "", "", "") + +var sqlFilePath = Path(testDir & "/" & sqlFileName) + +# Fresh schema +db.dropTable(sqlFilePath) +db.createTable(sqlFilePath) + +suite &"Transactions ({backend})": + + test "commit on success": + transaction: + query: + insert person(id = ?(101), name = ?"john101", password = ?"p101", email = ?"john101@mail.com", salt = ?"s101", status = ?"ok") + check db.getValue(sql"select count(*) from person where id = 101") == "1" + + test "rollback on error with manual try except": + # prepare one row + query: + insert person(id = ?(201), name = ?"john201", password = ?"p201", email = ?"john201@mail.com", salt = ?"s201", status = ?"ok") + # in transaction insert a new row and then violate PK + try: + transaction: + query: + insert person(id = ?(202), name = ?"john202", password = ?"p202", email = ?"john202@mail.com", salt = ?"s202", status = ?"ok") + # duplicate key error + query: + insert person(id = ?(201), name = ?"dup", password = ?"p", email = ?"e", salt = ?"s", status = ?"x") + check false # should not reach + except DbError as e: + discard + # both inserts inside the transaction should be rolled back + check db.getValue(sql"select count(*) from person where id = 202") == "0" + check db.getValue(sql"select count(*) from person where id = 201 and name = 'dup'") == "0" + + test "rollback on error with else": + # prepare one row + var failed = false + query: + insert person(id = ?(501), name = ?"john501", password = ?"p501", email = ?"john501@mail.com", salt = ?"s501", status = ?"ok") + # in transaction insert a new row and then violate PK + transaction: + query: + insert person(id = ?(502), name = ?"john502", password = ?"p502", email = ?"john502@mail.com", salt = ?"s502", status = ?"ok") + # duplicate key error + query: + insert person(id = ?(501), name = ?"dup", password = ?"p", email = ?"e", salt = ?"s", status = ?"x") + check false # should not reach + else: + echo "do something else..." + failed = true + + check failed + # both inserts inside the transaction should be rolled back + check db.getValue(sql"select count(*) from person where id = 502") == "0" + check db.getValue(sql"select count(*) from person where id = 501 and name = 'dup'") == "0" + + test "commit normally with else": + # prepare one row + var failed = false + query: + insert person(id = ?(601), name = ?"john601", password = ?"p601", email = ?"john601@mail.com", salt = ?"s601", status = ?"ok") + # in transaction insert a new row and then violate PK + transaction: + query: + insert person(id = ?(602), name = ?"john602", password = ?"p602", email = ?"john602@mail.com", salt = ?"s602", status = ?"ok") + query: + insert person(id = ?(603), name = ?"dup", password = ?"p", email = ?"e", salt = ?"s", status = ?"x") + else: + failed = true + + check not failed + # both inserts inside the transaction should be rolled back + check db.getValue(sql"select count(*) from person where id = 603") == "1" + check db.getValue(sql"select count(*) from person where id = 602") == "1" + check db.getValue(sql"select count(*) from person where id = 601") == "1" + + test "transaction set false on DbError": + var failed = false + transaction: + query: + insert person(id = ?(301), name = ?"john301", password = ?"p301", email = ?"john301@mail.com", salt = ?"s301", status = ?"ok") + query: + insert person(id = ?(301), name = ?"dup", password = ?"p", email = ?"e", salt = ?"s", status = ?"x") + else: + failed = true + check failed + check db.getValue(sql"select count(*) from person where id = 301") == "0" + + test "nested savepoints": + var failed = false + transaction: + query: + insert person(id = ?(401), name = ?"john401", password = ?"p401", email = ?"john401@mail.com", salt = ?"s401", status = ?"ok") + var innerOk = true + transaction: + query: + insert person(id = ?(402), name = ?"john402", password = ?"p402", email = ?"john402@mail.com", salt = ?"s402", status = ?"ok") + query: + insert person(id = ?(401), name = ?"dup401", password = ?"p", email = ?"e", salt = ?"s", status = ?"x") + else: + innerOk = false + + check innerOk == false + + # after inner rollback, we can still insert another row and commit outer + query: + insert person(id = ?(403), name = ?"john403", password = ?"p403", email = ?"john403@mail.com", salt = ?"s403", status = ?"ok") + else: + failed = true + check db.getValue(sql"select count(*) from person where id in (401,402,403)") == "2"