Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions .github/workflows/db_provider_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: DB Provider Tests

on:
push:
branches: [main]
pull_request_review:
types: [submitted]
branches: [main]

jobs:
test:
if: github.event.review.state == 'approved' || github.event_name == 'push'
runs-on: ubuntu-latest

services:
postgres:
image: postgres:latest
env:
POSTGRES_USER: test
POSTGRES_PASSWORD: test
POSTGRES_DB: test
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5

mysql:
image: mysql:latest
env:
MYSQL_ROOT_PASSWORD: test
MYSQL_DATABASE: test
ports:
- 3306:3306
options: >-
--health-cmd="mysqladmin ping"
--health-interval=10s
--health-timeout=5s
--health-retries=3

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install -r requirements-dev.txt
pip install -e .
- name: Run tests
run: pytest
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ anthropic
click
flask
pytest
pytest-postgresql
psycopg2-binary
3 changes: 2 additions & 1 deletion tabletalk/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def get_db_provider(config: Dict[str, Any]) -> DatabaseProvider:
elif provider_type == "postgres":
return PostgresProvider(
host=config["host"],
database=config["database"],
port=int(config.get("port", 5432)),
dbname=config["database"],
user=config["user"],
password=config["password"],
)
Expand Down
11 changes: 7 additions & 4 deletions tabletalk/providers/postgres_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,26 @@


class PostgresProvider(DatabaseProvider):
def __init__(self, host: str, database: str, user: str, password: str):
def __init__(self, host: str, port: int, dbname: str, user: str, password: str):
"""
Initialize PostgreSQL provider with connection string.

Args:
host (str): PostgreSQL host
database (str): PostgreSQL database name
port (int): PostgreSQL port
dbname (str): PostgreSQL database name
user (str): PostgreSQL user
password (str): PostgreSQL password
"""
self.host = host
self.database = database
self.port = port
self.dbname = dbname
self.user = user
self.password = password
self.connection = psycopg2.connect(
host=self.host,
database=self.database,
port=self.port,
dbname=self.dbname,
user=self.user,
password=self.password,
)
Expand Down
125 changes: 125 additions & 0 deletions tabletalk/tests/providers/test_mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import Any, Dict, Generator, Union

import mysql.connector
import pytest
from mysql.connector.abstracts import MySQLConnectionAbstract
from mysql.connector.pooling import PooledMySQLConnection

from tabletalk.providers.mysql_provider import MySQLProvider

TEST_CONFIG = {
"host": "localhost",
"port": 3306,
"database": "test",
"user": "root",
"password": "test",
}

ConnectionType = Union[PooledMySQLConnection, MySQLConnectionAbstract]


@pytest.fixture(scope="function")
def mysql_db() -> Generator[Dict[str, Any], None, None]:
"""Set up a simple test database"""
conn: ConnectionType = mysql.connector.connect(
host=TEST_CONFIG["host"],
port=TEST_CONFIG["port"],
user=TEST_CONFIG["user"],
password=TEST_CONFIG["password"],
)

# Handle autocommit based on connection type
if isinstance(conn, PooledMySQLConnection):
real_conn = conn.get_connection()
real_conn.autocommit = True
else:
conn.autocommit = True

with conn.cursor() as cur:
cur.execute(f"DROP DATABASE IF EXISTS {TEST_CONFIG['database']}")
cur.execute(f"CREATE DATABASE {TEST_CONFIG['database']}")

conn.close()
conn = mysql.connector.connect(
host=TEST_CONFIG["host"],
port=TEST_CONFIG["port"],
user=TEST_CONFIG["user"],
password=TEST_CONFIG["password"],
database=TEST_CONFIG["database"],
)

# Handle autocommit again for the new connection
if isinstance(conn, PooledMySQLConnection):
real_conn = conn.get_connection()
real_conn.autocommit = True
else:
conn.autocommit = True

with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE users (
id INT AUTO_INCREMENT PRIMARY KEY,
name TEXT,
age INTEGER
)
"""
)
cur.execute(
"""
INSERT INTO users (name, age)
VALUES ('Alice', 30), ('Bob', 25)
"""
)
cur.execute(
"""
CREATE VIEW adult_users AS
SELECT name, age FROM users WHERE age >= 18
"""
)
conn.commit()
conn.close()

yield TEST_CONFIG


@pytest.fixture
def mysql_provider(
mysql_db: Dict[str, Any],
) -> Generator[MySQLProvider, None, None]:
provider = MySQLProvider(
host=mysql_db["host"],
port=mysql_db["port"],
database=mysql_db["database"],
user=mysql_db["user"],
password=mysql_db["password"],
)
yield provider
provider.connection.close()


def test_basic_query(mysql_provider: MySQLProvider) -> None:
results = mysql_provider.execute_query("SELECT * FROM users ORDER BY id")
assert len(results) == 2
assert results[0]["name"] == "Alice"
assert results[1]["name"] == "Bob"


def test_table_and_view_schema(mysql_provider: MySQLProvider) -> None:
schemas = mysql_provider.get_compact_tables()
assert len(schemas) == 2
table_schema = next(s for s in schemas if s["t"] == "users")
view_schema = next(s for s in schemas if s["t"] == "adult_users")
assert len(table_schema["f"]) == 3
assert [f["n"] for f in table_schema["f"]] == ["id", "name", "age"]
assert len(view_schema["f"]) == 2
assert [f["n"] for f in view_schema["f"]] == ["name", "age"]


def test_view_query(mysql_provider: MySQLProvider) -> None:
results = mysql_provider.execute_query("SELECT * FROM adult_users ORDER BY name")
assert len(results) == 2
assert results[0]["name"] == "Alice"
assert results[0]["age"] == 30
assert results[1]["name"] == "Bob"
assert results[1]["age"] == 25
110 changes: 110 additions & 0 deletions tabletalk/tests/providers/test_postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any, Dict, Generator

import pytest
from psycopg2 import connect

from tabletalk.providers.postgres_provider import PostgresProvider

TEST_CONFIG = {
"host": "localhost",
"port": 5432,
"dbname": "test_db",
"user": "test",
"password": "test",
}


@pytest.fixture(scope="function")
def postgres_db() -> Generator[Dict[str, Any], None, None]:
"""Set up a simple test database"""
# Create fresh database
admin_conn = connect(
host=TEST_CONFIG["host"],
port=TEST_CONFIG["port"],
user=TEST_CONFIG["user"],
password=TEST_CONFIG["password"],
dbname="postgres",
)
admin_conn.autocommit = True
with admin_conn.cursor() as cur:
cur.execute(f"DROP DATABASE IF EXISTS {TEST_CONFIG['dbname']}")
cur.execute(f"CREATE DATABASE {TEST_CONFIG['dbname']}")
admin_conn.close()

# Set up schema and data
conn = connect(
host=TEST_CONFIG["host"],
port=TEST_CONFIG["port"],
dbname=TEST_CONFIG["dbname"],
user=TEST_CONFIG["user"],
password=TEST_CONFIG["password"],
)
with conn.cursor() as cur:
cur.execute(
"""
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name TEXT,
age INTEGER
)
"""
)
cur.execute(
"""
INSERT INTO users (name, age)
VALUES ('Alice', 30), ('Bob', 25)
"""
)
cur.execute(
"""
CREATE VIEW adult_users AS
SELECT name, age FROM users WHERE age >= 18
"""
)
conn.commit()
conn.close()

yield TEST_CONFIG


# Rest of the code remains unchanged
@pytest.fixture
def postgres_provider(
postgres_db: Dict[str, Any],
) -> Generator[PostgresProvider, None, None]:
provider = PostgresProvider(
host=postgres_db["host"],
port=postgres_db["port"],
dbname=postgres_db["dbname"],
user=postgres_db["user"],
password=postgres_db["password"],
)
yield provider
provider.connection.close()


def test_basic_query(postgres_provider: PostgresProvider) -> None:
results = postgres_provider.execute_query("SELECT * FROM users ORDER BY id")
assert len(results) == 2
assert results[0]["name"] == "Alice"
assert results[1]["name"] == "Bob"


def test_table_and_view_schema(postgres_provider: PostgresProvider) -> None:
schemas = postgres_provider.get_compact_tables()
assert len(schemas) == 2
table_schema = next(s for s in schemas if s["t"] == "users")
view_schema = next(s for s in schemas if s["t"] == "adult_users")
assert len(table_schema["f"]) == 3
assert [f["n"] for f in table_schema["f"]] == ["id", "name", "age"]
assert len(view_schema["f"]) == 2
assert [f["n"] for f in view_schema["f"]] == ["name", "age"]


def test_view_query(postgres_provider: PostgresProvider) -> None:
results = postgres_provider.execute_query("SELECT * FROM adult_users ORDER BY name")
assert len(results) == 2
assert results[0]["name"] == "Alice"
assert results[0]["age"] == 30
assert results[1]["name"] == "Bob"
assert results[1]["age"] == 25