diff --git a/.github/workflows/db_provider_tests.yml b/.github/workflows/db_provider_tests.yml new file mode 100644 index 0000000..50d00a4 --- /dev/null +++ b/.github/workflows/db_provider_tests.yml @@ -0,0 +1,53 @@ +name: DB Provider Tests + +on: + push: + branches: [main] + pull_request_review: + types: [submitted] + branches: [main] + +jobs: + test: + if: github.event.review.state == 'approved' || github.event_name == 'push' + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + mysql: + image: mysql:latest + env: + MYSQL_ROOT_PASSWORD: test + MYSQL_DATABASE: test + ports: + - 3306:3306 + options: >- + --health-cmd="mysqladmin ping" + --health-interval=10s + --health-timeout=5s + --health-retries=3 + + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.11' + - name: Install dependencies + run: | + pip install -r requirements-dev.txt + pip install -e . + - name: Run tests + run: pytest diff --git a/requirements-dev.txt b/requirements-dev.txt index 864a7dc..876f8c1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,3 +8,5 @@ anthropic click flask pytest +pytest-postgresql +psycopg2-binary diff --git a/tabletalk/factories.py b/tabletalk/factories.py index f37155a..2a85abd 100644 --- a/tabletalk/factories.py +++ b/tabletalk/factories.py @@ -55,7 +55,8 @@ def get_db_provider(config: Dict[str, Any]) -> DatabaseProvider: elif provider_type == "postgres": return PostgresProvider( host=config["host"], - database=config["database"], + port=int(config.get("port", 5432)), + dbname=config["database"], user=config["user"], password=config["password"], ) diff --git a/tabletalk/providers/postgres_provider.py b/tabletalk/providers/postgres_provider.py index c334481..e53b54d 100644 --- a/tabletalk/providers/postgres_provider.py +++ b/tabletalk/providers/postgres_provider.py @@ -7,23 +7,26 @@ class PostgresProvider(DatabaseProvider): - def __init__(self, host: str, database: str, user: str, password: str): + def __init__(self, host: str, port: int, dbname: str, user: str, password: str): """ Initialize PostgreSQL provider with connection string. Args: host (str): PostgreSQL host - database (str): PostgreSQL database name + port (int): PostgreSQL port + dbname (str): PostgreSQL database name user (str): PostgreSQL user password (str): PostgreSQL password """ self.host = host - self.database = database + self.port = port + self.dbname = dbname self.user = user self.password = password self.connection = psycopg2.connect( host=self.host, - database=self.database, + port=self.port, + dbname=self.dbname, user=self.user, password=self.password, ) diff --git a/tabletalk/tests/providers/test_mysql.py b/tabletalk/tests/providers/test_mysql.py new file mode 100644 index 0000000..6a25037 --- /dev/null +++ b/tabletalk/tests/providers/test_mysql.py @@ -0,0 +1,125 @@ +from typing import Any, Dict, Generator, Union + +import mysql.connector +import pytest +from mysql.connector.abstracts import MySQLConnectionAbstract +from mysql.connector.pooling import PooledMySQLConnection + +from tabletalk.providers.mysql_provider import MySQLProvider + +TEST_CONFIG = { + "host": "localhost", + "port": 3306, + "database": "test", + "user": "root", + "password": "test", +} + +ConnectionType = Union[PooledMySQLConnection, MySQLConnectionAbstract] + + +@pytest.fixture(scope="function") +def mysql_db() -> Generator[Dict[str, Any], None, None]: + """Set up a simple test database""" + conn: ConnectionType = mysql.connector.connect( + host=TEST_CONFIG["host"], + port=TEST_CONFIG["port"], + user=TEST_CONFIG["user"], + password=TEST_CONFIG["password"], + ) + + # Handle autocommit based on connection type + if isinstance(conn, PooledMySQLConnection): + real_conn = conn.get_connection() + real_conn.autocommit = True + else: + conn.autocommit = True + + with conn.cursor() as cur: + cur.execute(f"DROP DATABASE IF EXISTS {TEST_CONFIG['database']}") + cur.execute(f"CREATE DATABASE {TEST_CONFIG['database']}") + + conn.close() + conn = mysql.connector.connect( + host=TEST_CONFIG["host"], + port=TEST_CONFIG["port"], + user=TEST_CONFIG["user"], + password=TEST_CONFIG["password"], + database=TEST_CONFIG["database"], + ) + + # Handle autocommit again for the new connection + if isinstance(conn, PooledMySQLConnection): + real_conn = conn.get_connection() + real_conn.autocommit = True + else: + conn.autocommit = True + + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE users ( + id INT AUTO_INCREMENT PRIMARY KEY, + name TEXT, + age INTEGER + ) + """ + ) + cur.execute( + """ + INSERT INTO users (name, age) + VALUES ('Alice', 30), ('Bob', 25) + """ + ) + cur.execute( + """ + CREATE VIEW adult_users AS + SELECT name, age FROM users WHERE age >= 18 + """ + ) + conn.commit() + conn.close() + + yield TEST_CONFIG + + +@pytest.fixture +def mysql_provider( + mysql_db: Dict[str, Any], +) -> Generator[MySQLProvider, None, None]: + provider = MySQLProvider( + host=mysql_db["host"], + port=mysql_db["port"], + database=mysql_db["database"], + user=mysql_db["user"], + password=mysql_db["password"], + ) + yield provider + provider.connection.close() + + +def test_basic_query(mysql_provider: MySQLProvider) -> None: + results = mysql_provider.execute_query("SELECT * FROM users ORDER BY id") + assert len(results) == 2 + assert results[0]["name"] == "Alice" + assert results[1]["name"] == "Bob" + + +def test_table_and_view_schema(mysql_provider: MySQLProvider) -> None: + schemas = mysql_provider.get_compact_tables() + assert len(schemas) == 2 + table_schema = next(s for s in schemas if s["t"] == "users") + view_schema = next(s for s in schemas if s["t"] == "adult_users") + assert len(table_schema["f"]) == 3 + assert [f["n"] for f in table_schema["f"]] == ["id", "name", "age"] + assert len(view_schema["f"]) == 2 + assert [f["n"] for f in view_schema["f"]] == ["name", "age"] + + +def test_view_query(mysql_provider: MySQLProvider) -> None: + results = mysql_provider.execute_query("SELECT * FROM adult_users ORDER BY name") + assert len(results) == 2 + assert results[0]["name"] == "Alice" + assert results[0]["age"] == 30 + assert results[1]["name"] == "Bob" + assert results[1]["age"] == 25 diff --git a/tabletalk/tests/providers/test_postgres.py b/tabletalk/tests/providers/test_postgres.py new file mode 100644 index 0000000..2cb5995 --- /dev/null +++ b/tabletalk/tests/providers/test_postgres.py @@ -0,0 +1,110 @@ +from typing import Any, Dict, Generator + +import pytest +from psycopg2 import connect + +from tabletalk.providers.postgres_provider import PostgresProvider + +TEST_CONFIG = { + "host": "localhost", + "port": 5432, + "dbname": "test_db", + "user": "test", + "password": "test", +} + + +@pytest.fixture(scope="function") +def postgres_db() -> Generator[Dict[str, Any], None, None]: + """Set up a simple test database""" + # Create fresh database + admin_conn = connect( + host=TEST_CONFIG["host"], + port=TEST_CONFIG["port"], + user=TEST_CONFIG["user"], + password=TEST_CONFIG["password"], + dbname="postgres", + ) + admin_conn.autocommit = True + with admin_conn.cursor() as cur: + cur.execute(f"DROP DATABASE IF EXISTS {TEST_CONFIG['dbname']}") + cur.execute(f"CREATE DATABASE {TEST_CONFIG['dbname']}") + admin_conn.close() + + # Set up schema and data + conn = connect( + host=TEST_CONFIG["host"], + port=TEST_CONFIG["port"], + dbname=TEST_CONFIG["dbname"], + user=TEST_CONFIG["user"], + password=TEST_CONFIG["password"], + ) + with conn.cursor() as cur: + cur.execute( + """ + CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name TEXT, + age INTEGER + ) + """ + ) + cur.execute( + """ + INSERT INTO users (name, age) + VALUES ('Alice', 30), ('Bob', 25) + """ + ) + cur.execute( + """ + CREATE VIEW adult_users AS + SELECT name, age FROM users WHERE age >= 18 + """ + ) + conn.commit() + conn.close() + + yield TEST_CONFIG + + +# Rest of the code remains unchanged +@pytest.fixture +def postgres_provider( + postgres_db: Dict[str, Any], +) -> Generator[PostgresProvider, None, None]: + provider = PostgresProvider( + host=postgres_db["host"], + port=postgres_db["port"], + dbname=postgres_db["dbname"], + user=postgres_db["user"], + password=postgres_db["password"], + ) + yield provider + provider.connection.close() + + +def test_basic_query(postgres_provider: PostgresProvider) -> None: + results = postgres_provider.execute_query("SELECT * FROM users ORDER BY id") + assert len(results) == 2 + assert results[0]["name"] == "Alice" + assert results[1]["name"] == "Bob" + + +def test_table_and_view_schema(postgres_provider: PostgresProvider) -> None: + schemas = postgres_provider.get_compact_tables() + assert len(schemas) == 2 + table_schema = next(s for s in schemas if s["t"] == "users") + view_schema = next(s for s in schemas if s["t"] == "adult_users") + assert len(table_schema["f"]) == 3 + assert [f["n"] for f in table_schema["f"]] == ["id", "name", "age"] + assert len(view_schema["f"]) == 2 + assert [f["n"] for f in view_schema["f"]] == ["name", "age"] + + +def test_view_query(postgres_provider: PostgresProvider) -> None: + results = postgres_provider.execute_query("SELECT * FROM adult_users ORDER BY name") + assert len(results) == 2 + assert results[0]["name"] == "Alice" + assert results[0]["age"] == 30 + assert results[1]["name"] == "Bob" + assert results[1]["age"] == 25 diff --git a/tabletalk/tests/providers/test_sqlite_provider.py b/tabletalk/tests/providers/test_sqlite.py similarity index 100% rename from tabletalk/tests/providers/test_sqlite_provider.py rename to tabletalk/tests/providers/test_sqlite.py