From b56f736e7cb36c76cba53034f94b58ad3e60dc2c Mon Sep 17 00:00:00 2001 From: saville Date: Mon, 13 Oct 2025 23:01:17 -0600 Subject: [PATCH] Fix handling for multitenancy --- README.rst | 48 +++++++---- dysql/__init__.py | 3 + dysql/databases.py | 97 ++++++++++++---------- dysql/multitenancy.py | 69 +++++++++++++++ dysql/test/__init__.py | 22 ++--- dysql/test/conftest.py | 19 +++++ dysql/test/test_database_initialization.py | 61 +++++++------- dysql/test/test_sql_decorator.py | 5 +- dysql/test/test_sql_exists_decorator.py | 4 +- dysql/test/test_sql_in_list_templates.py | 44 +++++----- dysql/test/test_sql_insert_templates.py | 18 ++-- pyproject.toml | 2 +- uv.lock | 2 +- 13 files changed, 241 insertions(+), 153 deletions(-) create mode 100644 dysql/multitenancy.py create mode 100644 dysql/test/conftest.py diff --git a/README.rst b/README.rst index 218f7b0..6633540 100644 --- a/README.rst +++ b/README.rst @@ -35,10 +35,10 @@ Component Breakdown initialization so that when a decorator function is called, it can setup a connection pool to a correct database * **is_set_current_database_supported** - this function may be used to determine if the ``*_current_database`` methods may be used or not -* **set_current_database** - (only supported on Python 3.7+) this function may be used to set the database name for the - current async context (not thread), this is especially useful for multitenant applications -* **reset_current_database** - (only supported on Python 3.7+) helper method to reset the current database after - ``set_current_database`` has been used in an async context +* **set_current_database** - this function may be used to set the database name for the current async context + (not thread), this is especially useful for multitenant applications +* **reset_current_database** - helper method to reset the current database after ``set_current_database`` has + been used in an async context * **set_database_init_hook** - sets a method to call whenever a new database is initialized * **QueryData** - a class that may be returned or yielded from ``sql*`` decorated methods which contains query information @@ -99,28 +99,46 @@ The ``set_database_init_hook`` method may be used in this case. As an example, t Multitenancy ============ In some applications, it may be useful to set a database other than the default database in order to support -database-per-tenant configurations. This may be done using the ``set_current_database`` and ``reset_current_database`` -methods. +database-per-tenant configurations. This may be done using various provided methods. .. code-block:: python - from dysql import reset_current_database, set_current_database - - def use_database_for_query(): - set_database_parameters( + from dysql import ( + set_default_connection_parameters, + reset_current_database, + set_current_database, + use_database_tenant, + tenant_database_manager, + sqlquery, + QueryData, + ) + + def init(): + # Initialize all databases up-front using an arbitrary database key to refer to them later + set_default_connection_parameters( + ... + database_key='db1', + ) + set_default_connection_parameters( ... - 'db1', + database_key='db2', ) + + def tenant_query_with_manual_set_reset(): set_current_database('db2') try: - # Queries db2 and not db1 query_database() finally: reset_current_database() -.. warning:: - These methods are only supported in Python 3.7+ due to their use of the ``contextvars`` module. The - ``is_set_current_database_supported`` method is provided to help tell if these methods may be used. + def tenant_query_with_context_manager(): + with tenant_database_manager("db2"): + return query_database() + + @use_database_tenant("db2") + @sqlquery() + def tenant_query_with_decorator(): + return QueryData("SELECT * FROM users") Decorators ========== diff --git a/dysql/__init__.py b/dysql/__init__.py index 2dd5258..492ed75 100644 --- a/dysql/__init__.py +++ b/dysql/__init__.py @@ -30,6 +30,7 @@ set_database_init_hook, set_default_connection_parameters, ) +from .multitenancy import use_database_tenant, tenant_database_manager from .exceptions import DBNotPreparedError @@ -53,5 +54,7 @@ "set_current_database", "set_database_init_hook", "set_default_connection_parameters", + "use_database_tenant", + "tenant_database_manager", "DBNotPreparedError", ] diff --git a/dysql/databases.py b/dysql/databases.py index 72a359b..9c0213f 100644 --- a/dysql/databases.py +++ b/dysql/databases.py @@ -6,8 +6,9 @@ with the terms of the Adobe license agreement accompanying it. """ +import contextvars import logging -import sys +from collections import defaultdict from typing import Callable, Optional import sqlalchemy @@ -17,14 +18,8 @@ logger = logging.getLogger("database") -_DEFAULT_CONNECTION_PARAMS = {} - -try: - import contextvars - - CURRENT_DATABASE_VAR = contextvars.ContextVar("dysql_current_database", default="") -except ImportError: - CURRENT_DATABASE_VAR = None +_DEFAULT_CONNECTION_PARAMS_BY_KEY = defaultdict(dict) +CURRENT_DATABASE_VAR = contextvars.ContextVar("dysql_current_database", default="") def set_database_init_hook( @@ -41,24 +36,20 @@ def set_database_init_hook( def is_set_current_database_supported() -> bool: """ - Determines if the set_current_database method is available on this python runtime. - :return: True if available, False otherwise + Deprecated, left in for backwards compatibility but always returns true. + :return: True """ - return bool(CURRENT_DATABASE_VAR) + return True -def set_current_database(database: str) -> None: +def set_current_database(database_key: str) -> None: """ - Sets the current database, may be used for multitenancy. This is only supported on Python 3.7+. This uses + Sets the current database key, may be used for multitenancy. This is only supported on Python 3.7+. This uses contextvars internally to set the name for the current async context. - :param database: the database name to use for this async context + :param database_key: the arbitrary database key to use for this async context """ - if not CURRENT_DATABASE_VAR: - raise DBNotPreparedError( - f'Cannot set the current database on Python "{sys.version}", please upgrade your Python version' - ) - CURRENT_DATABASE_VAR.set(database) - logger.debug(f"Set current database to {database}") + CURRENT_DATABASE_VAR.set(database_key) + logger.debug(f"Set current database to {database_key}") def reset_current_database() -> None: @@ -69,16 +60,17 @@ def reset_current_database() -> None: set_current_database("") -def _get_current_database() -> str: +def _get_current_database_key() -> str: """ - The current database name, using contextvars (if on python 3.7+) or the default database name. - :return: The current database name + The current database key, using contextvars (if on python 3.7+) or the default database key. + :return: The current database key """ database: Optional[str] = None if CURRENT_DATABASE_VAR: database = CURRENT_DATABASE_VAR.get() - if not database: - database = _DEFAULT_CONNECTION_PARAMS.get("database") + if not database and _DEFAULT_CONNECTION_PARAMS_BY_KEY: + # Get first database key + database = next(iter(_DEFAULT_CONNECTION_PARAMS_BY_KEY)) return database @@ -94,12 +86,14 @@ def set_default_connection_parameters( user: str, password: str, database: str, + database_key: Optional[str] = None, port: int = 3306, pool_size: int = 10, pool_recycle: int = 3600, echo_queries: bool = False, charset: str = "utf8", -): # pylint: disable=too-many-arguments,unused-argument + collation: Optional[str] = None, +): """ Initializes the parameters to use when connecting to the database. This is a subset of the parameters used by sqlalchemy. These may be overridden by parameters provided in the QueryData, hence the "default". @@ -108,12 +102,14 @@ def set_default_connection_parameters( :param user: user to connect to the database with :param password: password for given user :param database: database to connect to + :param database_key: optional database key that may be used for multitenant DBs, defaults to the database name :param port: the port to connect to (default 3306) :param pool_size: number of connections to maintain in the connection pool (default 10) :param pool_recycle: amount of time to wait between resetting the connections in the pool (default 3600) :param echo_queries: this tells sqlalchemy to print the queries when set to True (default false) :param charset: the charset for the sql engine to initialize with. (default utf8) + :param collation: the collation for the sql engine to initialize with. (default is not set) :exception DBNotPrepareError: happens when required parameters are missing """ _validate_param("host", host) @@ -121,14 +117,14 @@ def set_default_connection_parameters( _validate_param("password", password) _validate_param("database", database) - _DEFAULT_CONNECTION_PARAMS.update(locals()) + if not database_key: + database_key = database + _DEFAULT_CONNECTION_PARAMS_BY_KEY[database_key].update(locals()) class Database: - # pylint: disable=too-few-public-methods - - def __init__(self, database: Optional[str]) -> None: - self.database = database + def __init__(self, database_key: Optional[str]) -> None: + self.database = database_key # Engine is lazy-initialized self._engine: Optional[sqlalchemy.engine.Engine] = None @@ -142,25 +138,35 @@ def set_init_hook( @property def engine(self) -> sqlalchemy.engine.Engine: if not self._engine: - user = _DEFAULT_CONNECTION_PARAMS.get("user") - password = _DEFAULT_CONNECTION_PARAMS.get("password") - host = _DEFAULT_CONNECTION_PARAMS.get("host") - port = _DEFAULT_CONNECTION_PARAMS.get("port") - charset = _DEFAULT_CONNECTION_PARAMS.get("charset") - - url = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{self.database}?charset={charset}" + connection_params = _DEFAULT_CONNECTION_PARAMS_BY_KEY.get(self.database, {}) + if not connection_params: + raise DBNotPreparedError( + f"No connection parameters found for database key '{self.database}'" + ) + user = connection_params.get("user") + password = connection_params.get("password") + database = connection_params.get("database") + host = connection_params.get("host") + port = connection_params.get("port") + charset = connection_params.get("charset") + collation = connection_params.get("collation") + collation_str = "" + if collation: + collation_str = f"&collation={collation}" + + url = f"mysql+mysqlconnector://{user}:{password}@{host}:{port}/{database}?charset={charset}{collation_str}" self._engine = sqlalchemy.create_engine( url, - pool_recycle=_DEFAULT_CONNECTION_PARAMS.get("pool_recycle"), - pool_size=_DEFAULT_CONNECTION_PARAMS.get("pool_size"), - echo=_DEFAULT_CONNECTION_PARAMS.get("echo_queries"), + pool_recycle=connection_params.get("pool_recycle"), + pool_size=connection_params.get("pool_size"), + echo=connection_params.get("echo_queries"), pool_pre_ping=True, ) hook_method: Optional[ Callable[[Optional[str], sqlalchemy.engine.Engine], None] ] = getattr(self.__class__, "hook_method", None) if hook_method: - hook_method(self.database, self._engine) + hook_method(database, self._engine) return self._engine @@ -178,7 +184,7 @@ def __getitem__(self, database: Optional[str]) -> Database: :return: a database instance :raises DBNotPreparedError: when set_default_connection_parameters has not yet been called """ - if not _DEFAULT_CONNECTION_PARAMS: + if not _DEFAULT_CONNECTION_PARAMS_BY_KEY: raise DBNotPreparedError( "Unable to connect to a database, set_default_connection_parameters must first be called" ) @@ -192,8 +198,7 @@ def current_database(self) -> Database: """ The current database instance, retrieved using contextvars (if python 3.7+) or the default database. """ - # pylint: disable=unnecessary-dunder-call - return self.__getitem__(_get_current_database()) + return self.__getitem__(_get_current_database_key()) class DatabaseContainerSingleton(DatabaseContainer): diff --git a/dysql/multitenancy.py b/dysql/multitenancy.py new file mode 100644 index 0000000..5cb8a76 --- /dev/null +++ b/dysql/multitenancy.py @@ -0,0 +1,69 @@ +""" +Copyright 2025 Adobe +All Rights Reserved. + +NOTICE: Adobe permits you to use, modify, and distribute this file in accordance +with the terms of the Adobe license agreement accompanying it. +""" + +import functools +import logging +from typing import Any, Callable, TypeVar +from contextlib import contextmanager + +from dysql.exceptions import DBNotPreparedError +from dysql import ( + set_current_database, + reset_current_database, +) + +LOGGER = logging.getLogger(__name__) + +F = TypeVar("F", bound=Callable[..., Any]) + + +@contextmanager +def tenant_database_manager(database_key: str): + """ + Context manager for temporarily switching to a different database. + + :param database_key: the database key to switch to + :raises DBNotPreparedError: if the database key is not set + """ + if not database_key: + raise DBNotPreparedError( + "Cannot switch to database tenant with empty database key" + ) + + try: + LOGGER.debug(f"Switching to database {database_key}") + set_current_database(database_key) + yield + except Exception as e: + LOGGER.error(f"Error while using database {database_key}: {e}") + raise + finally: + try: + reset_current_database() + LOGGER.debug(f"Reset database context from: {database_key}") + except Exception as e: + LOGGER.error(f"Error resetting database context: {e}") + # Don't re-raise here to avoid masking the original exception + + +def use_database_tenant(database_key: str): + """ + Decorator that switches to a specific database for the duration of the function call. + :param database_key: the database key to use + :return: the decorator function + """ + + def decorator(func: F) -> F: + @functools.wraps(func) + def wrapper(*args, **kwargs): + with tenant_database_manager(database_key): + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/dysql/test/__init__.py b/dysql/test/__init__.py index 0d74f41..c728d6b 100644 --- a/dysql/test/__init__.py +++ b/dysql/test/__init__.py @@ -6,21 +6,11 @@ with the terms of the Adobe license agreement accompanying it. """ -from unittest.mock import Mock, patch -import pytest +from unittest.mock import Mock from dysql import set_default_connection_parameters, databases -@pytest.fixture(name="mock_create_engine") -def mock_create_engine_fixture(): - create_mock = patch("dysql.databases.sqlalchemy.create_engine") - try: - yield create_mock.start() - finally: - create_mock.stop() - - def setup_mock_engine(mock_create_engine): """ build up the basics of a mock engine for the database @@ -37,12 +27,12 @@ def setup_mock_engine(mock_create_engine): return mock_engine -def _verify_query_params(mock_engine, expected_query, expected_args): - _verify_query(mock_engine, expected_query) - _verify_query_args(mock_engine, expected_args) +def verify_query_params(mock_engine, expected_query, expected_args): + verify_query(mock_engine, expected_query) + verify_query_args(mock_engine, expected_args) -def _verify_query(mock_engine, expected_query): +def verify_query(mock_engine, expected_query): execute_call = ( mock_engine.connect.return_value.execution_options.return_value.execute ) @@ -52,7 +42,7 @@ def _verify_query(mock_engine, expected_query): assert query == expected_query -def _verify_query_args(mock_engine, expected_args): +def verify_query_args(mock_engine, expected_args): execute_call = ( mock_engine.connect.return_value.execution_options.return_value.execute ) diff --git a/dysql/test/conftest.py b/dysql/test/conftest.py new file mode 100644 index 0000000..f8ddc39 --- /dev/null +++ b/dysql/test/conftest.py @@ -0,0 +1,19 @@ +""" +Copyright 2021 Adobe +All Rights Reserved. + +NOTICE: Adobe permits you to use, modify, and distribute this file in accordance +with the terms of the Adobe license agreement accompanying it. +""" + +from unittest.mock import patch +import pytest + + +@pytest.fixture(name="mock_create_engine") +def mock_create_engine_fixture(): + create_mock = patch("dysql.databases.sqlalchemy.create_engine") + try: + yield create_mock.start() + finally: + create_mock.stop() diff --git a/dysql/test/test_database_initialization.py b/dysql/test/test_database_initialization.py index c4d46e3..4c226f8 100644 --- a/dysql/test/test_database_initialization.py +++ b/dysql/test/test_database_initialization.py @@ -6,8 +6,6 @@ with the terms of the Adobe license agreement accompanying it. """ -# pylint: disable=protected-access -import sys from unittest import mock import pytest @@ -20,9 +18,7 @@ QueryData, set_database_init_hook, ) -from dysql.test import mock_create_engine_fixture, setup_mock_engine - -_ = mock_create_engine_fixture +from dysql.test import setup_mock_engine """ @@ -44,17 +40,15 @@ def query(): @pytest.fixture(autouse=True, name="mock_engine") def fixture_mock_engine(mock_create_engine): dysql.databases.DatabaseContainerSingleton().clear() - dysql.databases._DEFAULT_CONNECTION_PARAMS.clear() + dysql.databases._DEFAULT_CONNECTION_PARAMS_BY_KEY.clear() # Reset the database before the test - if dysql.databases.is_set_current_database_supported(): - dysql.databases.reset_current_database() + dysql.databases.reset_current_database() yield setup_mock_engine(mock_create_engine) # Reset database after the test as well - if dysql.databases.is_set_current_database_supported(): - dysql.databases.reset_current_database() + dysql.databases.reset_current_database() @pytest.fixture(autouse=True) @@ -65,7 +59,7 @@ def fixture_reset_init_hook(): def test_nothing_set(): - dysql.databases._DEFAULT_CONNECTION_PARAMS.clear() + dysql.databases._DEFAULT_CONNECTION_PARAMS_BY_KEY.clear() with pytest.raises(DBNotPreparedError) as error: query() assert ( @@ -107,25 +101,31 @@ def test_init_hook(mock_engine): init_hook = mock.MagicMock() set_database_init_hook(init_hook) set_default_connection_parameters("h", "u", "p", "d") + dysql.databases.set_current_database("d") mock_engine.connect().execution_options().execute.return_value = [] query() init_hook.assert_called_once_with("d", mock_engine) -@pytest.mark.skipif( - "3.6" in sys.version, reason="set_current_database is not supported on python 3.6" -) def test_init_hook_multiple_databases(mock_engine): init_hook = mock.MagicMock() set_database_init_hook(init_hook) - set_default_connection_parameters("h", "u", "p", "d1") + set_default_connection_parameters("h1", "u1", "p1", "d1") + set_default_connection_parameters("h2", "u2", "p2", "d2") mock_engine.connect().execution_options().execute.return_value = [] query() + dysql.databases.set_current_database("d1") + query() + assert init_hook.call_args_list == [ + mock.call("test", mock_engine), + mock.call("d1", mock_engine), + ] dysql.databases.set_current_database("d2") query() assert init_hook.call_args_list == [ + mock.call("test", mock_engine), mock.call("d1", mock_engine), mock.call("d2", mock_engine), ] @@ -150,18 +150,24 @@ def test_current_database_default(mock_engine, mock_create_engine): ) -def test_different_charset(mock_engine, mock_create_engine): +def test_different_charset_collation(mock_engine, mock_create_engine): db_container = dysql.databases.DatabaseContainerSingleton() set_default_connection_parameters( - "host", "user", "password", "database", charset="other" + "host", + "user", + "password", + "database", + charset="other", + collation="other_collation", ) assert len(db_container) == 0 + dysql.databases.set_current_database("database") mock_engine.connect().execution_options().execute.return_value = [] query() # Only one database is initialized mock_create_engine.assert_called_once_with( - "mysql+mysqlconnector://user:password@host:3306/database?charset=other", + "mysql+mysqlconnector://user:password@host:3306/database?charset=other&collation=other_collation", echo=False, pool_pre_ping=True, pool_recycle=3600, @@ -170,18 +176,13 @@ def test_different_charset(mock_engine, mock_create_engine): def test_is_set_current_database_supported(): - # This test only returns different outputs depending on the python runtime - if "3.6" in sys.version: - assert not dysql.databases.is_set_current_database_supported() - else: - assert dysql.databases.is_set_current_database_supported() + # This test only asserts that true is always returned since this method is deprecated + assert dysql.databases.is_set_current_database_supported() -@pytest.mark.skipif( - "3.6" in sys.version, reason="set_current_database is not supported on python 3.6" -) def test_current_database_set(mock_engine, mock_create_engine): db_container = dysql.databases.DatabaseContainerSingleton() + set_default_connection_parameters("h1", "u1", "p1", "d1", database_key="db1") dysql.databases.set_current_database("db1") mock_engine.connect().execution_options().execute.return_value = [] query() @@ -190,7 +191,7 @@ def test_current_database_set(mock_engine, mock_create_engine): assert "db1" in db_container assert db_container.current_database.database == "db1" mock_create_engine.assert_called_once_with( - "mysql+mysqlconnector://user:password@fake:3306/db1?charset=utf8", + "mysql+mysqlconnector://u1:p1@h1:3306/d1?charset=utf8", echo=False, pool_pre_ping=True, pool_recycle=3600, @@ -198,9 +199,6 @@ def test_current_database_set(mock_engine, mock_create_engine): ) -@pytest.mark.skipif( - "3.6" in sys.version, reason="set_current_database is not supported on python 3.6" -) def test_current_database_cached(mock_engine, mock_create_engine): db_container = dysql.databases.DatabaseContainerSingleton() mock_engine.connect().execution_options().execute.return_value = [] @@ -210,6 +208,7 @@ def test_current_database_cached(mock_engine, mock_create_engine): assert "test" in db_container assert db_container.current_database.database == "test" + set_default_connection_parameters("h1", "u1", "p1", "db1") dysql.databases.set_current_database("db1") query() assert len(db_container) == 2 @@ -232,7 +231,7 @@ def test_current_database_cached(mock_engine, mock_create_engine): pool_size=10, ), mock.call( - "mysql+mysqlconnector://user:password@fake:3306/db1?charset=utf8", + "mysql+mysqlconnector://u1:p1@h1:3306/db1?charset=utf8", echo=False, pool_pre_ping=True, pool_recycle=3600, diff --git a/dysql/test/test_sql_decorator.py b/dysql/test/test_sql_decorator.py index cf4f3ff..c236dd0 100644 --- a/dysql/test/test_sql_decorator.py +++ b/dysql/test/test_sql_decorator.py @@ -17,10 +17,7 @@ QueryData, QueryDataError, ) -from dysql.test import mock_create_engine_fixture, setup_mock_engine - - -_ = mock_create_engine_fixture +from dysql.test import setup_mock_engine class TestSqlSelectDecorator: diff --git a/dysql/test/test_sql_exists_decorator.py b/dysql/test/test_sql_exists_decorator.py index 0eb8af9..daf56a1 100644 --- a/dysql/test/test_sql_exists_decorator.py +++ b/dysql/test/test_sql_exists_decorator.py @@ -11,11 +11,9 @@ import pytest from dysql import sqlexists, QueryData -from dysql.test import mock_create_engine_fixture, setup_mock_engine +from dysql.test import setup_mock_engine -_ = mock_create_engine_fixture - TRUE_QUERY = "SELECT 1 from table" TRUE_QUERY_PARAMS = "SELECT 1 from table where key=:key" FALSE_QUERY = "SELECT 1 from false_table " diff --git a/dysql/test/test_sql_in_list_templates.py b/dysql/test/test_sql_in_list_templates.py index 9a0987c..a12531a 100644 --- a/dysql/test/test_sql_in_list_templates.py +++ b/dysql/test/test_sql_in_list_templates.py @@ -12,17 +12,13 @@ import dysql from dysql import QueryData, sqlquery from dysql.test import ( - _verify_query, - _verify_query_args, - _verify_query_params, - mock_create_engine_fixture, + verify_query, + verify_query_args, + verify_query_params, setup_mock_engine, ) -_ = mock_create_engine_fixture - - @pytest.fixture(name="mock_engine", autouse=True) def mock_engine_fixture(mock_create_engine): mock_engine = setup_mock_engine(mock_create_engine) @@ -35,7 +31,7 @@ def test_list_in_numbers(mock_engine): "SELECT * FROM table WHERE {in__column_a}", template_params={"in__column_a": [1, 2, 3, 4]}, ) - _verify_query_params( + verify_query_params( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1, :in__column_a_2, :in__column_a_3 ) ", { @@ -52,7 +48,7 @@ def test_list_in__strings(mock_engine): "SELECT * FROM table WHERE {in__column_a}", template_params={"in__column_a": ["a", "b", "c", "d"]}, ) - _verify_query_params( + verify_query_params( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1, :in__column_a_2, :in__column_a_3 ) ", { @@ -69,7 +65,7 @@ def test_list_not_in_numbers(mock_engine): "SELECT * FROM table WHERE {not_in__column_b}", template_params={"not_in__column_b": [1, 2, 3, 4]}, ) - _verify_query_params( + verify_query_params( mock_engine, "SELECT * FROM table WHERE column_b NOT IN ( :not_in__column_b_0, :not_in__column_b_1, " ":not_in__column_b_2, :not_in__column_b_3 ) ", @@ -87,7 +83,7 @@ def test_list_not_in_strings(mock_engine): "SELECT * FROM table WHERE {not_in__column_b}", template_params={"not_in__column_b": ["a", "b", "c", "d"]}, ) - _verify_query_params( + verify_query_params( mock_engine, "SELECT * FROM table WHERE column_b NOT IN ( :not_in__column_b_0, :not_in__column_b_1, " ":not_in__column_b_2, :not_in__column_b_3 ) ", @@ -104,7 +100,7 @@ def test_list_in_handles_empty(mock_engine): _query( "SELECT * FROM table WHERE {in__column_a}", template_params={"in__column_a": []} ) - _verify_query(mock_engine, "SELECT * FROM table WHERE 1 <> 1 ") + verify_query(mock_engine, "SELECT * FROM table WHERE 1 <> 1 ") def test_list_in_handles_no_param(): @@ -119,7 +115,7 @@ def test_list_in_multiple_lists(mock_engine): "SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", template_params={"in__column_a": ["first", "second"], "in__column_b": [1, 2]}, ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) " "OR column_b IN ( :in__column_b_0, :in__column_b_1 ) ", @@ -131,7 +127,7 @@ def test_list_in_multiple_lists_one_empty(mock_engine): "SELECT * FROM table WHERE {in__column_a} OR {in__column_b}", template_params={"in__column_a": ["first", "second"], "in__column_b": []}, ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE column_a IN ( :in__column_a_0, :in__column_a_1 ) OR 1 <> 1 ", ) @@ -159,7 +155,7 @@ def test_list_not_in_handles_empty(mock_engine): "SELECT * FROM table WHERE {not_in__column_b}", template_params={"not_in__column_b": []}, ) - _verify_query(mock_engine, "SELECT * FROM table WHERE 1 = 1 ") + verify_query(mock_engine, "SELECT * FROM table WHERE 1 = 1 ") def test_list_not_in_handles_no_param(): @@ -173,7 +169,7 @@ def test_list_gives_template_space_before(mock_engine): _query( "SELECT * FROM table WHERE{in__space}", template_params={"in__space": [9, 8]} ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) ", ) @@ -184,7 +180,7 @@ def test_list_gives_template_space_after(mock_engine): "SELECT * FROM table WHERE {in__space}AND other_condition = 1", template_params={"in__space": [9, 8]}, ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1", ) @@ -195,7 +191,7 @@ def test_list_gives_template_space_before_and_after(mock_engine): "SELECT * FROM table WHERE{in__space}AND other_condition = 1", template_params={"in__space": [9, 8]}, ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE space IN ( :in__space_0, :in__space_1 ) AND other_condition = 1", ) @@ -203,7 +199,7 @@ def test_list_gives_template_space_before_and_after(mock_engine): def test_in_contains_whitespace(mock_engine): _query("{in__column_one}", template_params={"in__column_one": [1, 2]}) - _verify_query( + verify_query( mock_engine, " column_one IN ( :in__column_one_0, :in__column_one_1 ) " ) @@ -220,11 +216,11 @@ def test_template_handles_table_qualifier(mock_engine): "SELECT * FROM table WHERE {in__table.column}", template_params={"in__table.column": [1, 2]}, ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) ", ) - _verify_query_args(mock_engine, {"in__table_column_0": 1, "in__table_column_1": 2}) + verify_query_args(mock_engine, {"in__table_column_0": 1, "in__table_column_1": 2}) def test_template_handles_multiple_table_qualifier(mock_engine): @@ -235,12 +231,12 @@ def test_template_handles_multiple_table_qualifier(mock_engine): "not_in__other_column": ["a", "b"], }, ) - _verify_query( + verify_query( mock_engine, "SELECT * FROM table WHERE table.column IN ( :in__table_column_0, :in__table_column_1 ) " "AND other_column NOT IN ( :not_in__other_column_0, :not_in__other_column_1 ) ", ) - _verify_query_args( + verify_query_args( mock_engine, { "in__table_column_0": 1, @@ -253,7 +249,7 @@ def test_template_handles_multiple_table_qualifier(mock_engine): def test_empty_in_contains_whitespace(mock_engine): _query("{in__column_one}", template_params={"in__column_one": []}) - _verify_query(mock_engine, " 1 <> 1 ") + verify_query(mock_engine, " 1 <> 1 ") def test_multiple_templates_same_column_diff_table(mock_engine): diff --git a/dysql/test/test_sql_insert_templates.py b/dysql/test/test_sql_insert_templates.py index ad7fc5d..56eb584 100644 --- a/dysql/test/test_sql_insert_templates.py +++ b/dysql/test/test_sql_insert_templates.py @@ -11,16 +11,12 @@ import dysql from dysql import QueryData, sqlupdate, QueryDataError from dysql.test import ( - _verify_query, - _verify_query_args, - mock_create_engine_fixture, + verify_query, + verify_query_args, setup_mock_engine, ) -_ = mock_create_engine_fixture - - @pytest.fixture(name="mock_engine", autouse=True) def mock_engine_fixture(mock_create_engine): initial_id = 0 @@ -50,7 +46,7 @@ def select_with_string(): def test_insert_single_column(mock_engine): insert_into_single_value(["Tom", "Jerry"]) - _verify_query( + verify_query( mock_engine, "INSERT INTO table(name) VALUES ( :values__name_col_0 ), ( :values__name_col_1 ) ", ) @@ -58,9 +54,7 @@ def test_insert_single_column(mock_engine): def test_insert_single_column_single_value(mock_engine): insert_into_single_value("Tom") - _verify_query( - mock_engine, "INSERT INTO table(name) VALUES ( :values__name_col_0 ) " - ) + verify_query(mock_engine, "INSERT INTO table(name) VALUES ( :values__name_col_0 ) ") def test_insert_single_value_empty(): @@ -84,12 +78,12 @@ def test_insert_multiple_values(mock_engine): {"name": "Jerry", "email": "jerry@adobe.com"}, ] ) - _verify_query( + verify_query( mock_engine, "INSERT INTO table(name, email) VALUES ( :values__users_0_0, :values__users_0_1 ), " "( :values__users_1_0, :values__users_1_1 ) ", ) - _verify_query_args( + verify_query_args( mock_engine, { "values__users_0_0": "Tom", diff --git a/pyproject.toml b/pyproject.toml index d6ef188..8a226da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ {name = "Adobe", email = "noreply@adobe.com"} ] urls = { "Homepage" = "https://github.com/adobe/dy-sql" } -version = "3.1" +version = "3.2" dependencies = [ # SQLAlchemy 2+ is not yet supported "sqlalchemy<2", diff --git a/uv.lock b/uv.lock index 11f32ee..dae7c32 100644 --- a/uv.lock +++ b/uv.lock @@ -261,7 +261,7 @@ wheels = [ [[package]] name = "dy-sql" -version = "3.1" +version = "3.2" source = { editable = "." } dependencies = [ { name = "pydantic" },