Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ test: chk
test-sqla:
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/sqlalchemy/

.PHONY: test-sqla-async
test-sqla-async:
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/sqlalchemy/ --dburi async

.PHONY: tox
tox:
uvx tox@$(TOX_VERSION) -c pyproject.toml run
Expand Down
69 changes: 68 additions & 1 deletion docs/sqlalchemy.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@

# SQLAlchemy

Install SQLAlchemy with `pip install "SQLAlchemy>=1.0.0"` or `pip install PyAthena[SQLAlchemy]`.
Install SQLAlchemy with `pip install "SQLAlchemy>=1.0.0"` or `pip install PyAthena[sqlalchemy]`.
Supported SQLAlchemy is 1.0.0 or higher.

For async support (`create_async_engine`), install with `pip install PyAthena[aiosqlalchemy]`
(requires SQLAlchemy 2.0+).

### Sync

```python
from sqlalchemy import func, select
from sqlalchemy.engine import create_engine
Expand All @@ -24,6 +29,48 @@ with engine.connect() as connection:
print(result.scalar())
```

### Async

```python
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine

conn_str = "awsathena+aiorest://{aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/"\
"{schema_name}?s3_staging_dir={s3_staging_dir}"
engine = create_async_engine(conn_str.format(
aws_access_key_id="YOUR_ACCESS_KEY_ID",
aws_secret_access_key="YOUR_SECRET_ACCESS_KEY",
region_name="us-west-2",
schema_name="default",
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/"))

async def main():
async with engine.connect() as connection:
result = await connection.execute(text("SELECT * FROM many_rows"))
print(result.fetchall())
await engine.dispose()
```

SQLAlchemy's reflection API (`Table(..., autoload_with=)`, `inspect()`) is synchronous
internally, so it cannot be called directly on an async connection. Use `run_sync()` to
bridge the gap:

```python
from sqlalchemy.sql.schema import Table, MetaData

async with engine.connect() as connection:
# Table reflection
table = await connection.run_sync(
lambda sync_conn: Table("my_table", MetaData(), autoload_with=sync_conn)
)

# Schema inspection
import sqlalchemy
schemas = await connection.run_sync(
lambda sync_conn: sqlalchemy.inspect(sync_conn).get_schema_names()
)
```

## Connection string

The connection string has the following format:
Expand All @@ -38,8 +85,16 @@ If you do not specify `aws_access_key_id` and `aws_secret_access_key` using inst
awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&...
```

For async, replace the driver portion (e.g. `+rest` with `+aiorest`):

```text
awsathena+aiorest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_staging_dir={s3_staging_dir}&...
```

## Dialect & driver

### Sync

| Dialect | Driver | Schema | Cursor |
|-----------|--------|------------------|------------------------|
| awsathena | | awsathena | DefaultCursor |
Expand All @@ -49,6 +104,18 @@ awsathena+rest://:@athena.{region_name}.amazonaws.com:443/{schema_name}?s3_stagi
| awsathena | polars | awsathena+polars | {ref}`polars-cursor` |
| awsathena | s3fs | awsathena+s3fs | {ref}`s3fs-cursor` |

### Async

Requires `pip install PyAthena[aiosqlalchemy]` (SQLAlchemy 2.0+).

| Dialect | Driver | Schema | Cursor |
|-----------|-----------|---------------------|------------------------------|
| awsathena | aiorest | awsathena+aiorest | DefaultCursor (async) |
| awsathena | aiopandas | awsathena+aiopandas | {ref}`pandas-cursor` (async) |
| awsathena | aioarrow | awsathena+aioarrow | {ref}`arrow-cursor` (async) |
| awsathena | aiopolars | awsathena+aiopolars | {ref}`polars-cursor` (async) |
| awsathena | aios3fs | awsathena+aios3fs | {ref}`s3fs-cursor` (async) |

## Dialect options

### Table options
Expand Down
1 change: 1 addition & 0 deletions pyathena/aio/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- coding: utf-8 -*-
55 changes: 55 additions & 0 deletions pyathena/aio/sqlalchemy/arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
from typing import TYPE_CHECKING

from pyathena.aio.sqlalchemy.base import AthenaAioDialect
from pyathena.util import strtobool

if TYPE_CHECKING:
from types import ModuleType


class AthenaAioArrowDialect(AthenaAioDialect):
"""Async SQLAlchemy dialect for Amazon Athena with Apache Arrow result format.

This dialect uses ``AioArrowCursor`` for native asyncio query execution
with Apache Arrow Table results.

Connection URL Format:
``awsathena+aioarrow://{access_key}:{secret_key}@athena.{region}.amazonaws.com/{schema}``

Query Parameters:
In addition to the base dialect parameters:
- unload: If "true", use UNLOAD for Parquet output

Example:
>>> from sqlalchemy.ext.asyncio import create_async_engine
>>> engine = create_async_engine(
... "awsathena+aioarrow://:@athena.us-west-2.amazonaws.com/default"
... "?s3_staging_dir=s3://my-bucket/athena-results/"
... "&unload=true"
... )

See Also:
:class:`~pyathena.aio.arrow.cursor.AioArrowCursor`: The underlying async cursor.
:class:`~pyathena.aio.sqlalchemy.base.AthenaAioDialect`: Base async dialect.
"""

driver = "aioarrow"
supports_statement_cache = True

def create_connect_args(self, url):
from pyathena.aio.arrow.cursor import AioArrowCursor

opts = super()._create_connect_args(url)
opts.update({"cursor_class": AioArrowCursor})
cursor_kwargs = {}
if "unload" in opts:
cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))})
if cursor_kwargs:
opts.update({"cursor_kwargs": cursor_kwargs})
self._connect_options = opts
return [[], opts]

@classmethod
def import_dbapi(cls) -> "ModuleType":
return super().import_dbapi()
218 changes: 218 additions & 0 deletions pyathena/aio/sqlalchemy/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from collections import deque
from typing import TYPE_CHECKING, Any, Dict, List, MutableMapping, Optional, Tuple, Union, cast

from sqlalchemy import pool
from sqlalchemy.engine import AdaptedConnection
from sqlalchemy.util.concurrency import await_only

import pyathena
from pyathena.aio.connection import AioConnection
from pyathena.error import (
DatabaseError,
DataError,
Error,
IntegrityError,
InterfaceError,
InternalError,
NotSupportedError,
OperationalError,
ProgrammingError,
)
from pyathena.sqlalchemy.base import AthenaDialect

if TYPE_CHECKING:
from types import ModuleType

from sqlalchemy import URL


class AsyncAdapt_pyathena_cursor: # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi)
"""Wraps any async PyAthena cursor with a sync DBAPI interface.

SQLAlchemy's async engine uses greenlet-based ``await_only()`` to call
async methods from synchronous code running inside the greenlet context.
This adapter wraps an ``AioCursor`` (or variant) so that the dialect can
use a normal synchronous DBAPI interface while the underlying I/O is async.
"""

server_side = False
__slots__ = ("_cursor", "_rows")

def __init__(self, cursor: Any) -> None:
self._cursor = cursor
self._rows: deque[Any] = deque()

@property
def description(self) -> Any:
return self._cursor.description

@property
def rowcount(self) -> int:
return self._cursor.rowcount # type: ignore[no-any-return]

def close(self) -> None:
self._cursor.close()
self._rows.clear()

def execute(self, operation: str, parameters: Any = None, **kwargs: Any) -> Any:
result = await_only(self._cursor.execute(operation, parameters, **kwargs))
if self._cursor.description:
self._rows = deque(await_only(self._cursor.fetchall()))
else:
self._rows.clear()
return result

def executemany(
self,
operation: str,
seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]],
**kwargs: Any,
) -> None:
for parameters in seq_of_parameters:
await_only(self._cursor.execute(operation, parameters, **kwargs))
self._rows.clear()

def fetchone(self) -> Any:
if self._rows:
return self._rows.popleft()
return None

def fetchmany(self, size: Optional[int] = None) -> Any:
if size is None:
size = self._cursor.arraysize if hasattr(self._cursor, "arraysize") else 1
return [self._rows.popleft() for _ in range(min(size, len(self._rows)))]

def fetchall(self) -> Any:
items = list(self._rows)
self._rows.clear()
return items

def setinputsizes(self, sizes: Any) -> None:
self._cursor.setinputsizes(sizes)

# PyAthena-specific methods used by AthenaDialect reflection
def list_databases(self, *args: Any, **kwargs: Any) -> Any:
return await_only(self._cursor.list_databases(*args, **kwargs))

def get_table_metadata(self, *args: Any, **kwargs: Any) -> Any:
return await_only(self._cursor.get_table_metadata(*args, **kwargs))

def list_table_metadata(self, *args: Any, **kwargs: Any) -> Any:
return await_only(self._cursor.list_table_metadata(*args, **kwargs))

def __enter__(self) -> "AsyncAdapt_pyathena_cursor":
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()


class AsyncAdapt_pyathena_connection(AdaptedConnection): # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi)
"""Wraps ``AioConnection`` with a sync DBAPI interface.

This adapted connection delegates ``cursor()`` to the underlying
``AioConnection`` and wraps each returned async cursor with
``AsyncAdapt_pyathena_cursor``.
"""

__slots__ = ("dbapi", "_connection")

def __init__(self, dbapi: "AsyncAdapt_pyathena_dbapi", connection: AioConnection) -> None:
self.dbapi = dbapi
self._connection = connection

@property
def driver_connection(self) -> AioConnection:
return self._connection # type: ignore[no-any-return]

@property
def catalog_name(self) -> Optional[str]:
return self._connection.catalog_name # type: ignore[no-any-return]

@property
def schema_name(self) -> Optional[str]:
return self._connection.schema_name # type: ignore[no-any-return]

def cursor(self) -> AsyncAdapt_pyathena_cursor:
raw_cursor = self._connection.cursor()
return AsyncAdapt_pyathena_cursor(raw_cursor)

def close(self) -> None:
self._connection.close()

def commit(self) -> None:
self._connection.commit()

def rollback(self) -> None:
pass


class AsyncAdapt_pyathena_dbapi: # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi)
"""Fake DBAPI module for the async SQLAlchemy engine.

SQLAlchemy expects ``import_dbapi()`` to return a module-like object
with ``connect()``, ``paramstyle``, and the standard DBAPI exception
hierarchy. This class fulfils that contract while routing connections
through ``AioConnection``.
"""

paramstyle = "pyformat"

# DBAPI exception hierarchy
Error = Error
Warning = pyathena.Warning
InterfaceError = InterfaceError
DatabaseError = DatabaseError
InternalError = InternalError
OperationalError = OperationalError
ProgrammingError = ProgrammingError
IntegrityError = IntegrityError
DataError = DataError
NotSupportedError = NotSupportedError

def connect(self, **kwargs: Any) -> AsyncAdapt_pyathena_connection:
connection = await_only(AioConnection.create(**kwargs))
return AsyncAdapt_pyathena_connection(self, connection)


class AthenaAioDialect(AthenaDialect):
"""Base async SQLAlchemy dialect for Amazon Athena.

Extends the synchronous ``AthenaDialect`` with async capability
by setting ``is_async = True`` and providing an adapted DBAPI module
that wraps ``AioConnection`` and async cursors via greenlet-based
``await_only()``.

Subclasses (e.g. ``AthenaAioRestDialect``, ``AthenaAioPandasDialect``)
register concrete ``awsathena+aio*`` drivers.

See Also:
:class:`~pyathena.sqlalchemy.base.AthenaDialect`: Synchronous base dialect.
:class:`~pyathena.aio.connection.AioConnection`: Native async connection.
"""

is_async = True
supports_statement_cache = True

@classmethod
def get_pool_class(cls, url: "URL") -> type:
return pool.AsyncAdaptedQueuePool

@classmethod
def import_dbapi(cls) -> "ModuleType":
return AsyncAdapt_pyathena_dbapi() # type: ignore[return-value]

@classmethod
def dbapi(cls) -> "ModuleType": # type: ignore[override]
return AsyncAdapt_pyathena_dbapi() # type: ignore[return-value]

def create_connect_args(self, url: "URL") -> Tuple[Tuple[str], MutableMapping[str, Any]]:
opts = self._create_connect_args(url)
self._connect_options = opts
return cast(Tuple[str], ()), opts

def get_driver_connection(self, connection: Any) -> Any:
return connection
Loading