Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ PyAthena is a Python DB API 2.0 (PEP 249) compliant client for Amazon Athena. Se

### Code Quality — Always Run Before Committing
```bash
make fmt # Auto-fix formatting and imports
make chk # Lint + format check + mypy
make format # Auto-fix formatting and imports
make lint # Lint + format check + mypy
```

### Testing
```bash
# ALWAYS run `make chk` first — tests will fail if lint doesn't pass
# ALWAYS run `make lint` first — tests will fail if lint doesn't pass
make test # Unit tests (runs chk first)
make test-sqla # SQLAlchemy dialect tests
```
Expand Down
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
RUFF_VERSION := 0.14.14
TOX_VERSION := 4.34.1

.PHONY: fmt
fmt:
.PHONY: format
format:
# TODO: https://github.com/astral-sh/uv/issues/5903
uvx ruff@$(RUFF_VERSION) check --select I --fix .
uvx ruff@$(RUFF_VERSION) format .

.PHONY: chk
chk:
.PHONY: lint
lint:
uvx ruff@$(RUFF_VERSION) check .
uvx ruff@$(RUFF_VERSION) format --check .
uv run mypy .

.PHONY: test
test: chk
test: lint
uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/pyathena/

.PHONY: test-sqla
Expand Down
6 changes: 3 additions & 3 deletions docs/testing.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ The code formatting uses [ruff](https://github.com/astral-sh/ruff).
### Appy format

```bash
$ make fmt
$ make format
```

### Check format
### Lint and check format

```bash
$ make chk
$ make lint
```

## GitHub Actions
Expand Down
23 changes: 11 additions & 12 deletions pyathena/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import datetime
from typing import TYPE_CHECKING, Any, FrozenSet, Type, overload
from typing import TYPE_CHECKING, Any, overload

from pyathena.error import * # noqa
from pyathena.error import * # noqa: F403

if TYPE_CHECKING:
from pyathena.aio.connection import AioConnection
Expand All @@ -28,7 +27,7 @@
paramstyle: str = "pyformat"


class DBAPITypeObject(FrozenSet[str]):
class DBAPITypeObject(frozenset[str]):
"""Type Objects and Constructors

https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors
Expand Down Expand Up @@ -60,22 +59,22 @@ def __hash__(self):
DATETIME: DBAPITypeObject = DBAPITypeObject(("timestamp", "timestamp with time zone"))
JSON: DBAPITypeObject = DBAPITypeObject(("json",))

Date: Type[datetime.date] = datetime.date
Time: Type[datetime.time] = datetime.time
Timestamp: Type[datetime.datetime] = datetime.datetime
Date: type[datetime.date] = datetime.date
Time: type[datetime.time] = datetime.time
Timestamp: type[datetime.datetime] = datetime.datetime


@overload
def connect(*args, cursor_class: None = ..., **kwargs) -> "Connection[Cursor]": ...
def connect(*args, cursor_class: None = ..., **kwargs) -> Connection[Cursor]: ...


@overload
def connect(
*args, cursor_class: Type[ConnectionCursor], **kwargs
) -> "Connection[ConnectionCursor]": ...
*args, cursor_class: type[ConnectionCursor], **kwargs
) -> Connection[ConnectionCursor]: ...


def connect(*args, **kwargs) -> "Connection[Any]":
def connect(*args, **kwargs) -> Connection[Any]:
"""Create a new database connection to Amazon Athena.

This function provides the main entry point for establishing connections
Expand Down Expand Up @@ -131,7 +130,7 @@ def connect(*args, **kwargs) -> "Connection[Any]":
return Connection(*args, **kwargs)


async def aio_connect(*args, **kwargs) -> "AioConnection":
async def aio_connect(*args, **kwargs) -> AioConnection:
"""Create a new async database connection to Amazon Athena.

This is the async counterpart of :func:`connect`. It returns an
Expand Down
1 change: 0 additions & 1 deletion pyathena/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
# -*- coding: utf-8 -*-
1 change: 0 additions & 1 deletion pyathena/aio/arrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
# -*- coding: utf-8 -*-
55 changes: 27 additions & 28 deletions pyathena/aio/arrow/cursor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

import asyncio
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, cast

from pyathena.aio.common import WithAsyncFetch
from pyathena.arrow.converter import (
Expand All @@ -19,7 +18,7 @@
import polars as pl
from pyarrow import Table

_logger = logging.getLogger(__name__) # type: ignore
_logger = logging.getLogger(__name__)


class AioArrowCursor(WithAsyncFetch):
Expand All @@ -37,19 +36,19 @@ class AioArrowCursor(WithAsyncFetch):

def __init__(
self,
s3_staging_dir: Optional[str] = None,
schema_name: Optional[str] = None,
catalog_name: Optional[str] = None,
work_group: Optional[str] = None,
s3_staging_dir: str | None = None,
schema_name: str | None = None,
catalog_name: str | None = None,
work_group: str | None = None,
poll_interval: float = 1,
encryption_option: Optional[str] = None,
kms_key: Optional[str] = None,
encryption_option: str | None = None,
kms_key: str | None = None,
kill_on_interrupt: bool = True,
unload: bool = False,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
connect_timeout: Optional[float] = None,
request_timeout: Optional[float] = None,
connect_timeout: float | None = None,
request_timeout: float | None = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -68,29 +67,29 @@ def __init__(
self._unload = unload
self._connect_timeout = connect_timeout
self._request_timeout = request_timeout
self._result_set: Optional[AthenaArrowResultSet] = None
self._result_set: AthenaArrowResultSet | None = None

@staticmethod
def get_default_converter(
unload: bool = False,
) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]:
) -> DefaultArrowTypeConverter | DefaultArrowUnloadTypeConverter | Any:
if unload:
return DefaultArrowUnloadTypeConverter()
return DefaultArrowTypeConverter()

async def execute( # type: ignore[override]
self,
operation: str,
parameters: Optional[Union[Dict[str, Any], List[str]]] = None,
work_group: Optional[str] = None,
s3_staging_dir: Optional[str] = None,
cache_size: Optional[int] = 0,
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
paramstyle: Optional[str] = None,
parameters: dict[str, Any] | list[str] | None = None,
work_group: str | None = None,
s3_staging_dir: str | None = None,
cache_size: int | None = 0,
cache_expiration_time: int | None = 0,
result_reuse_enable: bool | None = None,
result_reuse_minutes: int | None = None,
paramstyle: str | None = None,
**kwargs,
) -> "AioArrowCursor":
) -> AioArrowCursor:
"""Execute a SQL query asynchronously and return results as Arrow Tables.

Args:
Expand Down Expand Up @@ -143,7 +142,7 @@ async def execute( # type: ignore[override]

async def fetchone( # type: ignore[override]
self,
) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
) -> tuple[Any | None, ...] | dict[Any, Any | None] | None:
"""Fetch the next row of the result set.

Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid
Expand All @@ -161,8 +160,8 @@ async def fetchone( # type: ignore[override]
return await asyncio.to_thread(result_set.fetchone)

async def fetchmany( # type: ignore[override]
self, size: Optional[int] = None
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
self, size: int | None = None
) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]:
"""Fetch multiple rows from the result set.

Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid
Expand All @@ -184,7 +183,7 @@ async def fetchmany( # type: ignore[override]

async def fetchall( # type: ignore[override]
self,
) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]:
) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]:
"""Fetch all remaining rows from the result set.

Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid
Expand All @@ -207,7 +206,7 @@ async def __anext__(self):
raise StopAsyncIteration
return row

def as_arrow(self) -> "Table":
def as_arrow(self) -> Table:
"""Return query results as an Apache Arrow Table.

Returns:
Expand All @@ -218,7 +217,7 @@ def as_arrow(self) -> "Table":
result_set = cast(AthenaArrowResultSet, self.result_set)
return result_set.as_arrow()

def as_polars(self) -> "pl.DataFrame":
def as_polars(self) -> pl.DataFrame:
"""Return query results as a Polars DataFrame.

Returns:
Expand Down
Loading