From 0e3eeaa6d405a1f2eef6a437b300354cf9089e7e Mon Sep 17 00:00:00 2001 From: Jaromir Hamala Date: Mon, 30 Dec 2024 13:36:52 +0100 Subject: [PATCH] Implements QuestDB's LIMIT clause in SQLAlchemy dialect. Key features: - Standard LIMIT N syntax - LIMIT lower,upper range syntax (lower exclusive, upper inclusive) - Support for expressions and bind parameters - Automatic conversion of SQLAlchemy's LIMIT/OFFSET Example: select(table).limit(5) # LIMIT 5 select(table).limit(3).offset(2) # LIMIT 2,5 select(table).offset(8) # LIMIT 8,BIGINT_MAX --- pyproject.toml | 2 +- src/questdb_connect/compilers.py | 33 ++++++ tests/conftest.py | 31 ++++++ tests/test_dialect.py | 185 +++++++++++++++++++++++++++++++ 4 files changed, 250 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index adb048c..e82bf0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ max-branches = 20 max-args = 10 [tool.ruff.per-file-ignores] -'tests/test_dialect.py' = ['S101'] +'tests/test_dialect.py' = ['S101', 'PLR2004'] 'tests/test_types.py' = ['S101'] 'tests/test_superset.py' = ['S101'] 'tests/conftest.py' = ['S608'] diff --git a/src/questdb_connect/compilers.py b/src/questdb_connect/compilers.py index 7994514..e45430e 100644 --- a/src/questdb_connect/compilers.py +++ b/src/questdb_connect/compilers.py @@ -30,9 +30,42 @@ def get_column_specification(self, column: sqlalchemy.Column, **_): class QDBSQLCompiler(sqlalchemy.sql.compiler.SQLCompiler, abc.ABC): + # Maximum value for 64-bit signed integer (2^63 - 1) + BIGINT_MAX = 9223372036854775807 + def _is_safe_for_fast_insert_values_helper(self): return True def visit_textclause(self, textclause, add_to_result_map=None, **kw): textclause.text = remove_public_schema(textclause.text) return super().visit_textclause(textclause, add_to_result_map, **kw) + + def limit_clause(self, select, **kw): + """ + Generate QuestDB-style LIMIT clause from SQLAlchemy select statement. + QuestDB supports arbitrary expressions in LIMIT clause. + """ + text = "" + limit = select._limit_clause + offset = select._offset_clause + + if limit is None and offset is None: + return text + + text += "\n LIMIT " + + # Handle cases based on presence of limit and offset + if limit is not None and offset is not None: + # Convert LIMIT x OFFSET y to LIMIT y,y+x + lower_bound = self.process(offset, **kw) + limit_val = self.process(limit, **kw) + text += f"{lower_bound},{lower_bound} + {limit_val}" + + elif limit is not None: + text += self.process(limit, **kw) + + elif offset is not None: + # If only offset is specified, use max bigint as upper bound + text += f"{self.process(offset, **kw)},{self.BIGINT_MAX}" + + return text diff --git a/tests/conftest.py b/tests/conftest.py index 850f5fa..222925f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import os +import time from typing import NamedTuple import pytest @@ -126,6 +127,36 @@ def collect_select_all(session, expected_rows) -> str: return '\n'.join(str(row) for row in rs) +def wait_until_table_is_ready(test_engine, table_name, expected_rows, timeout=10): + """ + Wait until a table has the expected number of rows, with timeout. + + Args: + test_engine: SQLAlchemy engine + table_name: Name of the table to check + expected_rows: Expected number of rows + timeout: Maximum time to wait in seconds (default: 10 seconds) + + Returns: + bool: True if table is ready, False if timeout occurred + + Raises: + sqlalchemy.exc.SQLAlchemyError: If there's a database error + """ + start_time = time.time() + + while time.time() - start_time < timeout: + with test_engine.connect() as conn: + result = conn.execute(text(f'SELECT count(*) FROM {table_name}')) + row = result.fetchone() + if row and row[0] == expected_rows: + return True + + print(f'Waiting for table {table_name} to have {expected_rows} rows, current: {row[0] if row else 0}') + time.sleep(0.01) # Wait 10ms between checks + return False + + def collect_select_all_raw_connection(test_engine, expected_rows) -> str: conn = test_engine.raw_connection() try: diff --git a/tests/test_dialect.py b/tests/test_dialect.py index d7c0b6f..ba529cf 100644 --- a/tests/test_dialect.py +++ b/tests/test_dialect.py @@ -9,6 +9,7 @@ METRICS_TABLE_NAME, collect_select_all, collect_select_all_raw_connection, + wait_until_table_is_ready, ) @@ -257,3 +258,187 @@ def test_keywords(test_engine): sql = sqla.text("SELECT keyword FROM keywords()") expected = [row[0] for row in conn.execute(sql).fetchall()] assert qdbc.get_keywords_list() == expected + + +def test_limit_clause_basic(test_engine, test_model): + """Test basic LIMIT clause functionality.""" + now = datetime.datetime(2023, 4, 12, 23, 55, 59, 342380) + now_date = now.date() + session = Session(test_engine) + num_rows = 10 + + try: + # Insert test data + models = [ + test_model( + col_boolean=True, + col_byte=8, + col_short=12, + col_int=idx, # Using idx to make rows distinct and ordered + col_long=14, + col_float=15.234, + col_double=16.88993244, + col_symbol='coconut', + col_string='banana', + col_char='C', + col_uuid='6d5eb038-63d1-4971-8484-30c16e13de5b', + col_date=now_date, + col_ts=now, + col_geohash='dfvgsj2vptwu', + col_long256='0xa3b400fcf6ed707d710d5d4e672305203ed3cc6254d1cefe313e4a465861f42a', + col_varchar='pineapple' + ) for idx in range(num_rows) + ] + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, num_rows) + + with test_engine.connect() as conn: + # simple LIMIT + query = sqla.select(table).limit(5) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 5 + assert rows[0].col_int == 0 + assert rows[-1].col_int == 4 + + # LIMIT with OFFSET + query = sqla.select(table).limit(3).offset(2) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 3 + assert rows[0].col_int == 2 + assert rows[-1].col_int == 4 + + # OFFSET only + query = sqla.select(table).offset(8) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 2 + assert rows[0].col_int == 8 + assert rows[-1].col_int == 9 + + # LIMIT 0 + query = sqla.select(table).limit(0) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 0 + + # LIMIT 0 and offset + query = sqla.select(table).limit(0).offset(1) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 0 + + finally: + if session: + session.close() + + +def test_limit_clause_with_binds_and_expressions(test_engine, test_model): + """Test LIMIT clause with bind parameters and expressions.""" + # Setup test data + now = datetime.datetime(2023, 4, 12, 23, 55, 59, 342380) + now_date = now.date() + session = Session(test_engine) + num_rows = 10 + + try: + # Insert test data + models = [ + test_model( + col_boolean=True, + col_byte=8, + col_short=12, + col_int=idx, # Using idx to make rows distinct and ordered + col_long=14, + col_float=15.234, + col_double=16.88993244, + col_symbol='coconut', + col_string='banana', + col_char='C', + col_uuid='6d5eb038-63d1-4971-8484-30c16e13de5b', + col_date=now_date, + col_ts=now, + col_geohash='dfvgsj2vptwu', + col_long256='0xa3b400fcf6ed707d710d5d4e672305203ed3cc6254d1cefe313e4a465861f42a', + col_varchar='pineapple' + ) for idx in range(num_rows) + ] + session.bulk_save_objects(models) + session.commit() + + metadata = sqla.MetaData() + table = sqla.Table(ALL_TYPES_TABLE_NAME, metadata, autoload_with=test_engine) + + wait_until_table_is_ready(test_engine, ALL_TYPES_TABLE_NAME, num_rows) + + with test_engine.connect() as conn: + # simple bindparam + result = conn.execute( + sqla.select(table).limit(sqla.bindparam('limit_val')), + {"limit_val": 5} + ) + rows = result.fetchall() + assert len(rows) == 5 + assert rows[0].col_int == 0 + assert rows[-1].col_int == 4 + + # bindparam with expressions + result = conn.execute( + sqla.select(table).limit(sqla.bindparam('base_limit') * 2), + {"base_limit": 3} + ) + rows = result.fetchall() + assert len(rows) == 6 + assert rows[0].col_int == 0 + assert rows[-1].col_int == 5 + + # multiple bindparams in expression + result = conn.execute( + sqla.select(table).limit( + sqla.bindparam('limit_val') + ).offset( + sqla.bindparam('offset_val') + ), + { + "limit_val": 3, + "offset_val": 2 + } + ) + rows = result.fetchall() + assert len(rows) == 3 + assert rows[0].col_int == 2 + assert rows[-1].col_int == 4 + + # bindparam with type specification + from sqlalchemy import Integer + result = conn.execute( + sqla.select(table).limit( + sqla.bindparam('limit_val', type_=Integer) + 1 + ), + {"limit_val": 4} + ) + rows = result.fetchall() + assert len(rows) == 5 + assert rows[0].col_int == 0 + assert rows[-1].col_int == 4 + + # text() and bindparam + from sqlalchemy import text + result = conn.execute( + text("SELECT * FROM all_types_table LIMIT :lo, :hi"), + {"lo": 3, "hi": 8} + ) + rows = result.fetchall() + assert len(rows) == 5 + assert rows[0].col_int == 3 + assert rows[-1].col_int == 7 + + finally: + if session: + session.close()